PRUTHVIn commited on
Commit
87d4a35
·
verified ·
1 Parent(s): d7a7065

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +157 -205
main.py CHANGED
@@ -1,107 +1,30 @@
1
- import re
2
- from collections import Counter
3
- import pickle
4
-
5
  import torch
6
  import torch.nn as nn
7
- from torch.utils.data import Dataset, DataLoader
 
 
8
  from PIL import Image
9
  import torchvision.transforms as transforms
10
- import torchvision.models as models
 
11
  from langdetect import detect
 
 
12
 
13
- # ========== Global config ==========
14
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
- MAX_LEN = 20
16
 
17
- # Same transform as training
18
  transform = transforms.Compose([
19
  transforms.Resize((224, 224)),
20
  transforms.ToTensor()
21
  ])
22
 
23
-
24
- # ========== Dataset & preprocessing (for local training only) ==========
25
- class VQADataset(Dataset):
26
- def __init__(self, df, transform):
27
- self.df = df.reset_index(drop=True)
28
- self.transform = transform
29
-
30
- def __len__(self):
31
- return len(self.df)
32
-
33
- def __getitem__(self, idx):
34
- row = self.df.iloc[idx]
35
- image = row["image"].convert("RGB")
36
- image = self.transform(image)
37
- question = torch.tensor(row["question_encoded"], dtype=torch.long)
38
- answer = torch.tensor(row["answer_encoded"], dtype=torch.long)
39
- return image, question, answer
40
-
41
-
42
- def prepare_data(max_answers=50, min_word_count=3, max_len=MAX_LEN):
43
- """
44
- Local training helper: load VQA-RAD, clean, build vocab, dataset.
45
- NOT used on the Space at runtime.
46
- """
47
- from datasets import load_dataset
48
- import pandas as pd
49
-
50
- dataset = load_dataset("flaviagiammarino/vqa-rad")
51
- df = pd.DataFrame(dataset["train"])[["image", "question", "answer"]]
52
-
53
- def clean_text(text):
54
- text = str(text).lower()
55
- text = re.sub(r"[^a-z0-9 ]", "", text)
56
- return text
57
-
58
- df["question"] = df["question"].apply(clean_text)
59
- df["answer"] = df["answer"].apply(clean_text)
60
-
61
- top_answers = df["answer"].value_counts().nlargest(max_answers).index
62
- df = df[df["answer"].isin(top_answers)].reset_index(drop=True)
63
-
64
- answer_to_idx = {a: i for i, a in enumerate(top_answers)}
65
- idx_to_answer = {i: a for a, i in answer_to_idx.items()}
66
- df["answer_encoded"] = df["answer"].apply(lambda x: answer_to_idx[x])
67
-
68
- vocab = {"<PAD>": 0, "<UNK>": 1}
69
- counter = Counter()
70
- for q in df["question"]:
71
- for w in q.split():
72
- counter[w] += 1
73
-
74
- idx = 2
75
- for word, count in counter.items():
76
- if count >= min_word_count:
77
- vocab[word] = idx
78
- idx += 1
79
-
80
- def encode_question(q):
81
- tokens = q.split()
82
- enc = [vocab.get(w, vocab["<UNK>"]) for w in tokens]
83
- enc = enc[:max_len] + [vocab["<PAD>"]] * (max_len - len(enc))
84
- return enc
85
-
86
- df["question_encoded"] = df["question"].apply(encode_question)
87
-
88
- train_dataset = VQADataset(df, transform)
89
- return {
90
- "dataset_df": df,
91
- "train_dataset": train_dataset,
92
- "vocab": vocab,
93
- "answer_to_idx": answer_to_idx,
94
- "idx_to_answer": idx_to_answer,
95
- "max_len": max_len,
96
- }
97
-
98
-
99
- # ========== Model ==========
100
  class VQAModel(nn.Module):
101
  def __init__(self, vocab_size, embed_dim, hidden_dim, num_answers):
102
  super().__init__()
103
- # same backbone as original code (ResNet18 pretrained)
104
- self.cnn = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
105
  self.cnn.fc = nn.Identity()
106
  self.embedding = nn.Embedding(vocab_size, embed_dim)
107
  self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
@@ -114,125 +37,154 @@ class VQAModel(nn.Module):
114
  q_embed = self.embedding(question)
115
  _, (h, _) = self.lstm(q_embed)
116
  q_feat = h.squeeze(0)
117
- x = torch.cat((img_feat, q_feat), dim=1)
118
- x = self.relu(self.fc1(x))
119
- return self.fc2(x)
120
-
121
-
122
- # ========== Training (local only, not used on Space) ==========
123
- def train_model(train_dataset, vocab, idx_to_answer,
124
- epochs=20, batch_size=32, lr=1e-3, save_prefix="vqa_custom"):
125
- """Use only in Colab / local to create vqa_custom_model.pth etc."""
126
- vocab_size = len(vocab)
127
- num_answers = len(idx_to_answer)
128
-
129
- model = VQAModel(vocab_size, 300, 256, num_answers).to(device)
130
- criterion = nn.CrossEntropyLoss()
131
- optimizer = torch.optim.Adam(model.parameters(), lr=lr)
132
-
133
- loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
134
- from tqdm import tqdm
135
-
136
- for epoch in range(epochs):
137
- model.train()
138
- total_loss = 0.0
139
- for images, questions, answers in tqdm(loader, desc=f"Epoch {epoch+1}/{epochs}"):
140
- images, questions, answers = images.to(device), questions.to(device), answers.to(device)
141
- outputs = model(images, questions)
142
- loss = criterion(outputs, answers)
143
-
144
- optimizer.zero_grad()
145
- loss.backward()
146
- optimizer.step()
147
- total_loss += loss.item()
148
-
149
- print(f"Epoch {epoch+1}, Loss: {total_loss/len(loader):.4f}")
150
-
151
- torch.save(model.state_dict(), "vqa_custom_model.pth")
152
- with open("vocab.pkl", "wb") as f:
153
- pickle.dump(vocab, f)
154
- with open("answer_mapping.pkl", "wb") as f:
155
- pickle.dump(idx_to_answer, f)
156
-
157
- return model
158
-
159
-
160
- # ========== Load artifacts + inference (used in Space) ==========
161
- def load_artifacts(prefix=None, map_location=None):
162
- """
163
- Load your original good model:
164
- - vqa_custom_model.pth
165
- - vocab.pkl
166
- - answer_mapping.pkl
167
- """
168
- with open("vocab.pkl", "rb") as f:
169
  vocab = pickle.load(f)
170
- with open("answer_mapping.pkl", "rb") as f:
171
  idx_to_answer = pickle.load(f)
172
-
173
- model = VQAModel(len(vocab), 300, 256, len(idx_to_answer))
174
- model.load_state_dict(torch.load("vqa_custom_model.pth",
175
- map_location=map_location or device))
176
  model.to(device)
177
  model.eval()
178
-
179
- def encode_question_infer(q, max_len=MAX_LEN):
180
- q = str(q).lower()
181
- tokens = q.split()
182
- enc = [vocab.get(w, vocab.get("<UNK>", 0)) for w in tokens]
183
- enc = enc[:max_len] + [vocab.get("<PAD>", 0)] * (max_len - len(enc))
184
- return torch.tensor(enc).unsqueeze(0)
185
-
186
- def predict_custom_vqa(image_path, question):
187
- image = Image.open(image_path).convert("RGB")
188
- image_t = transform(image).unsqueeze(0).to(device)
189
- q = encode_question_infer(question).to(device)
190
- with torch.no_grad():
191
- out = model(image_t, q)
192
- _, pred = torch.max(out, 1)
193
- return idx_to_answer[pred.item()]
194
-
195
- def final_pipeline(image_path, question, open_vqa_fn=None, translate_fn=None):
196
- # Keep exactly what your good model expects (English radiology questions)
197
- lang = detect(question)
198
- q_en = question # you trained in English; skip translation
199
-
200
- # Always use custom model here; you can add BLIP routing later if needed
201
- answer_en = predict_custom_vqa(image_path, q_en)
202
- return answer_en
203
-
204
- return final_pipeline, predict_custom_vqa, vocab, idx_to_answer, model, encode_question_infer
205
-
206
-
207
- def load_artifacts_and_helpers(prefix="vqa_custom", map_location=None):
208
- # wrapper used by app.py
209
- return load_artifacts(map_location=map_location)
210
-
211
-
212
- # ========== Optional CLI (local only) ==========
213
- if __name__ == "__main__":
214
- import argparse
215
-
216
- parser = argparse.ArgumentParser(description="VQA pipeline (prepare/train/infer)")
217
- parser.add_argument("--prepare", action="store_true")
218
- parser.add_argument("--train", action="store_true")
219
- parser.add_argument("--epochs", type=int, default=20)
220
- parser.add_argument("--image")
221
- parser.add_argument("--question", default="What is in the image?")
222
- args = parser.parse_args()
223
-
224
- if args.prepare or args.train:
225
- artifacts = prepare_data()
226
- print("Prepared dataset with", len(artifacts["answer_to_idx"]), "answer classes.")
227
-
228
- if args.train:
229
- train_model(
230
- artifacts["train_dataset"],
231
- artifacts["vocab"],
232
- artifacts["idx_to_answer"],
233
- epochs=args.epochs,
 
 
 
 
 
 
 
 
 
 
 
 
234
  )
 
 
 
235
 
236
- if args.image:
237
- final_pipeline, *_ = load_artifacts()
238
- print(final_pipeline(args.image, args.question))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import torch.nn as nn
3
+ import torchvision.models as models
4
+ import pickle
5
+ import re
6
  from PIL import Image
7
  import torchvision.transforms as transforms
8
+ from transformers import Blip2Processor, Blip2ForConditionalGeneration
9
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
10
  from langdetect import detect
11
+ import numpy as np
12
+ import os
13
 
14
+ # Global models dictionary
15
+ models_dict = None
16
+ device = None
17
 
18
+ # Transforms
19
  transform = transforms.Compose([
20
  transforms.Resize((224, 224)),
21
  transforms.ToTensor()
22
  ])
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  class VQAModel(nn.Module):
25
  def __init__(self, vocab_size, embed_dim, hidden_dim, num_answers):
26
  super().__init__()
27
+ self.cnn = models.resnet18(pretrained=False)
 
28
  self.cnn.fc = nn.Identity()
29
  self.embedding = nn.Embedding(vocab_size, embed_dim)
30
  self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
 
37
  q_embed = self.embedding(question)
38
  _, (h, _) = self.lstm(q_embed)
39
  q_feat = h.squeeze(0)
40
+ combined = torch.cat((img_feat, q_feat), dim=1)
41
+ x = self.relu(self.fc1(combined))
42
+ out = self.fc2(x)
43
+ return out
44
+
45
+ def load_models():
46
+ """Load all models once at startup"""
47
+ global models_dict, device
48
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
49
+
50
+ print(f"Using device: {device}")
51
+
52
+ # Load custom VQA model
53
+ with open("models/vocab.pkl", "rb") as f:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  vocab = pickle.load(f)
55
+ with open("models/answer_mapping.pkl", "rb") as f:
56
  idx_to_answer = pickle.load(f)
57
+
58
+ vocab_size = len(vocab)
59
+ model = VQAModel(vocab_size, 300, 256, len(idx_to_answer))
60
+ model.load_state_dict(torch.load("models/vqa_custom_model.pth", map_location=device))
61
  model.to(device)
62
  model.eval()
63
+
64
+ # BLIP2 for open-ended (smaller model for free tier)
65
+ print("Loading BLIP2...")
66
+ processor = Blip2Processor.from_pretrained(
67
+ "Salesforce/blip2-flan-t5-base",
68
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
69
+ )
70
+ blip_model = Blip2ForConditionalGeneration.from_pretrained(
71
+ "Salesforce/blip2-flan-t5-base",
72
+ device_map="auto" if torch.cuda.is_available() else None,
73
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
74
+ low_cpu_mem_usage=True
75
+ )
76
+
77
+ # Translator
78
+ print("Loading Translator...")
79
+ translator_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
80
+ translator_model = AutoModelForSeq2SeqLM.from_pretrained(
81
+ "facebook/nllb-200-distilled-600M",
82
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
83
+ ).to(device)
84
+
85
+ lang_code_map = {
86
+ "en": "eng_Latn", "hi": "hin_Deva", "te": "tel_Telu",
87
+ "ta": "tam_Taml", "kn": "kan_Knda", "ml": "mal_Mlym"
88
+ }
89
+
90
+ models_dict = {
91
+ 'model': model, 'vocab': vocab, 'idx_to_answer': idx_to_answer,
92
+ 'processor': processor, 'blip_model': blip_model,
93
+ 'translator_tokenizer': translator_tokenizer,
94
+ 'translator_model': translator_model, 'lang_code_map': lang_code_map,
95
+ 'device': device
96
+ }
97
+
98
+ print("✅ All models loaded successfully!")
99
+ return models_dict
100
+
101
+ def init_models():
102
+ """Initialize models if not loaded"""
103
+ global models_dict
104
+ if models_dict is None:
105
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
106
+ load_models()
107
+ return models_dict
108
+
109
+ # All your functions remain EXACTLY the same...
110
+ def clean_text(text):
111
+ text = text.lower()
112
+ text = re.sub(r"[^a-z0-9 ]", "", text)
113
+ return text
114
+
115
+ def encode_question_infer(q, vocab):
116
+ q = clean_text(q)
117
+ tokens = q.split()
118
+ MAX_LEN = 20
119
+ enc = [vocab.get(w, vocab["<unk>"]) for w in tokens]
120
+ enc = enc[:MAX_LEN] + [vocab["<pad>"]] * (MAX_LEN - len(enc))
121
+ return torch.tensor(enc, dtype=torch.long)
122
+
123
+ def translate(text, src_lang, tgt_lang, tokenizer, model, lang_code_map, device):
124
+ try:
125
+ tokenizer.src_lang = lang_code_map.get(src_lang, "eng_Latn")
126
+ inputs = tokenizer(text, return_tensors="pt", padding=True).to(device)
127
+ tokens = model.generate(
128
+ **inputs,
129
+ forced_bos_token_id=tokenizer.convert_tokens_to_ids(lang_code_map[tgt_lang]),
130
+ max_length=50, num_beams=5
131
  )
132
+ return tokenizer.decode(tokens[0], skip_special_tokens=True)
133
+ except:
134
+ return text
135
 
136
+ def predict_custom_vqa(image_tensor, question_tensor, model, idx_to_answer, device):
137
+ model.eval()
138
+ with torch.no_grad():
139
+ image_tensor = image_tensor.to(device)
140
+ question_tensor = question_tensor.to(device)
141
+ out = model(image_tensor, question_tensor)
142
+ _, pred = torch.max(out, 1)
143
+ return idx_to_answer[pred.item()]
144
+
145
+ def open_vqa(image, question, processor, blip_model):
146
+ inputs = processor(image, question, return_tensors="pt")
147
+ if torch.cuda.is_available():
148
+ inputs = {k: v.to(blip_model.device) for k, v in inputs.items()}
149
+ out = blip_model.generate(**inputs, max_new_tokens=50)
150
+ return processor.decode(out[0], skip_special_tokens=True)
151
+
152
+ def final_pipeline(image_path_or_pil, question):
153
+ """Main inference function - EXACT SAME as before"""
154
+ init_models()
155
+ m = models_dict
156
+
157
+ if hasattr(image_path_or_pil, 'convert'):
158
+ image = image_path_or_pil.convert("RGB")
159
+ image_tensor = transform(image).unsqueeze(0)
160
+ else:
161
+ image = Image.open(image_path_or_pil).convert("RGB")
162
+ image_tensor = transform(image).unsqueeze(0)
163
+
164
+ try:
165
+ lang = detect(question)
166
+ except:
167
+ lang = "en"
168
+
169
+ if lang != "en":
170
+ q_en = translate(question, lang, "en",
171
+ m['translator_tokenizer'], m['translator_model'],
172
+ m['lang_code_map'], m['device'])
173
+ else:
174
+ q_en = question
175
+
176
+ if any(x in q_en.lower() for x in ["what is", "describe", "this place", "show"]):
177
+ answer_en = open_vqa(image, q_en, m['processor'], m['blip_model'])
178
+ else:
179
+ q_tensor = encode_question_infer(q_en, m['vocab']).unsqueeze(0)
180
+ answer_en = predict_custom_vqa(image_tensor, q_tensor,
181
+ m['model'], m['idx_to_answer'], m['device'])
182
+
183
+ if lang != "en":
184
+ answer = translate(answer_en, "en", lang,
185
+ m['translator_tokenizer'], m['translator_model'],
186
+ m['lang_code_map'], m['device'])
187
+ else:
188
+ answer = answer_en
189
+
190
+ return f"**Detected Language:** {lang}\n**Answer:** {answer}"