Spaces:
Running
Running
| import os | |
| import torch | |
| import pickle | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| from langdetect import detect | |
| from huggingface_hub import hf_hub_download | |
| from transformers import Blip2Processor, Blip2ForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM | |
| # DOWNLOAD WEIGHTS FROM YOUR MODEL REPO | |
| os.makedirs("weights", exist_ok=True) | |
| if not os.path.exists("weights/vqa_model.pth"): | |
| hf_hub_download(repo_id="PRUTHVIn/vqa_project", filename="weights/vqa_model.pth", local_dir=".") | |
| device = torch.device("cpu") | |
| # LOAD BLIP-2 (The accurate "General" model) | |
| print("Loading BLIP2...") | |
| processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl") | |
| blip_model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xl") | |
| blip_model.to(device).eval() | |
| # LOAD TRANSLATOR | |
| print("Loading Translator...") | |
| translator_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M") | |
| translator_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M") | |
| translator_model.to(device).eval() | |
| lang_code_map = {"en":"eng_Latn","hi":"hin_Deva","te":"tel_Telu","ta":"tam_Taml","kn":"kan_Knda","ml":"mal_Mlym"} | |
| # HELPER FUNCTIONS | |
| def translate(text, src, tgt): | |
| translator_tokenizer.src_lang = lang_code_map[src] | |
| inputs = translator_tokenizer(text, return_tensors="pt") | |
| with torch.no_grad(): | |
| tokens = translator_model.generate(**inputs, forced_bos_token_id=translator_tokenizer.convert_tokens_to_ids(lang_code_map[tgt]), max_length=50) | |
| return translator_tokenizer.decode(tokens[0], skip_special_tokens=True) | |
| # LOAD CUSTOM MODEL | |
| from models.vqa_model import VQAModel | |
| with open("weights/vocab.pkl","rb") as f: vocab = pickle.load(f) | |
| with open("weights/answers.pkl","rb") as f: idx_to_answer = pickle.load(f) | |
| custom_model = VQAModel(len(vocab),300,256,len(idx_to_answer)) | |
| custom_model.load_state_dict(torch.load("weights/vqa_model.pth", map_location=device)) | |
| custom_model.to(device).eval() | |
| def predict_custom_vqa(image, question): | |
| transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor()]) | |
| img_t = transform(image.convert("RGB")).unsqueeze(0) | |
| tokens = question.lower().split() | |
| enc = [vocab.get(w, vocab["<UNK>"]) for w in tokens] | |
| enc = torch.tensor(enc[:20] + [0]*(20-len(enc))).unsqueeze(0) | |
| with torch.no_grad(): | |
| out = custom_model(img_t, enc) | |
| _, pred = torch.max(out, 1) | |
| return idx_to_answer[pred.item()] | |
| def open_vqa(image, question): | |
| inputs = processor(image, question, return_tensors="pt") | |
| with torch.no_grad(): | |
| out = blip_model.generate(**inputs, max_new_tokens=20) | |
| return processor.decode(out[0], skip_special_tokens=True) | |
| # ======================== | |
| # THE SMART PIPELINE | |
| # ======================== | |
| def predict(image, question): | |
| try: | |
| lang = detect(question) | |
| except: | |
| lang = "en" | |
| # 1. Translate to English | |
| q_en = translate(question, lang, "en") if lang != "en" and lang in lang_code_map else question | |
| # 2. Smart Routing: Use BLIP-2 for almost everything to ensure high accuracy | |
| # BLIP-2 is much better at "How many", "What color", and "Describe" | |
| complex_q = ["how many", "color", "what", "describe", "where", "who"] | |
| if any(word in q_en.lower() for word in complex_q): | |
| answer_en = open_vqa(image, q_en) | |
| else: | |
| # Custom model used only for very specific trained patterns | |
| answer_en = predict_custom_vqa(image, q_en) | |
| # 3. Translate back if necessary | |
| if lang != "en" and lang in lang_code_map: | |
| return translate(answer_en, "en", lang) | |
| return answer_en |