PRUTHVIn commited on
Commit
d7a7065
·
verified ·
1 Parent(s): 37a21e2
Files changed (1) hide show
  1. main.py +39 -35
main.py CHANGED
@@ -14,7 +14,7 @@ from langdetect import detect
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
  MAX_LEN = 20
16
 
17
- # Fixed image transform
18
  transform = transforms.Compose([
19
  transforms.Resize((224, 224)),
20
  transforms.ToTensor()
@@ -40,7 +40,10 @@ class VQADataset(Dataset):
40
 
41
 
42
  def prepare_data(max_answers=50, min_word_count=3, max_len=MAX_LEN):
43
- """For local training only – not used in Space."""
 
 
 
44
  from datasets import load_dataset
45
  import pandas as pd
46
 
@@ -97,13 +100,11 @@ def prepare_data(max_answers=50, min_word_count=3, max_len=MAX_LEN):
97
  class VQAModel(nn.Module):
98
  def __init__(self, vocab_size, embed_dim, hidden_dim, num_answers):
99
  super().__init__()
100
- # Use lightweight ResNet18 backbone
101
  self.cnn = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
102
  self.cnn.fc = nn.Identity()
103
-
104
  self.embedding = nn.Embedding(vocab_size, embed_dim)
105
  self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
106
-
107
  self.fc1 = nn.Linear(512 + hidden_dim, 256)
108
  self.relu = nn.ReLU()
109
  self.fc2 = nn.Linear(256, num_answers)
@@ -118,10 +119,10 @@ class VQAModel(nn.Module):
118
  return self.fc2(x)
119
 
120
 
121
- # ========== Training (local only) ==========
122
  def train_model(train_dataset, vocab, idx_to_answer,
123
  epochs=20, batch_size=32, lr=1e-3, save_prefix="vqa_custom"):
124
- """Run this only on Colab / local, not in Space."""
125
  vocab_size = len(vocab)
126
  num_answers = len(idx_to_answer)
127
 
@@ -147,31 +148,39 @@ def train_model(train_dataset, vocab, idx_to_answer,
147
 
148
  print(f"Epoch {epoch+1}, Loss: {total_loss/len(loader):.4f}")
149
 
150
- torch.save(model.state_dict(), f"{save_prefix}_model.pth")
151
- with open(f"{save_prefix}_vocab.pkl", "wb") as f:
152
  pickle.dump(vocab, f)
153
- with open(f"{save_prefix}_answers.pkl", "wb") as f:
154
  pickle.dump(idx_to_answer, f)
155
 
156
  return model
157
 
158
 
159
- # ========== Load artifacts + inference ==========
160
- def load_artifacts(prefix="vqa_custom", map_location=None):
161
- with open(f"{prefix}_vocab.pkl", "rb") as f:
 
 
 
 
 
 
162
  vocab = pickle.load(f)
163
- with open(f"{prefix}_answers.pkl", "rb") as f:
164
  idx_to_answer = pickle.load(f)
165
 
166
  model = VQAModel(len(vocab), 300, 256, len(idx_to_answer))
167
- model.load_state_dict(torch.load(f"{prefix}_model.pth", map_location=map_location or device))
 
168
  model.to(device)
169
  model.eval()
170
 
171
  def encode_question_infer(q, max_len=MAX_LEN):
172
- tokens = str(q).lower().split()
173
- enc = [vocab.get(w, vocab["<UNK>"]) for w in tokens]
174
- enc = enc[:max_len] + [vocab["<PAD>"]] * (max_len - len(enc))
 
175
  return torch.tensor(enc).unsqueeze(0)
176
 
177
  def predict_custom_vqa(image_path, question):
@@ -184,41 +193,37 @@ def load_artifacts(prefix="vqa_custom", map_location=None):
184
  return idx_to_answer[pred.item()]
185
 
186
  def final_pipeline(image_path, question, open_vqa_fn=None, translate_fn=None):
187
- # Keep it simple in Space: English only, no BLIP / translator unless passed in
188
  lang = detect(question)
189
- q_en = question if lang == "en" or translate_fn is None else translate_fn(question, lang, "en")
190
-
191
- if open_vqa_fn is not None and ("what is" in q_en.lower() or "this place" in q_en.lower()):
192
- answer_en = open_vqa_fn(image_path, q_en)
193
- else:
194
- answer_en = predict_custom_vqa(image_path, q_en)
195
 
196
- if lang != "en" and translate_fn is not None:
197
- return translate_fn(answer_en, "en", lang)
198
  return answer_en
199
 
200
  return final_pipeline, predict_custom_vqa, vocab, idx_to_answer, model, encode_question_infer
201
 
202
 
203
  def load_artifacts_and_helpers(prefix="vqa_custom", map_location=None):
204
- return load_artifacts(prefix=prefix, map_location=map_location)
 
205
 
206
 
 
207
  if __name__ == "__main__":
208
- # Local CLI only; never runs in Space
209
  import argparse
210
- parser = argparse.ArgumentParser()
 
211
  parser.add_argument("--prepare", action="store_true")
212
  parser.add_argument("--train", action="store_true")
213
  parser.add_argument("--epochs", type=int, default=20)
214
- parser.add_argument("--prefix", default="vqa_custom")
215
  parser.add_argument("--image")
216
  parser.add_argument("--question", default="What is in the image?")
217
  args = parser.parse_args()
218
 
219
  if args.prepare or args.train:
220
  artifacts = prepare_data()
221
- print("Prepared dataset with", len(artifacts["answer_to_idx"]), "answers")
222
 
223
  if args.train:
224
  train_model(
@@ -226,9 +231,8 @@ if __name__ == "__main__":
226
  artifacts["vocab"],
227
  artifacts["idx_to_answer"],
228
  epochs=args.epochs,
229
- save_prefix=args.prefix,
230
  )
231
 
232
  if args.image:
233
- final_pipeline, *_ = load_artifacts(prefix=args.prefix)
234
- print(final_pipeline(args.image, args.question, open_vqa_fn=None, translate_fn=None))
 
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
  MAX_LEN = 20
16
 
17
+ # Same transform as training
18
  transform = transforms.Compose([
19
  transforms.Resize((224, 224)),
20
  transforms.ToTensor()
 
40
 
41
 
42
  def prepare_data(max_answers=50, min_word_count=3, max_len=MAX_LEN):
43
+ """
44
+ Local training helper: load VQA-RAD, clean, build vocab, dataset.
45
+ NOT used on the Space at runtime.
46
+ """
47
  from datasets import load_dataset
48
  import pandas as pd
49
 
 
100
  class VQAModel(nn.Module):
101
  def __init__(self, vocab_size, embed_dim, hidden_dim, num_answers):
102
  super().__init__()
103
+ # same backbone as original code (ResNet18 pretrained)
104
  self.cnn = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
105
  self.cnn.fc = nn.Identity()
 
106
  self.embedding = nn.Embedding(vocab_size, embed_dim)
107
  self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
 
108
  self.fc1 = nn.Linear(512 + hidden_dim, 256)
109
  self.relu = nn.ReLU()
110
  self.fc2 = nn.Linear(256, num_answers)
 
119
  return self.fc2(x)
120
 
121
 
122
+ # ========== Training (local only, not used on Space) ==========
123
  def train_model(train_dataset, vocab, idx_to_answer,
124
  epochs=20, batch_size=32, lr=1e-3, save_prefix="vqa_custom"):
125
+ """Use only in Colab / local to create vqa_custom_model.pth etc."""
126
  vocab_size = len(vocab)
127
  num_answers = len(idx_to_answer)
128
 
 
148
 
149
  print(f"Epoch {epoch+1}, Loss: {total_loss/len(loader):.4f}")
150
 
151
+ torch.save(model.state_dict(), "vqa_custom_model.pth")
152
+ with open("vocab.pkl", "wb") as f:
153
  pickle.dump(vocab, f)
154
+ with open("answer_mapping.pkl", "wb") as f:
155
  pickle.dump(idx_to_answer, f)
156
 
157
  return model
158
 
159
 
160
+ # ========== Load artifacts + inference (used in Space) ==========
161
+ def load_artifacts(prefix=None, map_location=None):
162
+ """
163
+ Load your original good model:
164
+ - vqa_custom_model.pth
165
+ - vocab.pkl
166
+ - answer_mapping.pkl
167
+ """
168
+ with open("vocab.pkl", "rb") as f:
169
  vocab = pickle.load(f)
170
+ with open("answer_mapping.pkl", "rb") as f:
171
  idx_to_answer = pickle.load(f)
172
 
173
  model = VQAModel(len(vocab), 300, 256, len(idx_to_answer))
174
+ model.load_state_dict(torch.load("vqa_custom_model.pth",
175
+ map_location=map_location or device))
176
  model.to(device)
177
  model.eval()
178
 
179
  def encode_question_infer(q, max_len=MAX_LEN):
180
+ q = str(q).lower()
181
+ tokens = q.split()
182
+ enc = [vocab.get(w, vocab.get("<UNK>", 0)) for w in tokens]
183
+ enc = enc[:max_len] + [vocab.get("<PAD>", 0)] * (max_len - len(enc))
184
  return torch.tensor(enc).unsqueeze(0)
185
 
186
  def predict_custom_vqa(image_path, question):
 
193
  return idx_to_answer[pred.item()]
194
 
195
  def final_pipeline(image_path, question, open_vqa_fn=None, translate_fn=None):
196
+ # Keep exactly what your good model expects (English radiology questions)
197
  lang = detect(question)
198
+ q_en = question # you trained in English; skip translation
 
 
 
 
 
199
 
200
+ # Always use custom model here; you can add BLIP routing later if needed
201
+ answer_en = predict_custom_vqa(image_path, q_en)
202
  return answer_en
203
 
204
  return final_pipeline, predict_custom_vqa, vocab, idx_to_answer, model, encode_question_infer
205
 
206
 
207
  def load_artifacts_and_helpers(prefix="vqa_custom", map_location=None):
208
+ # wrapper used by app.py
209
+ return load_artifacts(map_location=map_location)
210
 
211
 
212
+ # ========== Optional CLI (local only) ==========
213
  if __name__ == "__main__":
 
214
  import argparse
215
+
216
+ parser = argparse.ArgumentParser(description="VQA pipeline (prepare/train/infer)")
217
  parser.add_argument("--prepare", action="store_true")
218
  parser.add_argument("--train", action="store_true")
219
  parser.add_argument("--epochs", type=int, default=20)
 
220
  parser.add_argument("--image")
221
  parser.add_argument("--question", default="What is in the image?")
222
  args = parser.parse_args()
223
 
224
  if args.prepare or args.train:
225
  artifacts = prepare_data()
226
+ print("Prepared dataset with", len(artifacts["answer_to_idx"]), "answer classes.")
227
 
228
  if args.train:
229
  train_model(
 
231
  artifacts["vocab"],
232
  artifacts["idx_to_answer"],
233
  epochs=args.epochs,
 
234
  )
235
 
236
  if args.image:
237
+ final_pipeline, *_ = load_artifacts()
238
+ print(final_pipeline(args.image, args.question))