Spaces:
Running
Running
File size: 3,653 Bytes
13a580c 876d5d2 13a580c 876d5d2 13a580c 876d5d2 13a580c 876d5d2 1e5f3d4 876d5d2 1e5f3d4 876d5d2 1e5f3d4 876d5d2 1e5f3d4 876d5d2 1e5f3d4 876d5d2 1e5f3d4 876d5d2 1e5f3d4 876d5d2 1e5f3d4 876d5d2 1e5f3d4 876d5d2 1e5f3d4 876d5d2 1e5f3d4 876d5d2 1e5f3d4 876d5d2 1e5f3d4 876d5d2 1e5f3d4 876d5d2 1e5f3d4 876d5d2 1e5f3d4 876d5d2 1e5f3d4 876d5d2 1e5f3d4 876d5d2 1e5f3d4 876d5d2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 | import os
import torch
import pickle
import torchvision.transforms as transforms
from PIL import Image
from langdetect import detect
from huggingface_hub import hf_hub_download
from transformers import Blip2Processor, Blip2ForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM
# DOWNLOAD WEIGHTS FROM YOUR MODEL REPO
os.makedirs("weights", exist_ok=True)
if not os.path.exists("weights/vqa_model.pth"):
hf_hub_download(repo_id="PRUTHVIn/vqa_project", filename="weights/vqa_model.pth", local_dir=".")
device = torch.device("cpu")
# LOAD BLIP-2 (The accurate "General" model)
print("Loading BLIP2...")
processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
blip_model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xl")
blip_model.to(device).eval()
# LOAD TRANSLATOR
print("Loading Translator...")
translator_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
translator_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
translator_model.to(device).eval()
lang_code_map = {"en":"eng_Latn","hi":"hin_Deva","te":"tel_Telu","ta":"tam_Taml","kn":"kan_Knda","ml":"mal_Mlym"}
# HELPER FUNCTIONS
def translate(text, src, tgt):
translator_tokenizer.src_lang = lang_code_map[src]
inputs = translator_tokenizer(text, return_tensors="pt")
with torch.no_grad():
tokens = translator_model.generate(**inputs, forced_bos_token_id=translator_tokenizer.convert_tokens_to_ids(lang_code_map[tgt]), max_length=50)
return translator_tokenizer.decode(tokens[0], skip_special_tokens=True)
# LOAD CUSTOM MODEL
from models.vqa_model import VQAModel
with open("weights/vocab.pkl","rb") as f: vocab = pickle.load(f)
with open("weights/answers.pkl","rb") as f: idx_to_answer = pickle.load(f)
custom_model = VQAModel(len(vocab),300,256,len(idx_to_answer))
custom_model.load_state_dict(torch.load("weights/vqa_model.pth", map_location=device))
custom_model.to(device).eval()
def predict_custom_vqa(image, question):
transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor()])
img_t = transform(image.convert("RGB")).unsqueeze(0)
tokens = question.lower().split()
enc = [vocab.get(w, vocab["<UNK>"]) for w in tokens]
enc = torch.tensor(enc[:20] + [0]*(20-len(enc))).unsqueeze(0)
with torch.no_grad():
out = custom_model(img_t, enc)
_, pred = torch.max(out, 1)
return idx_to_answer[pred.item()]
def open_vqa(image, question):
inputs = processor(image, question, return_tensors="pt")
with torch.no_grad():
out = blip_model.generate(**inputs, max_new_tokens=20)
return processor.decode(out[0], skip_special_tokens=True)
# ========================
# THE SMART PIPELINE
# ========================
def predict(image, question):
try:
lang = detect(question)
except:
lang = "en"
# 1. Translate to English
q_en = translate(question, lang, "en") if lang != "en" and lang in lang_code_map else question
# 2. Smart Routing: Use BLIP-2 for almost everything to ensure high accuracy
# BLIP-2 is much better at "How many", "What color", and "Describe"
complex_q = ["how many", "color", "what", "describe", "where", "who"]
if any(word in q_en.lower() for word in complex_q):
answer_en = open_vqa(image, q_en)
else:
# Custom model used only for very specific trained patterns
answer_en = predict_custom_vqa(image, q_en)
# 3. Translate back if necessary
if lang != "en" and lang in lang_code_map:
return translate(answer_en, "en", lang)
return answer_en |