File size: 4,218 Bytes
364daa0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 | from transformers import (
Blip2Processor,
Blip2ForConditionalGeneration,
AutoTokenizer,
AutoModelForSeq2SeqLM
)
from langdetect import detect
from PIL import Image
import torch
import pickle
import torchvision.transforms as transforms
# ========================
# PERFORMANCE SETTINGS
# ========================
torch.set_num_threads(4)
# ========================
# DEVICE (CPU ONLY)
# ========================
device = torch.device("cpu")
# ========================
# LOAD BLIP2 (SAFE)
# ========================
print("Loading BLIP2...")
processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
blip_model = Blip2ForConditionalGeneration.from_pretrained(
"Salesforce/blip2-flan-t5-xl"
)
blip_model.to(device)
blip_model.eval()
# ========================
# LOAD TRANSLATOR
# ========================
print("Loading Translator...")
translator_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
translator_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
translator_model.to(device)
translator_model.eval()
lang_code_map = {
"en":"eng_Latn","hi":"hin_Deva","te":"tel_Telu",
"ta":"tam_Taml","kn":"kan_Knda","ml":"mal_Mlym"
}
def translate(text, src, tgt):
translator_tokenizer.src_lang = lang_code_map[src]
inputs = translator_tokenizer(text, return_tensors="pt")
with torch.no_grad():
tokens = translator_model.generate(
**inputs,
forced_bos_token_id=translator_tokenizer.convert_tokens_to_ids(lang_code_map[tgt]),
max_length=50
)
return translator_tokenizer.decode(tokens[0], skip_special_tokens=True)
# ========================
# LOAD CUSTOM MODEL
# ========================
from models.vqa_model import VQAModel
transform = transforms.Compose([
transforms.Resize((224,224)),
transforms.ToTensor()
])
with open("weights/vocab.pkl","rb") as f:
vocab = pickle.load(f)
with open("weights/answers.pkl","rb") as f:
idx_to_answer = pickle.load(f)
custom_model = VQAModel(len(vocab),300,256,len(idx_to_answer))
custom_model.load_state_dict(torch.load("weights/vqa_model.pth", map_location=device))
custom_model.to(device)
custom_model.eval()
def encode_question(q):
tokens = q.lower().split()
enc = [vocab.get(w, vocab["<UNK>"]) for w in tokens]
enc = enc[:20] + [vocab["<PAD>"]] * (20-len(enc))
return torch.tensor(enc).unsqueeze(0)
# ========================
# CUSTOM MODEL
# ========================
def predict_custom_vqa(image_path, question):
image = Image.open(image_path).convert("RGB")
image = transform(image).unsqueeze(0)
q = encode_question(question)
with torch.no_grad():
out = custom_model(image, q)
_, pred = torch.max(out,1)
return idx_to_answer[pred.item()]
# ========================
# BLIP2 (OPTIMIZED)
# ========================
def open_vqa(image_path, question):
image = Image.open(image_path).convert("RGB")
inputs = processor(image, question, return_tensors="pt")
with torch.no_grad():
out = blip_model.generate(
**inputs,
max_new_tokens=15 # 🔥 reduced for speed
)
return processor.decode(out[0], skip_special_tokens=True)
# ========================
# FINAL PIPELINE
# ========================
def final_pipeline(image_path, question):
lang = detect(question)
if lang != "en":
q_en = translate(question, lang, "en")
else:
q_en = question
if "what is" in q_en.lower() or "this place" in q_en.lower():
answer_en = open_vqa(image_path, q_en)
else:
answer_en = predict_custom_vqa(image_path, q_en)
if lang != "en":
return translate(answer_en, "en", lang)
else:
return answer_en
def predict(image_path, question):
return final_pipeline(image_path, question)
# ========================
# WARMUP
# ========================
print("Warming up...")
dummy = Image.new("RGB", (224,224))
processor(dummy, "test", return_tensors="pt")
print("✅ Ready!")
# ========================
# TEST
# ========================
if __name__ == "__main__":
print(predict("test.jpg","What is in the image?")) |