PRUTHVIn commited on
Commit
876d5d2
·
verified ·
1 Parent(s): 13a580c

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. inference.py +46 -133
inference.py CHANGED
@@ -1,177 +1,90 @@
1
  import os
2
  import torch
 
 
 
 
3
  from huggingface_hub import hf_hub_download
 
4
 
5
- # 1. Create the weights folder in the Space environment
6
  os.makedirs("weights", exist_ok=True)
7
-
8
- # 2. Download the heavy .pth file from your MODEL repo to the SPACE
9
- # This only happens once when the Space starts up.
10
  if not os.path.exists("weights/vqa_model.pth"):
11
- print("Downloading weights from Model Hub...")
12
- hf_hub_download(
13
- repo_id="PRUTHVIn/vqa_project",
14
- filename="weights/vqa_model.pth",
15
- local_dir="."
16
- )
17
-
18
- from transformers import (
19
- Blip2Processor,
20
- Blip2ForConditionalGeneration,
21
- AutoTokenizer,
22
- AutoModelForSeq2SeqLM
23
- )
24
- from langdetect import detect
25
- from PIL import Image
26
- import torch
27
- import pickle
28
- import torchvision.transforms as transforms
29
-
30
- # ========================
31
- # PERFORMANCE SETTINGS
32
- # ========================
33
- torch.set_num_threads(4)
34
 
35
- # ========================
36
- # DEVICE (CPU ONLY)
37
- # ========================
38
  device = torch.device("cpu")
39
 
40
- # ========================
41
- # LOAD BLIP2 (SAFE)
42
- # ========================
43
  print("Loading BLIP2...")
44
-
45
  processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
 
 
46
 
47
- blip_model = Blip2ForConditionalGeneration.from_pretrained(
48
- "Salesforce/blip2-flan-t5-xl"
49
- )
50
-
51
- blip_model.to(device)
52
- blip_model.eval()
53
-
54
- # ========================
55
  # LOAD TRANSLATOR
56
- # ========================
57
  print("Loading Translator...")
58
-
59
  translator_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
60
  translator_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
 
61
 
62
- translator_model.to(device)
63
- translator_model.eval()
64
-
65
- lang_code_map = {
66
- "en":"eng_Latn","hi":"hin_Deva","te":"tel_Telu",
67
- "ta":"tam_Taml","kn":"kan_Knda","ml":"mal_Mlym"
68
- }
69
 
 
70
  def translate(text, src, tgt):
71
  translator_tokenizer.src_lang = lang_code_map[src]
72
  inputs = translator_tokenizer(text, return_tensors="pt")
73
-
74
  with torch.no_grad():
75
- tokens = translator_model.generate(
76
- **inputs,
77
- forced_bos_token_id=translator_tokenizer.convert_tokens_to_ids(lang_code_map[tgt]),
78
- max_length=50
79
- )
80
-
81
  return translator_tokenizer.decode(tokens[0], skip_special_tokens=True)
82
 
83
- # ========================
84
  # LOAD CUSTOM MODEL
85
- # ========================
86
  from models.vqa_model import VQAModel
87
-
88
- transform = transforms.Compose([
89
- transforms.Resize((224,224)),
90
- transforms.ToTensor()
91
- ])
92
-
93
- with open("weights/vocab.pkl","rb") as f:
94
- vocab = pickle.load(f)
95
-
96
- with open("weights/answers.pkl","rb") as f:
97
- idx_to_answer = pickle.load(f)
98
 
99
  custom_model = VQAModel(len(vocab),300,256,len(idx_to_answer))
100
  custom_model.load_state_dict(torch.load("weights/vqa_model.pth", map_location=device))
101
- custom_model.to(device)
102
- custom_model.eval()
103
 
104
- def encode_question(q):
105
- tokens = q.lower().split()
 
 
106
  enc = [vocab.get(w, vocab["<UNK>"]) for w in tokens]
107
- enc = enc[:20] + [vocab["<PAD>"]] * (20-len(enc))
108
- return torch.tensor(enc).unsqueeze(0)
109
-
110
- # ========================
111
- # CUSTOM MODEL
112
- # ========================
113
- def predict_custom_vqa(image_path, question):
114
- image = Image.open(image_path).convert("RGB")
115
- image = transform(image).unsqueeze(0)
116
- q = encode_question(question)
117
-
118
  with torch.no_grad():
119
- out = custom_model(image, q)
120
- _, pred = torch.max(out,1)
121
-
122
  return idx_to_answer[pred.item()]
123
 
124
- # ========================
125
- # BLIP2 (OPTIMIZED)
126
- # ========================
127
- def open_vqa(image_path, question):
128
- image = Image.open(image_path).convert("RGB")
129
-
130
  inputs = processor(image, question, return_tensors="pt")
131
-
132
  with torch.no_grad():
133
- out = blip_model.generate(
134
- **inputs,
135
- max_new_tokens=15 # 🔥 reduced for speed
136
- )
137
-
138
  return processor.decode(out[0], skip_special_tokens=True)
139
 
140
  # ========================
141
- # FINAL PIPELINE
142
  # ========================
143
- def final_pipeline(image_path, question):
144
- lang = detect(question)
 
 
 
145
 
146
- if lang != "en":
147
- q_en = translate(question, lang, "en")
148
- else:
149
- q_en = question
150
-
151
- if "what is" in q_en.lower() or "this place" in q_en.lower():
152
- answer_en = open_vqa(image_path, q_en)
 
 
153
  else:
154
- answer_en = predict_custom_vqa(image_path, q_en)
 
155
 
156
- if lang != "en":
 
157
  return translate(answer_en, "en", lang)
158
- else:
159
- return answer_en
160
-
161
- def predict(image_path, question):
162
- return final_pipeline(image_path, question)
163
-
164
- # ========================
165
- # WARMUP
166
- # ========================
167
- print("Warming up...")
168
- dummy = Image.new("RGB", (224,224))
169
- processor(dummy, "test", return_tensors="pt")
170
-
171
- print("✅ Ready!")
172
-
173
- # ========================
174
- # TEST
175
- # ========================
176
- if __name__ == "__main__":
177
- print(predict("test.jpg","What is in the image?"))
 
1
  import os
2
  import torch
3
+ import pickle
4
+ import torchvision.transforms as transforms
5
+ from PIL import Image
6
+ from langdetect import detect
7
  from huggingface_hub import hf_hub_download
8
+ from transformers import Blip2Processor, Blip2ForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM
9
 
10
+ # DOWNLOAD WEIGHTS FROM YOUR MODEL REPO
11
  os.makedirs("weights", exist_ok=True)
 
 
 
12
  if not os.path.exists("weights/vqa_model.pth"):
13
+ hf_hub_download(repo_id="PRUTHVIn/vqa_project", filename="weights/vqa_model.pth", local_dir=".")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
 
 
 
15
  device = torch.device("cpu")
16
 
17
+ # LOAD BLIP-2 (The accurate "General" model)
 
 
18
  print("Loading BLIP2...")
 
19
  processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
20
+ blip_model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xl")
21
+ blip_model.to(device).eval()
22
 
 
 
 
 
 
 
 
 
23
  # LOAD TRANSLATOR
 
24
  print("Loading Translator...")
 
25
  translator_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
26
  translator_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
27
+ translator_model.to(device).eval()
28
 
29
+ lang_code_map = {"en":"eng_Latn","hi":"hin_Deva","te":"tel_Telu","ta":"tam_Taml","kn":"kan_Knda","ml":"mal_Mlym"}
 
 
 
 
 
 
30
 
31
+ # HELPER FUNCTIONS
32
  def translate(text, src, tgt):
33
  translator_tokenizer.src_lang = lang_code_map[src]
34
  inputs = translator_tokenizer(text, return_tensors="pt")
 
35
  with torch.no_grad():
36
+ tokens = translator_model.generate(**inputs, forced_bos_token_id=translator_tokenizer.convert_tokens_to_ids(lang_code_map[tgt]), max_length=50)
 
 
 
 
 
37
  return translator_tokenizer.decode(tokens[0], skip_special_tokens=True)
38
 
 
39
  # LOAD CUSTOM MODEL
 
40
  from models.vqa_model import VQAModel
41
+ with open("weights/vocab.pkl","rb") as f: vocab = pickle.load(f)
42
+ with open("weights/answers.pkl","rb") as f: idx_to_answer = pickle.load(f)
 
 
 
 
 
 
 
 
 
43
 
44
  custom_model = VQAModel(len(vocab),300,256,len(idx_to_answer))
45
  custom_model.load_state_dict(torch.load("weights/vqa_model.pth", map_location=device))
46
+ custom_model.to(device).eval()
 
47
 
48
+ def predict_custom_vqa(image, question):
49
+ transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor()])
50
+ img_t = transform(image.convert("RGB")).unsqueeze(0)
51
+ tokens = question.lower().split()
52
  enc = [vocab.get(w, vocab["<UNK>"]) for w in tokens]
53
+ enc = torch.tensor(enc[:20] + [0]*(20-len(enc))).unsqueeze(0)
 
 
 
 
 
 
 
 
 
 
54
  with torch.no_grad():
55
+ out = custom_model(img_t, enc)
56
+ _, pred = torch.max(out, 1)
 
57
  return idx_to_answer[pred.item()]
58
 
59
+ def open_vqa(image, question):
 
 
 
 
 
60
  inputs = processor(image, question, return_tensors="pt")
 
61
  with torch.no_grad():
62
+ out = blip_model.generate(**inputs, max_new_tokens=20)
 
 
 
 
63
  return processor.decode(out[0], skip_special_tokens=True)
64
 
65
  # ========================
66
+ # THE SMART PIPELINE
67
  # ========================
68
+ def predict(image, question):
69
+ try:
70
+ lang = detect(question)
71
+ except:
72
+ lang = "en"
73
 
74
+ # 1. Translate to English
75
+ q_en = translate(question, lang, "en") if lang != "en" and lang in lang_code_map else question
76
+
77
+ # 2. Smart Routing: Use BLIP-2 for almost everything to ensure high accuracy
78
+ # BLIP-2 is much better at "How many", "What color", and "Describe"
79
+ complex_q = ["how many", "color", "what", "describe", "where", "who"]
80
+
81
+ if any(word in q_en.lower() for word in complex_q):
82
+ answer_en = open_vqa(image, q_en)
83
  else:
84
+ # Custom model used only for very specific trained patterns
85
+ answer_en = predict_custom_vqa(image, q_en)
86
 
87
+ # 3. Translate back if necessary
88
+ if lang != "en" and lang in lang_code_map:
89
  return translate(answer_en, "en", lang)
90
+ return answer_en