main.py
Browse files
main.py
CHANGED
|
@@ -14,7 +14,7 @@ from langdetect import detect
|
|
| 14 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 15 |
MAX_LEN = 20
|
| 16 |
|
| 17 |
-
#
|
| 18 |
transform = transforms.Compose([
|
| 19 |
transforms.Resize((224, 224)),
|
| 20 |
transforms.ToTensor()
|
|
@@ -40,7 +40,10 @@ class VQADataset(Dataset):
|
|
| 40 |
|
| 41 |
|
| 42 |
def prepare_data(max_answers=50, min_word_count=3, max_len=MAX_LEN):
|
| 43 |
-
"""
|
|
|
|
|
|
|
|
|
|
| 44 |
from datasets import load_dataset
|
| 45 |
import pandas as pd
|
| 46 |
|
|
@@ -97,13 +100,11 @@ def prepare_data(max_answers=50, min_word_count=3, max_len=MAX_LEN):
|
|
| 97 |
class VQAModel(nn.Module):
|
| 98 |
def __init__(self, vocab_size, embed_dim, hidden_dim, num_answers):
|
| 99 |
super().__init__()
|
| 100 |
-
#
|
| 101 |
self.cnn = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
|
| 102 |
self.cnn.fc = nn.Identity()
|
| 103 |
-
|
| 104 |
self.embedding = nn.Embedding(vocab_size, embed_dim)
|
| 105 |
self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
|
| 106 |
-
|
| 107 |
self.fc1 = nn.Linear(512 + hidden_dim, 256)
|
| 108 |
self.relu = nn.ReLU()
|
| 109 |
self.fc2 = nn.Linear(256, num_answers)
|
|
@@ -118,10 +119,10 @@ class VQAModel(nn.Module):
|
|
| 118 |
return self.fc2(x)
|
| 119 |
|
| 120 |
|
| 121 |
-
# ========== Training (local only) ==========
|
| 122 |
def train_model(train_dataset, vocab, idx_to_answer,
|
| 123 |
epochs=20, batch_size=32, lr=1e-3, save_prefix="vqa_custom"):
|
| 124 |
-
"""
|
| 125 |
vocab_size = len(vocab)
|
| 126 |
num_answers = len(idx_to_answer)
|
| 127 |
|
|
@@ -147,31 +148,39 @@ def train_model(train_dataset, vocab, idx_to_answer,
|
|
| 147 |
|
| 148 |
print(f"Epoch {epoch+1}, Loss: {total_loss/len(loader):.4f}")
|
| 149 |
|
| 150 |
-
torch.save(model.state_dict(),
|
| 151 |
-
with open(
|
| 152 |
pickle.dump(vocab, f)
|
| 153 |
-
with open(
|
| 154 |
pickle.dump(idx_to_answer, f)
|
| 155 |
|
| 156 |
return model
|
| 157 |
|
| 158 |
|
| 159 |
-
# ========== Load artifacts + inference ==========
|
| 160 |
-
def load_artifacts(prefix=
|
| 161 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
vocab = pickle.load(f)
|
| 163 |
-
with open(
|
| 164 |
idx_to_answer = pickle.load(f)
|
| 165 |
|
| 166 |
model = VQAModel(len(vocab), 300, 256, len(idx_to_answer))
|
| 167 |
-
model.load_state_dict(torch.load(
|
|
|
|
| 168 |
model.to(device)
|
| 169 |
model.eval()
|
| 170 |
|
| 171 |
def encode_question_infer(q, max_len=MAX_LEN):
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
enc =
|
|
|
|
| 175 |
return torch.tensor(enc).unsqueeze(0)
|
| 176 |
|
| 177 |
def predict_custom_vqa(image_path, question):
|
|
@@ -184,41 +193,37 @@ def load_artifacts(prefix="vqa_custom", map_location=None):
|
|
| 184 |
return idx_to_answer[pred.item()]
|
| 185 |
|
| 186 |
def final_pipeline(image_path, question, open_vqa_fn=None, translate_fn=None):
|
| 187 |
-
# Keep
|
| 188 |
lang = detect(question)
|
| 189 |
-
q_en = question
|
| 190 |
-
|
| 191 |
-
if open_vqa_fn is not None and ("what is" in q_en.lower() or "this place" in q_en.lower()):
|
| 192 |
-
answer_en = open_vqa_fn(image_path, q_en)
|
| 193 |
-
else:
|
| 194 |
-
answer_en = predict_custom_vqa(image_path, q_en)
|
| 195 |
|
| 196 |
-
|
| 197 |
-
|
| 198 |
return answer_en
|
| 199 |
|
| 200 |
return final_pipeline, predict_custom_vqa, vocab, idx_to_answer, model, encode_question_infer
|
| 201 |
|
| 202 |
|
| 203 |
def load_artifacts_and_helpers(prefix="vqa_custom", map_location=None):
|
| 204 |
-
|
|
|
|
| 205 |
|
| 206 |
|
|
|
|
| 207 |
if __name__ == "__main__":
|
| 208 |
-
# Local CLI only; never runs in Space
|
| 209 |
import argparse
|
| 210 |
-
|
|
|
|
| 211 |
parser.add_argument("--prepare", action="store_true")
|
| 212 |
parser.add_argument("--train", action="store_true")
|
| 213 |
parser.add_argument("--epochs", type=int, default=20)
|
| 214 |
-
parser.add_argument("--prefix", default="vqa_custom")
|
| 215 |
parser.add_argument("--image")
|
| 216 |
parser.add_argument("--question", default="What is in the image?")
|
| 217 |
args = parser.parse_args()
|
| 218 |
|
| 219 |
if args.prepare or args.train:
|
| 220 |
artifacts = prepare_data()
|
| 221 |
-
print("Prepared dataset with", len(artifacts["answer_to_idx"]), "
|
| 222 |
|
| 223 |
if args.train:
|
| 224 |
train_model(
|
|
@@ -226,9 +231,8 @@ if __name__ == "__main__":
|
|
| 226 |
artifacts["vocab"],
|
| 227 |
artifacts["idx_to_answer"],
|
| 228 |
epochs=args.epochs,
|
| 229 |
-
save_prefix=args.prefix,
|
| 230 |
)
|
| 231 |
|
| 232 |
if args.image:
|
| 233 |
-
final_pipeline, *_ = load_artifacts(
|
| 234 |
-
print(final_pipeline(args.image, args.question
|
|
|
|
| 14 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 15 |
MAX_LEN = 20
|
| 16 |
|
| 17 |
+
# Same transform as training
|
| 18 |
transform = transforms.Compose([
|
| 19 |
transforms.Resize((224, 224)),
|
| 20 |
transforms.ToTensor()
|
|
|
|
| 40 |
|
| 41 |
|
| 42 |
def prepare_data(max_answers=50, min_word_count=3, max_len=MAX_LEN):
|
| 43 |
+
"""
|
| 44 |
+
Local training helper: load VQA-RAD, clean, build vocab, dataset.
|
| 45 |
+
NOT used on the Space at runtime.
|
| 46 |
+
"""
|
| 47 |
from datasets import load_dataset
|
| 48 |
import pandas as pd
|
| 49 |
|
|
|
|
| 100 |
class VQAModel(nn.Module):
|
| 101 |
def __init__(self, vocab_size, embed_dim, hidden_dim, num_answers):
|
| 102 |
super().__init__()
|
| 103 |
+
# same backbone as original code (ResNet18 pretrained)
|
| 104 |
self.cnn = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
|
| 105 |
self.cnn.fc = nn.Identity()
|
|
|
|
| 106 |
self.embedding = nn.Embedding(vocab_size, embed_dim)
|
| 107 |
self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
|
|
|
|
| 108 |
self.fc1 = nn.Linear(512 + hidden_dim, 256)
|
| 109 |
self.relu = nn.ReLU()
|
| 110 |
self.fc2 = nn.Linear(256, num_answers)
|
|
|
|
| 119 |
return self.fc2(x)
|
| 120 |
|
| 121 |
|
| 122 |
+
# ========== Training (local only, not used on Space) ==========
|
| 123 |
def train_model(train_dataset, vocab, idx_to_answer,
|
| 124 |
epochs=20, batch_size=32, lr=1e-3, save_prefix="vqa_custom"):
|
| 125 |
+
"""Use only in Colab / local to create vqa_custom_model.pth etc."""
|
| 126 |
vocab_size = len(vocab)
|
| 127 |
num_answers = len(idx_to_answer)
|
| 128 |
|
|
|
|
| 148 |
|
| 149 |
print(f"Epoch {epoch+1}, Loss: {total_loss/len(loader):.4f}")
|
| 150 |
|
| 151 |
+
torch.save(model.state_dict(), "vqa_custom_model.pth")
|
| 152 |
+
with open("vocab.pkl", "wb") as f:
|
| 153 |
pickle.dump(vocab, f)
|
| 154 |
+
with open("answer_mapping.pkl", "wb") as f:
|
| 155 |
pickle.dump(idx_to_answer, f)
|
| 156 |
|
| 157 |
return model
|
| 158 |
|
| 159 |
|
| 160 |
+
# ========== Load artifacts + inference (used in Space) ==========
|
| 161 |
+
def load_artifacts(prefix=None, map_location=None):
|
| 162 |
+
"""
|
| 163 |
+
Load your original good model:
|
| 164 |
+
- vqa_custom_model.pth
|
| 165 |
+
- vocab.pkl
|
| 166 |
+
- answer_mapping.pkl
|
| 167 |
+
"""
|
| 168 |
+
with open("vocab.pkl", "rb") as f:
|
| 169 |
vocab = pickle.load(f)
|
| 170 |
+
with open("answer_mapping.pkl", "rb") as f:
|
| 171 |
idx_to_answer = pickle.load(f)
|
| 172 |
|
| 173 |
model = VQAModel(len(vocab), 300, 256, len(idx_to_answer))
|
| 174 |
+
model.load_state_dict(torch.load("vqa_custom_model.pth",
|
| 175 |
+
map_location=map_location or device))
|
| 176 |
model.to(device)
|
| 177 |
model.eval()
|
| 178 |
|
| 179 |
def encode_question_infer(q, max_len=MAX_LEN):
|
| 180 |
+
q = str(q).lower()
|
| 181 |
+
tokens = q.split()
|
| 182 |
+
enc = [vocab.get(w, vocab.get("<UNK>", 0)) for w in tokens]
|
| 183 |
+
enc = enc[:max_len] + [vocab.get("<PAD>", 0)] * (max_len - len(enc))
|
| 184 |
return torch.tensor(enc).unsqueeze(0)
|
| 185 |
|
| 186 |
def predict_custom_vqa(image_path, question):
|
|
|
|
| 193 |
return idx_to_answer[pred.item()]
|
| 194 |
|
| 195 |
def final_pipeline(image_path, question, open_vqa_fn=None, translate_fn=None):
|
| 196 |
+
# Keep exactly what your good model expects (English radiology questions)
|
| 197 |
lang = detect(question)
|
| 198 |
+
q_en = question # you trained in English; skip translation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
|
| 200 |
+
# Always use custom model here; you can add BLIP routing later if needed
|
| 201 |
+
answer_en = predict_custom_vqa(image_path, q_en)
|
| 202 |
return answer_en
|
| 203 |
|
| 204 |
return final_pipeline, predict_custom_vqa, vocab, idx_to_answer, model, encode_question_infer
|
| 205 |
|
| 206 |
|
| 207 |
def load_artifacts_and_helpers(prefix="vqa_custom", map_location=None):
|
| 208 |
+
# wrapper used by app.py
|
| 209 |
+
return load_artifacts(map_location=map_location)
|
| 210 |
|
| 211 |
|
| 212 |
+
# ========== Optional CLI (local only) ==========
|
| 213 |
if __name__ == "__main__":
|
|
|
|
| 214 |
import argparse
|
| 215 |
+
|
| 216 |
+
parser = argparse.ArgumentParser(description="VQA pipeline (prepare/train/infer)")
|
| 217 |
parser.add_argument("--prepare", action="store_true")
|
| 218 |
parser.add_argument("--train", action="store_true")
|
| 219 |
parser.add_argument("--epochs", type=int, default=20)
|
|
|
|
| 220 |
parser.add_argument("--image")
|
| 221 |
parser.add_argument("--question", default="What is in the image?")
|
| 222 |
args = parser.parse_args()
|
| 223 |
|
| 224 |
if args.prepare or args.train:
|
| 225 |
artifacts = prepare_data()
|
| 226 |
+
print("Prepared dataset with", len(artifacts["answer_to_idx"]), "answer classes.")
|
| 227 |
|
| 228 |
if args.train:
|
| 229 |
train_model(
|
|
|
|
| 231 |
artifacts["vocab"],
|
| 232 |
artifacts["idx_to_answer"],
|
| 233 |
epochs=args.epochs,
|
|
|
|
| 234 |
)
|
| 235 |
|
| 236 |
if args.image:
|
| 237 |
+
final_pipeline, *_ = load_artifacts()
|
| 238 |
+
print(final_pipeline(args.image, args.question))
|