File size: 3,653 Bytes
13a580c
 
876d5d2
 
 
 
13a580c
876d5d2
13a580c
876d5d2
13a580c
 
876d5d2
1e5f3d4
 
 
876d5d2
1e5f3d4
 
876d5d2
 
1e5f3d4
 
 
 
 
876d5d2
1e5f3d4
876d5d2
1e5f3d4
876d5d2
1e5f3d4
 
 
 
876d5d2
1e5f3d4
 
 
 
876d5d2
 
1e5f3d4
 
 
876d5d2
1e5f3d4
876d5d2
 
 
 
1e5f3d4
876d5d2
1e5f3d4
876d5d2
 
1e5f3d4
 
876d5d2
1e5f3d4
 
876d5d2
1e5f3d4
 
 
876d5d2
1e5f3d4
876d5d2
 
 
 
 
1e5f3d4
876d5d2
 
 
 
 
 
 
 
 
1e5f3d4
876d5d2
 
1e5f3d4
876d5d2
 
1e5f3d4
876d5d2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
import torch
import pickle
import torchvision.transforms as transforms
from PIL import Image
from langdetect import detect
from huggingface_hub import hf_hub_download
from transformers import Blip2Processor, Blip2ForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM

# DOWNLOAD WEIGHTS FROM YOUR MODEL REPO
os.makedirs("weights", exist_ok=True)
if not os.path.exists("weights/vqa_model.pth"):
    hf_hub_download(repo_id="PRUTHVIn/vqa_project", filename="weights/vqa_model.pth", local_dir=".")

device = torch.device("cpu")

# LOAD BLIP-2 (The accurate "General" model)
print("Loading BLIP2...")
processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
blip_model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xl")
blip_model.to(device).eval()

# LOAD TRANSLATOR
print("Loading Translator...")
translator_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
translator_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
translator_model.to(device).eval()

lang_code_map = {"en":"eng_Latn","hi":"hin_Deva","te":"tel_Telu","ta":"tam_Taml","kn":"kan_Knda","ml":"mal_Mlym"}

# HELPER FUNCTIONS
def translate(text, src, tgt):
    translator_tokenizer.src_lang = lang_code_map[src]
    inputs = translator_tokenizer(text, return_tensors="pt")
    with torch.no_grad():
        tokens = translator_model.generate(**inputs, forced_bos_token_id=translator_tokenizer.convert_tokens_to_ids(lang_code_map[tgt]), max_length=50)
    return translator_tokenizer.decode(tokens[0], skip_special_tokens=True)

# LOAD CUSTOM MODEL
from models.vqa_model import VQAModel
with open("weights/vocab.pkl","rb") as f: vocab = pickle.load(f)
with open("weights/answers.pkl","rb") as f: idx_to_answer = pickle.load(f)

custom_model = VQAModel(len(vocab),300,256,len(idx_to_answer))
custom_model.load_state_dict(torch.load("weights/vqa_model.pth", map_location=device))
custom_model.to(device).eval()

def predict_custom_vqa(image, question):
    transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor()])
    img_t = transform(image.convert("RGB")).unsqueeze(0)
    tokens = question.lower().split()
    enc = [vocab.get(w, vocab["<UNK>"]) for w in tokens]
    enc = torch.tensor(enc[:20] + [0]*(20-len(enc))).unsqueeze(0)
    with torch.no_grad():
        out = custom_model(img_t, enc)
        _, pred = torch.max(out, 1)
    return idx_to_answer[pred.item()]

def open_vqa(image, question):
    inputs = processor(image, question, return_tensors="pt")
    with torch.no_grad():
        out = blip_model.generate(**inputs, max_new_tokens=20)
    return processor.decode(out[0], skip_special_tokens=True)

# ========================
# THE SMART PIPELINE
# ========================
def predict(image, question):
    try:
        lang = detect(question)
    except:
        lang = "en"

    # 1. Translate to English
    q_en = translate(question, lang, "en") if lang != "en" and lang in lang_code_map else question
    
    # 2. Smart Routing: Use BLIP-2 for almost everything to ensure high accuracy
    # BLIP-2 is much better at "How many", "What color", and "Describe"
    complex_q = ["how many", "color", "what", "describe", "where", "who"]
    
    if any(word in q_en.lower() for word in complex_q):
        answer_en = open_vqa(image, q_en)
    else:
        # Custom model used only for very specific trained patterns
        answer_en = predict_custom_vqa(image, q_en)

    # 3. Translate back if necessary
    if lang != "en" and lang in lang_code_map:
        return translate(answer_en, "en", lang)
    return answer_en