PRUTHVIn commited on
Commit
ee742af
·
verified ·
1 Parent(s): 660ccf8

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +190 -0
main.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision.models as models
4
+ import pickle
5
+ import re
6
+ from PIL import Image
7
+ import torchvision.transforms as transforms
8
+ from transformers import Blip2Processor, Blip2ForConditionalGeneration
9
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
10
+ from langdetect import detect
11
+ import numpy as np
12
+ import os
13
+
14
+ # Global models dictionary
15
+ models_dict = None
16
+ device = None
17
+
18
+ # Transforms
19
+ transform = transforms.Compose([
20
+ transforms.Resize((224, 224)),
21
+ transforms.ToTensor()
22
+ ])
23
+
24
+ class VQAModel(nn.Module):
25
+ def __init__(self, vocab_size, embed_dim, hidden_dim, num_answers):
26
+ super().__init__()
27
+ self.cnn = models.resnet18(pretrained=False)
28
+ self.cnn.fc = nn.Identity()
29
+ self.embedding = nn.Embedding(vocab_size, embed_dim)
30
+ self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
31
+ self.fc1 = nn.Linear(512 + hidden_dim, 256)
32
+ self.relu = nn.ReLU()
33
+ self.fc2 = nn.Linear(256, num_answers)
34
+
35
+ def forward(self, image, question):
36
+ img_feat = self.cnn(image)
37
+ q_embed = self.embedding(question)
38
+ _, (h, _) = self.lstm(q_embed)
39
+ q_feat = h.squeeze(0)
40
+ combined = torch.cat((img_feat, q_feat), dim=1)
41
+ x = self.relu(self.fc1(combined))
42
+ out = self.fc2(x)
43
+ return out
44
+
45
+ def load_models():
46
+ """Load all models once at startup"""
47
+ global models_dict, device
48
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
49
+
50
+ print(f"Using device: {device}")
51
+
52
+ # Load custom VQA model
53
+ with open("models/vocab.pkl", "rb") as f:
54
+ vocab = pickle.load(f)
55
+ with open("models/answer_mapping.pkl", "rb") as f:
56
+ idx_to_answer = pickle.load(f)
57
+
58
+ vocab_size = len(vocab)
59
+ model = VQAModel(vocab_size, 300, 256, len(idx_to_answer))
60
+ model.load_state_dict(torch.load("models/vqa_custom_model.pth", map_location=device))
61
+ model.to(device)
62
+ model.eval()
63
+
64
+ # BLIP2 for open-ended (smaller model for free tier)
65
+ print("Loading BLIP2...")
66
+ processor = Blip2Processor.from_pretrained(
67
+ "Salesforce/blip2-flan-t5-base",
68
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
69
+ )
70
+ blip_model = Blip2ForConditionalGeneration.from_pretrained(
71
+ "Salesforce/blip2-flan-t5-base",
72
+ device_map="auto" if torch.cuda.is_available() else None,
73
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
74
+ low_cpu_mem_usage=True
75
+ )
76
+
77
+ # Translator
78
+ print("Loading Translator...")
79
+ translator_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
80
+ translator_model = AutoModelForSeq2SeqLM.from_pretrained(
81
+ "facebook/nllb-200-distilled-600M",
82
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
83
+ ).to(device)
84
+
85
+ lang_code_map = {
86
+ "en": "eng_Latn", "hi": "hin_Deva", "te": "tel_Telu",
87
+ "ta": "tam_Taml", "kn": "kan_Knda", "ml": "mal_Mlym"
88
+ }
89
+
90
+ models_dict = {
91
+ 'model': model, 'vocab': vocab, 'idx_to_answer': idx_to_answer,
92
+ 'processor': processor, 'blip_model': blip_model,
93
+ 'translator_tokenizer': translator_tokenizer,
94
+ 'translator_model': translator_model, 'lang_code_map': lang_code_map,
95
+ 'device': device
96
+ }
97
+
98
+ print("✅ All models loaded successfully!")
99
+ return models_dict
100
+
101
+ def init_models():
102
+ """Initialize models if not loaded"""
103
+ global models_dict
104
+ if models_dict is None:
105
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
106
+ load_models()
107
+ return models_dict
108
+
109
+ # All your functions remain EXACTLY the same...
110
+ def clean_text(text):
111
+ text = text.lower()
112
+ text = re.sub(r"[^a-z0-9 ]", "", text)
113
+ return text
114
+
115
+ def encode_question_infer(q, vocab):
116
+ q = clean_text(q)
117
+ tokens = q.split()
118
+ MAX_LEN = 20
119
+ enc = [vocab.get(w, vocab["<unk>"]) for w in tokens]
120
+ enc = enc[:MAX_LEN] + [vocab["<pad>"]] * (MAX_LEN - len(enc))
121
+ return torch.tensor(enc, dtype=torch.long)
122
+
123
+ def translate(text, src_lang, tgt_lang, tokenizer, model, lang_code_map, device):
124
+ try:
125
+ tokenizer.src_lang = lang_code_map.get(src_lang, "eng_Latn")
126
+ inputs = tokenizer(text, return_tensors="pt", padding=True).to(device)
127
+ tokens = model.generate(
128
+ **inputs,
129
+ forced_bos_token_id=tokenizer.convert_tokens_to_ids(lang_code_map[tgt_lang]),
130
+ max_length=50, num_beams=5
131
+ )
132
+ return tokenizer.decode(tokens[0], skip_special_tokens=True)
133
+ except:
134
+ return text
135
+
136
+ def predict_custom_vqa(image_tensor, question_tensor, model, idx_to_answer, device):
137
+ model.eval()
138
+ with torch.no_grad():
139
+ image_tensor = image_tensor.to(device)
140
+ question_tensor = question_tensor.to(device)
141
+ out = model(image_tensor, question_tensor)
142
+ _, pred = torch.max(out, 1)
143
+ return idx_to_answer[pred.item()]
144
+
145
+ def open_vqa(image, question, processor, blip_model):
146
+ inputs = processor(image, question, return_tensors="pt")
147
+ if torch.cuda.is_available():
148
+ inputs = {k: v.to(blip_model.device) for k, v in inputs.items()}
149
+ out = blip_model.generate(**inputs, max_new_tokens=50)
150
+ return processor.decode(out[0], skip_special_tokens=True)
151
+
152
+ def final_pipeline(image_path_or_pil, question):
153
+ """Main inference function - EXACT SAME as before"""
154
+ init_models()
155
+ m = models_dict
156
+
157
+ if hasattr(image_path_or_pil, 'convert'):
158
+ image = image_path_or_pil.convert("RGB")
159
+ image_tensor = transform(image).unsqueeze(0)
160
+ else:
161
+ image = Image.open(image_path_or_pil).convert("RGB")
162
+ image_tensor = transform(image).unsqueeze(0)
163
+
164
+ try:
165
+ lang = detect(question)
166
+ except:
167
+ lang = "en"
168
+
169
+ if lang != "en":
170
+ q_en = translate(question, lang, "en",
171
+ m['translator_tokenizer'], m['translator_model'],
172
+ m['lang_code_map'], m['device'])
173
+ else:
174
+ q_en = question
175
+
176
+ if any(x in q_en.lower() for x in ["what is", "describe", "this place", "show"]):
177
+ answer_en = open_vqa(image, q_en, m['processor'], m['blip_model'])
178
+ else:
179
+ q_tensor = encode_question_infer(q_en, m['vocab']).unsqueeze(0)
180
+ answer_en = predict_custom_vqa(image_tensor, q_tensor,
181
+ m['model'], m['idx_to_answer'], m['device'])
182
+
183
+ if lang != "en":
184
+ answer = translate(answer_en, "en", lang,
185
+ m['translator_tokenizer'], m['translator_model'],
186
+ m['lang_code_map'], m['device'])
187
+ else:
188
+ answer = answer_en
189
+
190
+ return f"**Detected Language:** {lang}\n**Answer:** {answer}"