vqa_project / inference.py
PRUTHVIn's picture
Upload folder using huggingface_hub
876d5d2 verified
import os
import torch
import pickle
import torchvision.transforms as transforms
from PIL import Image
from langdetect import detect
from huggingface_hub import hf_hub_download
from transformers import Blip2Processor, Blip2ForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM
# DOWNLOAD WEIGHTS FROM YOUR MODEL REPO
os.makedirs("weights", exist_ok=True)
if not os.path.exists("weights/vqa_model.pth"):
hf_hub_download(repo_id="PRUTHVIn/vqa_project", filename="weights/vqa_model.pth", local_dir=".")
device = torch.device("cpu")
# LOAD BLIP-2 (The accurate "General" model)
print("Loading BLIP2...")
processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
blip_model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xl")
blip_model.to(device).eval()
# LOAD TRANSLATOR
print("Loading Translator...")
translator_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
translator_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
translator_model.to(device).eval()
lang_code_map = {"en":"eng_Latn","hi":"hin_Deva","te":"tel_Telu","ta":"tam_Taml","kn":"kan_Knda","ml":"mal_Mlym"}
# HELPER FUNCTIONS
def translate(text, src, tgt):
translator_tokenizer.src_lang = lang_code_map[src]
inputs = translator_tokenizer(text, return_tensors="pt")
with torch.no_grad():
tokens = translator_model.generate(**inputs, forced_bos_token_id=translator_tokenizer.convert_tokens_to_ids(lang_code_map[tgt]), max_length=50)
return translator_tokenizer.decode(tokens[0], skip_special_tokens=True)
# LOAD CUSTOM MODEL
from models.vqa_model import VQAModel
with open("weights/vocab.pkl","rb") as f: vocab = pickle.load(f)
with open("weights/answers.pkl","rb") as f: idx_to_answer = pickle.load(f)
custom_model = VQAModel(len(vocab),300,256,len(idx_to_answer))
custom_model.load_state_dict(torch.load("weights/vqa_model.pth", map_location=device))
custom_model.to(device).eval()
def predict_custom_vqa(image, question):
transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor()])
img_t = transform(image.convert("RGB")).unsqueeze(0)
tokens = question.lower().split()
enc = [vocab.get(w, vocab["<UNK>"]) for w in tokens]
enc = torch.tensor(enc[:20] + [0]*(20-len(enc))).unsqueeze(0)
with torch.no_grad():
out = custom_model(img_t, enc)
_, pred = torch.max(out, 1)
return idx_to_answer[pred.item()]
def open_vqa(image, question):
inputs = processor(image, question, return_tensors="pt")
with torch.no_grad():
out = blip_model.generate(**inputs, max_new_tokens=20)
return processor.decode(out[0], skip_special_tokens=True)
# ========================
# THE SMART PIPELINE
# ========================
def predict(image, question):
try:
lang = detect(question)
except:
lang = "en"
# 1. Translate to English
q_en = translate(question, lang, "en") if lang != "en" and lang in lang_code_map else question
# 2. Smart Routing: Use BLIP-2 for almost everything to ensure high accuracy
# BLIP-2 is much better at "How many", "What color", and "Describe"
complex_q = ["how many", "color", "what", "describe", "where", "who"]
if any(word in q_en.lower() for word in complex_q):
answer_en = open_vqa(image, q_en)
else:
# Custom model used only for very specific trained patterns
answer_en = predict_custom_vqa(image, q_en)
# 3. Translate back if necessary
if lang != "en" and lang in lang_code_map:
return translate(answer_en, "en", lang)
return answer_en