| from transformers import ( |
| Blip2Processor, |
| Blip2ForConditionalGeneration, |
| AutoTokenizer, |
| AutoModelForSeq2SeqLM |
| ) |
| from langdetect import detect |
| from PIL import Image |
| import torch |
| import pickle |
| import torchvision.transforms as transforms |
|
|
| |
| |
| |
| torch.set_num_threads(4) |
|
|
| |
| |
| |
| device = torch.device("cpu") |
|
|
| |
| |
| |
| print("Loading BLIP2...") |
|
|
| processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl") |
|
|
| blip_model = Blip2ForConditionalGeneration.from_pretrained( |
| "Salesforce/blip2-flan-t5-xl" |
| ) |
|
|
| blip_model.to(device) |
| blip_model.eval() |
|
|
| |
| |
| |
| print("Loading Translator...") |
|
|
| translator_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M") |
| translator_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M") |
|
|
| translator_model.to(device) |
| translator_model.eval() |
|
|
| lang_code_map = { |
| "en":"eng_Latn","hi":"hin_Deva","te":"tel_Telu", |
| "ta":"tam_Taml","kn":"kan_Knda","ml":"mal_Mlym" |
| } |
|
|
| def translate(text, src, tgt): |
| translator_tokenizer.src_lang = lang_code_map[src] |
| inputs = translator_tokenizer(text, return_tensors="pt") |
|
|
| with torch.no_grad(): |
| tokens = translator_model.generate( |
| **inputs, |
| forced_bos_token_id=translator_tokenizer.convert_tokens_to_ids(lang_code_map[tgt]), |
| max_length=50 |
| ) |
|
|
| return translator_tokenizer.decode(tokens[0], skip_special_tokens=True) |
|
|
| |
| |
| |
| from models.vqa_model import VQAModel |
|
|
| transform = transforms.Compose([ |
| transforms.Resize((224,224)), |
| transforms.ToTensor() |
| ]) |
|
|
| with open("weights/vocab.pkl","rb") as f: |
| vocab = pickle.load(f) |
|
|
| with open("weights/answers.pkl","rb") as f: |
| idx_to_answer = pickle.load(f) |
|
|
| custom_model = VQAModel(len(vocab),300,256,len(idx_to_answer)) |
| custom_model.load_state_dict(torch.load("weights/vqa_model.pth", map_location=device)) |
| custom_model.to(device) |
| custom_model.eval() |
|
|
| def encode_question(q): |
| tokens = q.lower().split() |
| enc = [vocab.get(w, vocab["<UNK>"]) for w in tokens] |
| enc = enc[:20] + [vocab["<PAD>"]] * (20-len(enc)) |
| return torch.tensor(enc).unsqueeze(0) |
|
|
| |
| |
| |
| def predict_custom_vqa(image_path, question): |
| image = Image.open(image_path).convert("RGB") |
| image = transform(image).unsqueeze(0) |
| q = encode_question(question) |
|
|
| with torch.no_grad(): |
| out = custom_model(image, q) |
| _, pred = torch.max(out,1) |
|
|
| return idx_to_answer[pred.item()] |
|
|
| |
| |
| |
| def open_vqa(image_path, question): |
| image = Image.open(image_path).convert("RGB") |
|
|
| inputs = processor(image, question, return_tensors="pt") |
|
|
| with torch.no_grad(): |
| out = blip_model.generate( |
| **inputs, |
| max_new_tokens=15 |
| ) |
|
|
| return processor.decode(out[0], skip_special_tokens=True) |
|
|
| |
| |
| |
| def final_pipeline(image_path, question): |
| lang = detect(question) |
|
|
| if lang != "en": |
| q_en = translate(question, lang, "en") |
| else: |
| q_en = question |
|
|
| if "what is" in q_en.lower() or "this place" in q_en.lower(): |
| answer_en = open_vqa(image_path, q_en) |
| else: |
| answer_en = predict_custom_vqa(image_path, q_en) |
|
|
| if lang != "en": |
| return translate(answer_en, "en", lang) |
| else: |
| return answer_en |
|
|
| def predict(image_path, question): |
| return final_pipeline(image_path, question) |
|
|
| |
| |
| |
| print("Warming up...") |
| dummy = Image.new("RGB", (224,224)) |
| processor(dummy, "test", return_tensors="pt") |
|
|
| print("✅ Ready!") |
|
|
| |
| |
| |
| if __name__ == "__main__": |
| print(predict("test.jpg","What is in the image?")) |