import os import torch import pickle import torchvision.transforms as transforms from PIL import Image from langdetect import detect from huggingface_hub import hf_hub_download from transformers import Blip2Processor, Blip2ForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM # DOWNLOAD WEIGHTS FROM YOUR MODEL REPO os.makedirs("weights", exist_ok=True) if not os.path.exists("weights/vqa_model.pth"): hf_hub_download(repo_id="PRUTHVIn/vqa_project", filename="weights/vqa_model.pth", local_dir=".") device = torch.device("cpu") # LOAD BLIP-2 (The accurate "General" model) print("Loading BLIP2...") processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl") blip_model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xl") blip_model.to(device).eval() # LOAD TRANSLATOR print("Loading Translator...") translator_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M") translator_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M") translator_model.to(device).eval() lang_code_map = {"en":"eng_Latn","hi":"hin_Deva","te":"tel_Telu","ta":"tam_Taml","kn":"kan_Knda","ml":"mal_Mlym"} # HELPER FUNCTIONS def translate(text, src, tgt): translator_tokenizer.src_lang = lang_code_map[src] inputs = translator_tokenizer(text, return_tensors="pt") with torch.no_grad(): tokens = translator_model.generate(**inputs, forced_bos_token_id=translator_tokenizer.convert_tokens_to_ids(lang_code_map[tgt]), max_length=50) return translator_tokenizer.decode(tokens[0], skip_special_tokens=True) # LOAD CUSTOM MODEL from models.vqa_model import VQAModel with open("weights/vocab.pkl","rb") as f: vocab = pickle.load(f) with open("weights/answers.pkl","rb") as f: idx_to_answer = pickle.load(f) custom_model = VQAModel(len(vocab),300,256,len(idx_to_answer)) custom_model.load_state_dict(torch.load("weights/vqa_model.pth", map_location=device)) custom_model.to(device).eval() def predict_custom_vqa(image, question): transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor()]) img_t = transform(image.convert("RGB")).unsqueeze(0) tokens = question.lower().split() enc = [vocab.get(w, vocab[""]) for w in tokens] enc = torch.tensor(enc[:20] + [0]*(20-len(enc))).unsqueeze(0) with torch.no_grad(): out = custom_model(img_t, enc) _, pred = torch.max(out, 1) return idx_to_answer[pred.item()] def open_vqa(image, question): inputs = processor(image, question, return_tensors="pt") with torch.no_grad(): out = blip_model.generate(**inputs, max_new_tokens=20) return processor.decode(out[0], skip_special_tokens=True) # ======================== # THE SMART PIPELINE # ======================== def predict(image, question): try: lang = detect(question) except: lang = "en" # 1. Translate to English q_en = translate(question, lang, "en") if lang != "en" and lang in lang_code_map else question # 2. Smart Routing: Use BLIP-2 for almost everything to ensure high accuracy # BLIP-2 is much better at "How many", "What color", and "Describe" complex_q = ["how many", "color", "what", "describe", "where", "who"] if any(word in q_en.lower() for word in complex_q): answer_en = open_vqa(image, q_en) else: # Custom model used only for very specific trained patterns answer_en = predict_custom_vqa(image, q_en) # 3. Translate back if necessary if lang != "en" and lang in lang_code_map: return translate(answer_en, "en", lang) return answer_en