PRUTHVIn commited on
Commit
1e5f3d4
·
verified ·
1 Parent(s): 9905fe1

Upload folder using huggingface_hub

Browse files
.DS_Store ADDED
Binary file (8.2 kB). View file
 
__pycache__/api.cpython-39.pyc ADDED
Binary file (850 Bytes). View file
 
__pycache__/config.cpython-39.pyc ADDED
Binary file (479 Bytes). View file
 
__pycache__/inference.cpython-39.pyc ADDED
Binary file (2.25 kB). View file
 
api.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile, Form
2
+ from inference import predict
3
+ import shutil
4
+ import os
5
+
6
+ app = FastAPI()
7
+
8
+ UPLOAD_DIR = "temp"
9
+ os.makedirs(UPLOAD_DIR, exist_ok=True)
10
+
11
+ @app.post("/predict")
12
+ async def predict_api(file: UploadFile = File(...), question: str = Form(...)):
13
+ try:
14
+ file_path = os.path.join(UPLOAD_DIR, file.filename)
15
+
16
+ with open(file_path, "wb") as buffer:
17
+ shutil.copyfileobj(file.file, buffer)
18
+
19
+ answer = predict(file_path, question)
20
+
21
+ return {"answer": answer}
22
+
23
+ except Exception as e:
24
+ return {"error": str(e)}
app.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from inference import predict
3
+ import torch
4
+ from huggingface_hub import hf_hub_download
5
+
6
+ # This pulls just the model file from your specific repo
7
+ model_path = hf_hub_download(repo_id="PRUTHVIn/vqa_project", filename="weights/vqa_model.pth")
8
+
9
+ # Now load it into your model class (example)
10
+ # model.load_state_dict(torch.load(model_path))
11
+
12
+ def vqa_interface(image, question):
13
+ try:
14
+ if image is None or question.strip() == "":
15
+ return "Please upload an image and enter a question."
16
+
17
+ answer = predict(image, question)
18
+ return answer
19
+
20
+ except Exception as e:
21
+ print("ERROR:", str(e))
22
+ return f"Error: {str(e)}"
23
+
24
+
25
+ iface = gr.Interface(
26
+ fn=vqa_interface,
27
+ inputs=[
28
+ gr.Image(type="filepath", label="Upload Image"),
29
+ gr.Textbox(
30
+ label="Ask a Question",
31
+ placeholder="e.g. What is in the image?"
32
+ )
33
+ ],
34
+ outputs=gr.Textbox(label="Answer"),
35
+ title="🧠 Smart Visual Question Answering System",
36
+ description="Upload any image and ask anything (works for medical + general images)",
37
+ theme="soft"
38
+ )
39
+
40
+ if __name__ == "__main__":
41
+ iface.launch()
config.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
4
+
5
+ MAX_LEN = 20
6
+ EMBED_DIM = 300
7
+ HIDDEN_DIM = 256
8
+ BATCH_SIZE = 32
9
+ LR = 1e-3
10
+ EPOCHS = 5
11
+
12
+ MODEL_PATH = "weights/vqa_model.pth"
13
+ VOCAB_PATH = "weights/vocab.pkl"
14
+ ANSWER_PATH = "weights/answers.pkl"
index.html ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html>
3
+ <head>
4
+ <title>VQA App</title>
5
+ </head>
6
+ <body>
7
+
8
+ <h2>Visual Question Answering</h2>
9
+
10
+ <input type="file" id="image"><br><br>
11
+ <input type="text" id="question" placeholder="Ask a question"><br><br>
12
+
13
+ <button onclick="send()">Submit</button>
14
+
15
+ <h3 id="result"></h3>
16
+
17
+ <script>
18
+ async function send() {
19
+ const file = document.getElementById("image").files[0];
20
+ const question = document.getElementById("question").value;
21
+
22
+ let formData = new FormData();
23
+ formData.append("file", file);
24
+ formData.append("question", question);
25
+
26
+ const res = await fetch("http://127.0.0.1:8000/predict", {
27
+ method: "POST",
28
+ body: formData
29
+ });
30
+
31
+ const data = await res.json();
32
+ document.getElementById("result").innerText = data.answer || data.error;
33
+ }
34
+ </script>
35
+
36
+ </body>
37
+ </html>
inference.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import (
2
+ Blip2Processor,
3
+ Blip2ForConditionalGeneration,
4
+ AutoTokenizer,
5
+ AutoModelForSeq2SeqLM
6
+ )
7
+ from langdetect import detect
8
+ from PIL import Image
9
+ import torch
10
+ import pickle
11
+ import torchvision.transforms as transforms
12
+
13
+ # ========================
14
+ # PERFORMANCE SETTINGS
15
+ # ========================
16
+ torch.set_num_threads(4)
17
+
18
+ # ========================
19
+ # DEVICE (CPU ONLY)
20
+ # ========================
21
+ device = torch.device("cpu")
22
+
23
+ # ========================
24
+ # LOAD BLIP2 (SAFE)
25
+ # ========================
26
+ print("Loading BLIP2...")
27
+
28
+ processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
29
+
30
+ blip_model = Blip2ForConditionalGeneration.from_pretrained(
31
+ "Salesforce/blip2-flan-t5-xl"
32
+ )
33
+
34
+ blip_model.to(device)
35
+ blip_model.eval()
36
+
37
+ # ========================
38
+ # LOAD TRANSLATOR
39
+ # ========================
40
+ print("Loading Translator...")
41
+
42
+ translator_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
43
+ translator_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
44
+
45
+ translator_model.to(device)
46
+ translator_model.eval()
47
+
48
+ lang_code_map = {
49
+ "en":"eng_Latn","hi":"hin_Deva","te":"tel_Telu",
50
+ "ta":"tam_Taml","kn":"kan_Knda","ml":"mal_Mlym"
51
+ }
52
+
53
+ def translate(text, src, tgt):
54
+ translator_tokenizer.src_lang = lang_code_map[src]
55
+ inputs = translator_tokenizer(text, return_tensors="pt")
56
+
57
+ with torch.no_grad():
58
+ tokens = translator_model.generate(
59
+ **inputs,
60
+ forced_bos_token_id=translator_tokenizer.convert_tokens_to_ids(lang_code_map[tgt]),
61
+ max_length=50
62
+ )
63
+
64
+ return translator_tokenizer.decode(tokens[0], skip_special_tokens=True)
65
+
66
+ # ========================
67
+ # LOAD CUSTOM MODEL
68
+ # ========================
69
+ from models.vqa_model import VQAModel
70
+
71
+ transform = transforms.Compose([
72
+ transforms.Resize((224,224)),
73
+ transforms.ToTensor()
74
+ ])
75
+
76
+ with open("weights/vocab.pkl","rb") as f:
77
+ vocab = pickle.load(f)
78
+
79
+ with open("weights/answers.pkl","rb") as f:
80
+ idx_to_answer = pickle.load(f)
81
+
82
+ custom_model = VQAModel(len(vocab),300,256,len(idx_to_answer))
83
+ custom_model.load_state_dict(torch.load("weights/vqa_model.pth", map_location=device))
84
+ custom_model.to(device)
85
+ custom_model.eval()
86
+
87
+ def encode_question(q):
88
+ tokens = q.lower().split()
89
+ enc = [vocab.get(w, vocab["<UNK>"]) for w in tokens]
90
+ enc = enc[:20] + [vocab["<PAD>"]] * (20-len(enc))
91
+ return torch.tensor(enc).unsqueeze(0)
92
+
93
+ # ========================
94
+ # CUSTOM MODEL
95
+ # ========================
96
+ def predict_custom_vqa(image_path, question):
97
+ image = Image.open(image_path).convert("RGB")
98
+ image = transform(image).unsqueeze(0)
99
+ q = encode_question(question)
100
+
101
+ with torch.no_grad():
102
+ out = custom_model(image, q)
103
+ _, pred = torch.max(out,1)
104
+
105
+ return idx_to_answer[pred.item()]
106
+
107
+ # ========================
108
+ # BLIP2 (OPTIMIZED)
109
+ # ========================
110
+ def open_vqa(image_path, question):
111
+ image = Image.open(image_path).convert("RGB")
112
+
113
+ inputs = processor(image, question, return_tensors="pt")
114
+
115
+ with torch.no_grad():
116
+ out = blip_model.generate(
117
+ **inputs,
118
+ max_new_tokens=15 # 🔥 reduced for speed
119
+ )
120
+
121
+ return processor.decode(out[0], skip_special_tokens=True)
122
+
123
+ # ========================
124
+ # FINAL PIPELINE
125
+ # ========================
126
+ def final_pipeline(image_path, question):
127
+ lang = detect(question)
128
+
129
+ if lang != "en":
130
+ q_en = translate(question, lang, "en")
131
+ else:
132
+ q_en = question
133
+
134
+ if "what is" in q_en.lower() or "this place" in q_en.lower():
135
+ answer_en = open_vqa(image_path, q_en)
136
+ else:
137
+ answer_en = predict_custom_vqa(image_path, q_en)
138
+
139
+ if lang != "en":
140
+ return translate(answer_en, "en", lang)
141
+ else:
142
+ return answer_en
143
+
144
+ def predict(image_path, question):
145
+ return final_pipeline(image_path, question)
146
+
147
+ # ========================
148
+ # WARMUP
149
+ # ========================
150
+ print("Warming up...")
151
+ dummy = Image.new("RGB", (224,224))
152
+ processor(dummy, "test", return_tensors="pt")
153
+
154
+ print("✅ Ready!")
155
+
156
+ # ========================
157
+ # TEST
158
+ # ========================
159
+ if __name__ == "__main__":
160
+ print(predict("test.jpg","What is in the image?"))
models/__pycache__/vqa_model.cpython-39.pyc ADDED
Binary file (1.22 kB). View file
 
models/pretrained.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ def open_vqa_stub(image_path, question):
2
+ return "Pretrained VQA disabled (too heavy for local)."
models/vqa_model.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision.models as models
4
+
5
+ class VQAModel(nn.Module):
6
+ def __init__(self, vocab_size, embed_dim, hidden_dim, num_answers):
7
+ super().__init__()
8
+
9
+ self.cnn = models.resnet18(weights="DEFAULT")
10
+ self.cnn.fc = nn.Identity()
11
+
12
+ self.embedding = nn.Embedding(vocab_size, embed_dim)
13
+ self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
14
+
15
+ self.fc1 = nn.Linear(512 + hidden_dim, 256)
16
+ self.relu = nn.ReLU()
17
+ self.fc2 = nn.Linear(256, num_answers)
18
+
19
+ def forward(self, image, question):
20
+ img_feat = self.cnn(image)
21
+
22
+ q_embed = self.embedding(question)
23
+ _, (h, _) = self.lstm(q_embed)
24
+ q_feat = h.squeeze(0)
25
+
26
+ x = self.relu(self.fc1(torch.cat((img_feat, q_feat), dim=1)))
27
+ return self.fc2(x)
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ pillow
4
+ pandas
5
+ scikit-learn
6
+ langdetect
7
+ tqdm
8
+ gradio
9
+ huggingface_hub
test.jpg ADDED
train.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ import pandas as pd
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.optim as optim
6
+ import torchvision.transforms as transforms
7
+ from torch.utils.data import Dataset, DataLoader, random_split
8
+ from PIL import Image
9
+ from collections import Counter
10
+ import pickle
11
+ import re
12
+ from tqdm import tqdm
13
+ import os
14
+
15
+ # ========================
16
+ # CONFIG
17
+ # ========================
18
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+ EPOCHS = 50
20
+ BATCH_SIZE = 32
21
+ LR = 5e-4
22
+ MAX_LEN = 20
23
+
24
+ # ========================
25
+ # LOAD DATASET
26
+ # ========================
27
+ dataset = load_dataset("flaviagiammarino/vqa-rad")
28
+ df = pd.DataFrame(dataset["train"])
29
+ df = df[["image", "question", "answer"]]
30
+
31
+ # ========================
32
+ # CLEAN TEXT
33
+ # ========================
34
+ def clean_text(text):
35
+ text = text.lower()
36
+ return re.sub(r"[^a-z0-9 ]", "", text)
37
+
38
+ df["question"] = df["question"].apply(clean_text)
39
+ df["answer"] = df["answer"].apply(clean_text)
40
+
41
+ # ========================
42
+ # FILTER TOP ANSWERS
43
+ # ========================
44
+ top_answers = df["answer"].value_counts().nlargest(50).index
45
+ df = df[df["answer"].isin(top_answers)]
46
+
47
+ answer_to_idx = {a:i for i,a in enumerate(top_answers)}
48
+ idx_to_answer = {i:a for a,i in answer_to_idx.items()}
49
+ df["answer_encoded"] = df["answer"].apply(lambda x: answer_to_idx[x])
50
+
51
+ # ========================
52
+ # VOCAB
53
+ # ========================
54
+ vocab = {"<PAD>":0, "<UNK>":1}
55
+ counter = Counter()
56
+
57
+ for q in df["question"]:
58
+ for w in q.split():
59
+ counter[w] += 1
60
+
61
+ idx = 2
62
+ for word, count in counter.items():
63
+ if count > 2:
64
+ vocab[word] = idx
65
+ idx += 1
66
+
67
+ def encode_question(q):
68
+ tokens = q.split()
69
+ enc = [vocab.get(w, vocab["<UNK>"]) for w in tokens]
70
+ enc = enc[:MAX_LEN] + [vocab["<PAD>"]] * (MAX_LEN - len(enc))
71
+ return enc
72
+
73
+ df["question_encoded"] = df["question"].apply(encode_question)
74
+
75
+ # ========================
76
+ # DATASET CLASS
77
+ # ========================
78
+ transform = transforms.Compose([
79
+ transforms.Resize((224,224)),
80
+ transforms.ToTensor()
81
+ ])
82
+
83
+ class VQADataset(Dataset):
84
+ def __init__(self, df):
85
+ self.df = df
86
+
87
+ def __len__(self):
88
+ return len(self.df)
89
+
90
+ def __getitem__(self, idx):
91
+ row = self.df.iloc[idx]
92
+
93
+ image = row["image"].convert("RGB")
94
+ image = transform(image)
95
+
96
+ question = torch.tensor(row["question_encoded"])
97
+ answer = torch.tensor(row["answer_encoded"])
98
+
99
+ return image, question, answer
100
+
101
+ # ========================
102
+ # SPLIT DATA
103
+ # ========================
104
+ dataset_full = VQADataset(df)
105
+ train_size = int(0.8 * len(dataset_full))
106
+ val_size = len(dataset_full) - train_size
107
+
108
+ train_dataset, val_dataset = random_split(dataset_full, [train_size, val_size])
109
+
110
+ train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
111
+ val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
112
+
113
+ # ========================
114
+ # MODEL
115
+ # ========================
116
+ import torchvision.models as models
117
+
118
+ class VQAModel(nn.Module):
119
+ def __init__(self, vocab_size, embed_dim, hidden_dim, num_answers):
120
+ super().__init__()
121
+
122
+ self.cnn = models.resnet18(weights="DEFAULT")
123
+ self.cnn.fc = nn.Identity()
124
+
125
+ self.embedding = nn.Embedding(vocab_size, embed_dim)
126
+ self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
127
+
128
+ self.fc1 = nn.Linear(512 + hidden_dim, 256)
129
+ self.relu = nn.ReLU()
130
+ self.fc2 = nn.Linear(256, num_answers)
131
+
132
+ def forward(self, image, question):
133
+ img_feat = self.cnn(image)
134
+
135
+ q_embed = self.embedding(question)
136
+ _, (h, _) = self.lstm(q_embed)
137
+ q_feat = h.squeeze(0)
138
+
139
+ x = self.relu(self.fc1(torch.cat((img_feat, q_feat), dim=1)))
140
+ return self.fc2(x)
141
+
142
+ model = VQAModel(len(vocab), 300, 256, len(answer_to_idx)).to(DEVICE)
143
+
144
+ criterion = nn.CrossEntropyLoss()
145
+ optimizer = optim.Adam(model.parameters(), lr=LR)
146
+
147
+ # ========================
148
+ # TRAIN LOOP
149
+ # ========================
150
+ for epoch in range(EPOCHS):
151
+ model.train()
152
+ total_loss = 0
153
+
154
+ for images, questions, answers in tqdm(train_loader):
155
+ images, questions, answers = images.to(DEVICE), questions.to(DEVICE), answers.to(DEVICE)
156
+
157
+ outputs = model(images, questions)
158
+ loss = criterion(outputs, answers)
159
+
160
+ optimizer.zero_grad()
161
+ loss.backward()
162
+ optimizer.step()
163
+
164
+ total_loss += loss.item()
165
+
166
+ # VALIDATION
167
+ model.eval()
168
+ val_loss = 0
169
+
170
+ with torch.no_grad():
171
+ for images, questions, answers in val_loader:
172
+ images, questions, answers = images.to(DEVICE), questions.to(DEVICE), answers.to(DEVICE)
173
+
174
+ outputs = model(images, questions)
175
+ loss = criterion(outputs, answers)
176
+ val_loss += loss.item()
177
+
178
+ print(f"\nEpoch {epoch+1}")
179
+ print(f"Train Loss: {total_loss/len(train_loader):.4f}")
180
+ print(f"Val Loss: {val_loss/len(val_loader):.4f}")
181
+
182
+ # ========================
183
+ # SAVE MODEL
184
+ # ========================
185
+ os.makedirs("weights", exist_ok=True)
186
+
187
+ torch.save(model.state_dict(), "weights/vqa_model.pth")
188
+
189
+ with open("weights/vocab.pkl", "wb") as f:
190
+ pickle.dump(vocab, f)
191
+
192
+ with open("weights/answers.pkl", "wb") as f:
193
+ pickle.dump(idx_to_answer, f)
194
+
195
+ print("\n✅ Training Complete & Model Saved!")
utils/__pycache__/text_utils.cpython-39.pyc ADDED
Binary file (719 Bytes). View file
 
utils/text_utils.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ def clean_text(text):
4
+ text = text.lower()
5
+ return re.sub(r"[^a-z0-9 ]", "", text)
6
+
7
+ def encode_question(q, vocab, max_len=20):
8
+ tokens = q.split()
9
+ enc = [vocab.get(w, vocab["<UNK>"]) for w in tokens]
10
+ enc = enc[:max_len] + [vocab["<PAD>"]] * (max_len - len(enc))
11
+ return enc
utils/translator.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from langdetect import detect
2
+
3
+ def detect_lang(text):
4
+ try:
5
+ return detect(text)
6
+ except:
7
+ return "en"
8
+
9
+ def translate(text, src, tgt):
10
+ return text