alvikhan commited on
Commit
62305fe
·
1 Parent(s): 56713ef

restructured files

Browse files
Files changed (8) hide show
  1. __init__.py +14 -0
  2. functions.py +197 -0
  3. fussionmodel.py +87 -0
  4. model.py +330 -0
  5. model_functions.py +210 -0
  6. models.py +29 -0
  7. qtype.py +25 -0
  8. tpred.py +24 -0
__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ from qtype import QuestionTypeClassifier
3
+ from tpred import TaskPredictor
4
+ from model import VQAModel
5
+ from fussionmodel import CoAttentionFusion
6
+ from functions import preprocess_example, preprocess_image, collate_fn
7
+ import torch
8
+ import torch.nn as nn
9
+ from datasets import load_dataset
10
+ from torch.utils.data import DataLoader
11
+ """
12
+
13
+ from .functions import preprocess_example, preprocess_image, collate_fn
14
+ from .model import VQAModel
functions.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from PIL import Image
3
+ import requests
4
+ import torch
5
+ import torchvision.transforms as transforms
6
+ transformten = transforms.Compose([
7
+ transforms.Resize((224, 224)), # adjust size for your model
8
+ transforms.ToTensor(), # convert to tensor
9
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], # ImageNet normalization
10
+ std=[0.229, 0.224, 0.225])
11
+ ])
12
+ from collections import defaultdict
13
+ from torch.utils.data import DataLoader
14
+ import os
15
+ from transformers import AutoTokenizer
16
+
17
+ image_cache = {}
18
+
19
+ def preprocess_image(image_source):
20
+ """
21
+ Preprocess a single image for inference.
22
+ `image_source` can be either a URL or a local file path.
23
+ Returns a tensor [C, H, W].
24
+ """
25
+ if isinstance(image_source, str):
26
+ if image_source.startswith("http"): # URL
27
+ image = Image.open(requests.get(image_source, stream=True).raw).convert("RGB")
28
+ else: # local path
29
+ image = Image.open(image_source).convert("RGB")
30
+ elif isinstance(image_source, Image.Image): # already a PIL image
31
+ image = image_source
32
+ else:
33
+ raise ValueError("Unsupported image_source type")
34
+
35
+ # Apply the same transform used during training
36
+ image = transformten(image) # e.g. Resize(224) → ToTensor() → Normalize()
37
+
38
+ return image # torch.Tensor [3, H, W]
39
+
40
+ def preprocess_example(example):
41
+ # Download image
42
+ #image = Image.open(requests.get(example["image"], stream=True).raw).convert("RGB")
43
+
44
+ router_tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
45
+
46
+ #Image from dataset
47
+ image_name = example["image"].split("/")[-1]
48
+ image_path = os.path.join("/kaggle/input/medico2025", image_name)
49
+
50
+ # 2. Check if the image is already in our cache
51
+ if image_path in image_cache:
52
+ image = image_cache[image_path]
53
+
54
+ else:
55
+ image = Image.open(image_path)
56
+ if image.mode != 'RGB':
57
+ image = image.convert('RGB')
58
+ image_cache[image_path] = image # Cache the loaded image object
59
+
60
+ # Apply your normalize/transform method
61
+ image = transformten(image) # e.g. Resize + ToTensor + Normalize
62
+
63
+
64
+ #print("DEBUG image:", type(image), image.shape)
65
+
66
+ # Tokenize the question
67
+ q_inputs = router_tokenizer(example["question"],
68
+ return_tensors="pt",
69
+ truncation=True,
70
+ padding="max_length",
71
+ max_length=32)
72
+
73
+ # q_inputs is a BatchEncoding with tensors inside (batch_size=1), so we squeeze
74
+ input_ids = q_inputs["input_ids"].squeeze(0) # torch.Tensor [seq_len]
75
+ attention_mask = q_inputs["attention_mask"].squeeze(0)
76
+
77
+ # Pack features
78
+ return {
79
+ "image": image,
80
+ "input_ids": input_ids,
81
+ "attention_mask": attention_mask,
82
+ "answer": example["answer"],
83
+ "question_class": example["question_class"],
84
+ "image_url": example["image"],
85
+ }
86
+
87
+ def normalize_answer(ans, q_type):
88
+ ans = ans.strip().lower()
89
+
90
+ if q_type == "yesno":
91
+ if "yes" in ans or "present" in ans or "evidence" in ans:
92
+ return "Yes"
93
+ elif "no" in ans or "absent" in ans or "none" in ans:
94
+ return "No"
95
+ else:
96
+ return None # ambiguous
97
+
98
+ if q_type == "count":
99
+ # Extract numeric value or return None
100
+ from re import findall
101
+ numbers = findall(r"\d+", ans)
102
+ if numbers:
103
+ return numbers[0]
104
+ elif "one" in ans: return "1"
105
+ elif "two" in ans: return "2"
106
+ return None
107
+
108
+ if q_type == "color":
109
+ for color in ["red","green","yellow","blue","white","black"]:
110
+ if color in ans:
111
+ return color
112
+ return None
113
+
114
+ if q_type == "location":
115
+ # Simplify locations to a small fixed set
116
+ for loc in ["upper","lower","left","right","central"]:
117
+ if loc in ans:
118
+ return loc
119
+ return None
120
+
121
+ if q_type in ["single","multi"]:
122
+ return ans # keep original but can also restrict choices
123
+
124
+ return ans
125
+
126
+
127
+ def build_vocabs(dataset,q_types_mapping):
128
+ # Build task-specific vocabularies
129
+ task_vocabs = {}
130
+ for general_class in set(q_types_mapping.values()):
131
+ task_vocabs[general_class] = {}
132
+
133
+ for row in dataset:
134
+ fine_class = row["question_class"]
135
+
136
+ # ✅ Handle if fine_class is a list
137
+ if isinstance(fine_class, list):
138
+ fine_class = fine_class[0]
139
+
140
+ general_class = q_types_mapping[fine_class]
141
+
142
+ norm_ans = normalize_answer(row["answer"], general_class)
143
+ if norm_ans is None:
144
+ continue # skip unnormalizable answers
145
+
146
+ if norm_ans not in task_vocabs[general_class]:
147
+ idx = len(task_vocabs[general_class])
148
+ task_vocabs[general_class][norm_ans] = idx
149
+
150
+ return task_vocabs
151
+
152
+
153
+ def build_answer_vocab(dataset, q_types_mapping):
154
+ answer_vocab = defaultdict(dict)
155
+ counters = defaultdict(int)
156
+
157
+ for ans, q_class in zip(dataset["answer"], dataset["question_class"]):
158
+ # q_class might be a list; pick the first (if multiple labels)
159
+ if isinstance(q_class, list):
160
+ q_class = q_class[0]
161
+
162
+ general_class = q_types_mapping[q_class]
163
+
164
+ if ans not in answer_vocab[general_class]:
165
+ answer_vocab[general_class][ans] = counters[general_class]
166
+ counters[general_class] += 1
167
+
168
+ return answer_vocab
169
+
170
+
171
+
172
+ def collate_fn(batch):
173
+ #print(type(batch[0]["image"]))
174
+
175
+ #images = torch.stack([item["image"] for item in batch])
176
+ images = torch.stack([torch.tensor(item["image"]) if isinstance(item["image"], list) else item["image"] for item in batch])
177
+
178
+ #print(type(images), images.shape)
179
+
180
+
181
+ input_ids = torch.stack([torch.tensor(item["input_ids"]) if isinstance(item["input_ids"], list) else item["input_ids"] for item in batch])
182
+ attention_mask = torch.stack([torch.tensor(item["attention_mask"]) if isinstance(item["attention_mask"], list) else item["attention_mask"] for item in batch])
183
+
184
+
185
+
186
+ #input_ids = torch.stack([item["input_ids"] for item in batch])
187
+ #attention_mask = torch.stack([item["attention_mask"] for item in batch])
188
+ answers = [item["answer"] for item in batch] # keep as list for label encoding later
189
+ q_classes = [item["question_class"] for item in batch]
190
+ return {
191
+ "images": images,
192
+ "input_ids": input_ids,
193
+ "attention_mask": attention_mask,
194
+ "answers": answers,
195
+ "question_classes": q_classes,
196
+ }
197
+
fussionmodel.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import ViTModel, BertModel
2
+ import torch.nn.functional as F
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ #image_encoder = ViTModel.from_pretrained("google/vit-base-patch16-224").to(device)
8
+ #question_encoder = BertModel.from_pretrained("bert-base-uncased").to(device)
9
+
10
+
11
+ class CoAttentionFusion(nn.Module):
12
+ def __init__(self, img_dim, ques_dim, disease_dim, hidden_dim, answer_vocab):
13
+ super(CoAttentionFusion, self).__init__()
14
+
15
+ self.img_proj = nn.Linear(img_dim, hidden_dim)
16
+ self.ques_proj = nn.Linear(ques_dim, hidden_dim)
17
+ self.dis_proj = nn.Linear(disease_dim, hidden_dim)
18
+
19
+ self.att_img = nn.Linear(hidden_dim, 1)
20
+ self.att_dis = nn.Linear(hidden_dim, 1)
21
+ self.fusion = nn.Linear(hidden_dim * 3, hidden_dim)
22
+
23
+ # ✅ Store answer vocab inside model for later use
24
+ self.answer_vocab = answer_vocab
25
+
26
+ def forward(self, img_feat, ques_feat, dis_vec):
27
+ # Project features
28
+ #print("Input Shapes\t",img_feat.shape, ques_feat.shape, dis_vec.shape)
29
+ img_proj = torch.tanh(self.img_proj(img_feat)) # [B, H]
30
+ ques_proj = torch.tanh(self.ques_proj(ques_feat)) # [B, H]
31
+
32
+ #print("Ques_proj",ques_proj.shape,img_proj.shape)
33
+
34
+ dis_vec = dis_vec.to(torch.float32)
35
+ dis_proj = torch.tanh(self.dis_proj(dis_vec)) # [B, H]
36
+
37
+
38
+ #print("After projection\t",img_proj.shape, ques_proj.shape, dis_proj.shape)
39
+
40
+ # Expand question for image alignment
41
+ #ques_expand = ques_proj.unsqueeze(1).expand_as(img_proj)
42
+
43
+
44
+ #ques_expand = ques_proj#.expand_as(img_proj)
45
+ #img_co = img_proj * ques_expand
46
+
47
+ #Replacement of above 2 lines
48
+
49
+ ques_proj = ques_proj.unsqueeze(1) # [16, 1, 512]
50
+ ques_expand = ques_proj.expand(-1, img_proj.size(1), -1) # [16, 197, 512]
51
+ img_co = img_proj * ques_expand # [16, 197, 512]
52
+
53
+ #print("ques_expand",ques_expand.shape,img_co.shape)
54
+
55
+ att_img_weights = torch.sigmoid(self.att_img(img_co)) # [B, 1]
56
+ #img_att = att_img_weights * img_proj
57
+ img_att = (att_img_weights * img_proj).sum(1)
58
+ #att_img_weights = F.softmax(self.att_img(img_co), dim=1) # [B, R, 1]
59
+ #img_att = (att_img_weights * img_proj).sum(1) # [B, H]
60
+
61
+ # Co-attention with disease vector
62
+ dis_co = dis_proj * ques_proj
63
+ att_dis_weights = torch.sigmoid(self.att_dis(dis_co)) # [B, 1]
64
+ dis_att = att_dis_weights * dis_proj # [B, H]
65
+
66
+
67
+ #print(img_att.shape, ques_proj.shape, dis_att.shape)
68
+
69
+
70
+ if img_att.dim() == 1:
71
+ img_att = img_att.unsqueeze(0) # [1, H]
72
+ if ques_proj.dim() == 1:
73
+ ques_proj = ques_proj.unsqueeze(0) # [1, H]
74
+ if dis_att.dim() == 1:
75
+ dis_att = dis_att.unsqueeze(0) # [1, H]
76
+
77
+
78
+ # Concatenate
79
+
80
+ ques_proj_flat = ques_proj.squeeze(1)
81
+ dis_att_flat = dis_att.squeeze(1)
82
+
83
+ #print(img_att.shape,ques_proj.shape,dis_att.shape)
84
+
85
+ joint_feat = torch.cat([img_att, ques_proj_flat, dis_att_flat], dim=1) # [B, 3H]
86
+ fused = torch.tanh(self.fusion(joint_feat)) # [B, H]
87
+ return fused
model.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import os
4
+ from .qtype import QuestionTypeClassifier
5
+ from .functions import build_vocabs, build_answer_vocab, collate_fn, preprocess_example, normalize_answer, preprocess_image
6
+ from .models import disease_model, device, generate_descriptive_answer, router_tokenizer, gen_model
7
+ from .tpred import TaskPredictor
8
+ from .model_functions import compute_loss, compute_meteor, compute_rouge, extract_count, forward_batch
9
+ from .fussionmodel import BertModel, CoAttentionFusion, ViTModel, F
10
+
11
+
12
+ class VQAModel(nn.Module):
13
+ def __init__(self,img_dim, ques_dim, disease_dim, hidden_dim):
14
+ super(VQAModel, self).__init__()
15
+ #self.fusion = CoAttentionFusion(img_dim, ques_dim, disease_dim, hidden_dim, answer_vocab=answer_vocab)
16
+ self.qtype_classifier=None
17
+ self.answer_classifier=None
18
+ self.epochs=1
19
+ self.device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+ self.hidden_dim=hidden_dim
21
+ self.input_dim=768
22
+ self.ques_dim=ques_dim
23
+ self.disease_dim=disease_dim
24
+ self.img_dim=img_dim
25
+ self.fusion_module=None
26
+ self.question_encoder=BertModel.from_pretrained("bert-base-uncased").to(self.device)
27
+ self.image_encoder=ViTModel.from_pretrained("google/vit-base-patch16-224").to(self.device)
28
+ self.optimizer=None
29
+ self.answer_vocabs=None
30
+ self.task_vocabs=None
31
+ self.data_train=None
32
+ self.train_loader=None
33
+ self.q_types = ["yesno", "single", "multi", "color", "location", "count"]
34
+ # Create task-specific heads (trainable)
35
+ self.task_heads = nn.ModuleDict({
36
+ t: TaskPredictor(t, hidden=hidden_dim) for t in self.q_types
37
+ })
38
+ self.q_types_mapping = {
39
+ 'abnormality_color': 'color',
40
+ 'landmark_color': 'color',
41
+ 'abnormality_location': 'location',
42
+ 'instrument_location': 'location',
43
+ 'landmark_location': 'location',
44
+ 'finding_count': 'count',
45
+ 'instrument_count': 'count',
46
+ 'polyp_count': 'count',
47
+ 'abnormality_presence': 'yesno',
48
+ 'box_artifact_presence': 'yesno',
49
+ 'finding_presence': 'yesno',
50
+ 'instrument_presence': 'yesno',
51
+ 'landmark_presence': 'yesno',
52
+ 'text_presence': 'yesno',
53
+ 'polyp_removal_status': 'yesno',
54
+ 'polyp_type': 'single',
55
+ 'polyp_size': 'single',
56
+ 'procedure_type': 'single',
57
+ }
58
+
59
+
60
+ def train(self,epochs,data_train,train_loader):
61
+ self.epochs=epochs
62
+ self.train_data=data_train
63
+ self.train_loader=train_loader
64
+ self.answer_vocabs = build_answer_vocab(self.train_data, self.q_types_mapping)
65
+ self.task_vocabs = build_vocabs(self.train_data,self.q_types_mapping)
66
+ #self.qtype_classifier = nn.Linear(hidden_dim, len(self.task_vocabs)) # ✅ match hidden_dim
67
+ self.qtype_classifier=QuestionTypeClassifier(num_types=len(self.q_types)).to(self.device)
68
+ #QuestionTypeClassifier(hidden=self.input_dim, num_types=len(self.q_types)).to(device)
69
+ #print(self.qtype_classifier)
70
+ self.answer_classifier = nn.Linear(self.hidden_dim, len(self.answer_vocabs)) # ✅ match hidden_dim
71
+ self.fusion_module = CoAttentionFusion(img_dim=self.img_dim,
72
+ ques_dim=self.ques_dim,
73
+ disease_dim=self.disease_dim,
74
+ hidden_dim=self.hidden_dim,
75
+ answer_vocab=self.answer_vocabs).to(self.device)
76
+ self.optimizer = torch.optim.AdamW(list(self.fusion_module.parameters()) +
77
+ list(self.question_encoder.parameters()) +
78
+ list(self.image_encoder.parameters())+
79
+ list(self.qtype_classifier.parameters()), lr=2e-5)
80
+ for epoch in range(self.epochs):
81
+ self.fusion_module.train()
82
+ self.qtype_classifier.train()
83
+ total_loss = 0
84
+ for batch in self.train_loader:
85
+ self.optimizer.zero_grad()
86
+ preds, answers, task_logits = forward_batch(
87
+ batch["images"],
88
+ batch["input_ids"],
89
+ batch["attention_mask"],
90
+ batch["answers"],
91
+ batch["question_classes"], # fine-grained from dataset
92
+ qtype_classifier=self.qtype_classifier,
93
+ fusion_module=self.fusion_module,
94
+ q_types=self.q_types,
95
+ q_types_mapping=self.q_types_mapping,
96
+ task_heads=self.task_heads,
97
+ device=self.device,
98
+ image_encoder=self.image_encoder,
99
+ question_encoder=self.question_encoder
100
+ )
101
+ #preds, answers = forward_batch(batch["images"],batch["input_ids"], batch["attention_mask"], batch["answers"], batch["question_classes"])
102
+ loss = compute_loss(preds,
103
+ answers,
104
+ task_logits,
105
+ batch["question_classes"],
106
+ answer_vocabs=self.answer_vocabs,
107
+ q_types_mapping=self.q_types_mapping,
108
+ q_types=self.q_types,
109
+ task_heads=self.task_heads
110
+ )
111
+ #loss = compute_loss(preds, answers, batch["question_classes"])
112
+ loss.backward()
113
+ self.optimizer.step()
114
+ total_loss += loss.item()
115
+ print(f"Epoch {epoch}, Train Loss: {total_loss / len(train_loader)}")
116
+
117
+
118
+ def eval(self, val_loader):
119
+ """
120
+ Evaluate the model on the validation set.
121
+
122
+ Args:
123
+ val_loader: DataLoader for validation data.
124
+
125
+ Returns:
126
+ avg_loss: average validation loss
127
+ all_preds: list of predicted labels
128
+ all_answers: list of ground truth answers
129
+ """
130
+ self.fusion_module.eval()
131
+ self.question_encoder.eval()
132
+ self.image_encoder.eval()
133
+ self.qtype_classifier.eval()
134
+ for head in self.task_heads.values():
135
+ head.eval()
136
+
137
+ total_loss = 0.0
138
+ all_preds, all_answers = [], []
139
+
140
+ with torch.no_grad():
141
+ for batch in val_loader:
142
+ images = batch["images"].to(self.device)
143
+ input_ids = batch["input_ids"].to(self.device)
144
+ attention_mask = batch["attention_mask"].to(self.device)
145
+ answers = batch["answers"]
146
+ q_classes = batch["question_classes"]
147
+
148
+ # ---- Disease vector ----
149
+ disease_vec = disease_model(images)
150
+
151
+ # ---- Question type classifier ----
152
+ task_logits = self.qtype_classifier(
153
+ input_ids=input_ids,
154
+ attention_mask=attention_mask
155
+ ) # [B, num_types]
156
+
157
+ # map fine-grained → general
158
+ mapped_classes = [
159
+ self.q_types_mapping[c[0] if isinstance(c, list) else c]
160
+ for c in q_classes
161
+ ]
162
+
163
+ # ---- Encoders ----
164
+ q_feat = self.question_encoder(
165
+ input_ids=input_ids,
166
+ attention_mask=attention_mask
167
+ ).pooler_output # [B, 768]
168
+
169
+ img_outputs = self.image_encoder(pixel_values=images)
170
+ img_feat = img_outputs.last_hidden_state # [B, R, 768]
171
+
172
+ # ---- Fusion ----
173
+ fused = self.fusion_module(img_feat, q_feat, disease_vec)
174
+
175
+ # ---- Predict per sample ----
176
+ pred_tensors = []
177
+ batch_preds = []
178
+ for i, task_type in enumerate(mapped_classes):
179
+ predictor = self.task_heads[task_type]
180
+ #pred_out = predictor(fused[i].unsqueeze(0))
181
+ pred_tensor = predictor(fused[i].unsqueeze(0)) # shape [1, C] or [1,1] for count
182
+ pred_tensors.append(pred_tensor)
183
+
184
+ if task_type == "yesno":
185
+ pred_label = "Yes" if torch.argmax(pred_tensor, dim=1).item() == 1 else "No"
186
+ elif task_type == "count":
187
+ pred_val = pred_tensor.squeeze()
188
+ pred_label = str(int(round(pred_val.item())))
189
+ #pred_label = str(int(pred_out.item()))
190
+ else:
191
+ ans_idx = torch.argmax(pred_tensor, dim=1).item()
192
+ if task_type in self.answer_vocabs and ans_idx < len(self.answer_vocabs[task_type]):
193
+ inv_vocab = {v: k for k, v in self.answer_vocabs[task_type].items()}
194
+ pred_label = inv_vocab.get(ans_idx, str(ans_idx))
195
+ else:
196
+ pred_label = str(ans_idx)
197
+
198
+ batch_preds.append(pred_label)
199
+
200
+ # ---- Compute loss ----
201
+ """
202
+ batch_loss = compute_loss(
203
+ [self.task_heads[c](fused[i].unsqueeze(0)) for i, c in enumerate(mapped_classes)],
204
+ answers,
205
+ task_logits,
206
+ q_classes,
207
+ self.answer_vocabs
208
+ )"""
209
+ # compute batch loss using the same preds (tensors) and required extra args
210
+ batch_loss = compute_loss(
211
+ preds=pred_tensors,
212
+ answers=answers,
213
+ task_logits=task_logits,
214
+ true_q_classes=q_classes,
215
+ answer_vocabs=self.answer_vocabs,
216
+ q_types_mapping=self.q_types_mapping,
217
+ q_types=self.q_types,
218
+ task_heads=self.task_heads
219
+ )
220
+ total_loss += batch_loss.item()
221
+
222
+ all_preds.extend(batch_preds)
223
+ all_answers.extend(answers)
224
+
225
+ avg_loss = total_loss / len(val_loader)
226
+ return avg_loss, all_preds, all_answers
227
+
228
+
229
+ def load(self,load_path = "vqa_model.pt"):
230
+ checkpoint = torch.load(load_path, map_location=self.device,weights_only=False)
231
+ self.task_vocabs=checkpoint["task_vocabs"]
232
+ self.answer_vocabs=checkpoint["answer_vocabs"]
233
+ self.fusion_module = CoAttentionFusion(
234
+ img_dim=self.img_dim, ques_dim=self.ques_dim, disease_dim=self.disease_dim, hidden_dim=self.hidden_dim,
235
+ answer_vocab=checkpoint["answer_vocabs"]
236
+ ).to(self.device)
237
+ self.fusion_module.load_state_dict(checkpoint["fusion_module"])
238
+ self.question_encoder.load_state_dict(checkpoint["question_encoder"])
239
+ self.image_encoder.load_state_dict(checkpoint["image_encoder"])
240
+ self.qtype_classifier.load_state_dict(checkpoint["qtype_classifier"])
241
+
242
+ for k, v in checkpoint["task_heads"].items():
243
+ self.task_heads[k].load_state_dict(v)
244
+
245
+ # 3. Recreate optimizer with correct params
246
+ self.optimizer = torch.optim.AdamW(
247
+ list(self.fusion_module.parameters()) +
248
+ list(self.question_encoder.parameters()) +
249
+ list(self.image_encoder.parameters()) +
250
+ list(self.qtype_classifier.parameters()),
251
+ lr=2e-5
252
+ )
253
+ self.optimizer.load_state_dict(checkpoint["optimizer"])
254
+ print("Model and components loaded successfully")
255
+
256
+ def save(self,save_path = "vqa_model.pt"):
257
+ torch.save({
258
+ "fusion_module": self.fusion_module.state_dict(),
259
+ "question_encoder": self.question_encoder.state_dict(),
260
+ "image_encoder": self.image_encoder.state_dict(),
261
+ "qtype_classifier": self.qtype_classifier.state_dict(),
262
+ "task_heads": {k: v.state_dict() for k, v in self.task_heads.items()},
263
+ "optimizer": self.optimizer.state_dict(),
264
+ "epochs": self.epochs,
265
+ "answer_vocabs": self.answer_vocabs,
266
+ "task_vocabs": self.task_vocabs
267
+ }, save_path)
268
+ print(f"Model saved at {save_path}")
269
+
270
+ def predict(self, image, question):
271
+ self.fusion_module.eval()
272
+ self.question_encoder.eval()
273
+ self.image_encoder.eval()
274
+ self.qtype_classifier.eval()
275
+
276
+ with torch.no_grad():
277
+ # ---- Preprocess image ----
278
+ image_tensor = preprocess_image(image).unsqueeze(0).to(self.device)
279
+
280
+ # ---- Disease vector ----
281
+ disease_vec = disease_model(image_tensor)
282
+
283
+ # ---- Encode question ----
284
+ q_inputs = router_tokenizer(
285
+ question,
286
+ return_tensors="pt",
287
+ truncation=True,
288
+ padding=True
289
+ ).to(self.device)
290
+
291
+ # DistilBERT classifier for q-type
292
+ task_logits = self.qtype_classifier(
293
+ input_ids=q_inputs["input_ids"],
294
+ attention_mask=q_inputs["attention_mask"]
295
+ ) # [1, num_types]
296
+
297
+ task_idx = torch.argmax(task_logits, dim=1).item()
298
+ task_type = self.q_types[task_idx] # map index → general type
299
+
300
+ # ---- Question encoder for fusion ----
301
+ q_feat = self.question_encoder(**q_inputs).pooler_output # [1, 768]
302
+
303
+ # ---- Image encoder ----
304
+ img_outputs = self.image_encoder(pixel_values=image_tensor)
305
+ img_feat = img_outputs.last_hidden_state # [1, R, 768]
306
+
307
+ # ---- Fusion ----
308
+ fused = self.fusion_module(img_feat, q_feat, disease_vec)
309
+
310
+ # ---- Task-specific head ----
311
+ predictor = self.task_heads[task_type] # use pretrained head
312
+ pred_out = predictor(fused)
313
+
314
+ # ---- Decode prediction ----
315
+ if task_type == "yesno":
316
+ pred_label = "Yes" if torch.argmax(pred_out, dim=1).item() == 1 else "No"
317
+
318
+ elif task_type == "count":
319
+ pred_label = str(int(pred_out.item()))
320
+
321
+ else: # categorical answer
322
+ ans_idx = torch.argmax(pred_out, dim=1).item()
323
+ if task_type in self.answer_vocabs and ans_idx < len(self.answer_vocabs[task_type]):
324
+ inv_vocab = {v: k for k, v in self.answer_vocabs[task_type].items()}
325
+ pred_label = inv_vocab.get(ans_idx, str(ans_idx))
326
+ else:
327
+ pred_label = str(ans_idx)
328
+
329
+ return pred_label
330
+
model_functions.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn import CrossEntropyLoss, MSELoss
3
+ import re
4
+ #!pip install rouge_score
5
+
6
+ #from rouge_score import rouge_scorer
7
+
8
+ from nltk.translate.meteor_score import meteor_score
9
+ from .models import disease_model
10
+
11
+ def forward_batch(images, input_ids, attention_mask, answers, question_classes=None,qtype_classifier=None,fusion_module=None,q_types=None,q_types_mapping=None,task_heads=None,device=None,image_encoder=None,question_encoder=None):
12
+ # Image encoding
13
+ img_outputs = image_encoder(pixel_values=images.to(device))
14
+ img_feat = img_outputs.last_hidden_state # [B, R, 768]
15
+
16
+ # Question encoding (DistilBERT for qtype classification)
17
+ task_logits = qtype_classifier(input_ids=input_ids.to(device),
18
+ attention_mask=attention_mask.to(device)) # [B, num_types]
19
+
20
+ # Use another encoder for question embeddings (router encoder you already had)
21
+ q_feat = question_encoder(input_ids=input_ids.to(device),
22
+ attention_mask=attention_mask.to(device)).pooler_output # [B, 768]
23
+
24
+ # Disease model
25
+ disease_vec = disease_model(images.to(device)) # [B, 23]
26
+
27
+ # Fusion
28
+ fused = fusion_module(img_feat, q_feat, disease_vec)
29
+
30
+ # Task-specific predictions (list of preds per sample, like before)
31
+ preds = []
32
+ for i, q_class in enumerate(question_classes):#q_class from task)type
33
+ mapped_type = q_types_mapping[q_class[0] if isinstance(q_class, list) else q_class]
34
+ predictor = task_heads[mapped_type] # ✅ trained head
35
+ pred_out = predictor(fused[i].unsqueeze(0))
36
+ preds.append(pred_out)
37
+ #general_class = q_types_mapping[task_type[0] if isinstance(task_type, list) else task_type]
38
+ #head = TaskPredictor(general_class, hidden=fused.size(-1)).to(device)
39
+ #preds.append(head(fused[i].unsqueeze(0)))
40
+
41
+ return preds, answers, task_logits
42
+
43
+
44
+ def forward_batch1(images, input_ids, attention_mask, answers, true_q_classes=None,qtype_classifier=None,fusion_module=None,q_types=None):
45
+ # Disease vector (dummy placeholder: replace with your trained disease model)
46
+ disease_vec = disease_model(images) # [B, 23]
47
+
48
+ # Encode image
49
+ img_outputs = image_encoder(pixel_values=images.to(device))
50
+ img_feat = img_outputs.last_hidden_state # [B, R, 768]
51
+
52
+ # Encode question
53
+ q_feat = question_encoder(input_ids=input_ids.to(device),
54
+ attention_mask=attention_mask.to(device)).pooler_output # [B, 768]
55
+
56
+ # Predict task type from question
57
+ #print(q_feat.device)
58
+ #print(q_feat.shape)
59
+ #task_logits = qtype_classifier(q_feat) # [B, 6]
60
+ task_logits = qtype_classifier(input_ids=batch["input_ids"],
61
+ attention_mask=batch["attention_mask"])
62
+ task_pred = torch.argmax(task_logits, dim=1) # predicted type index
63
+
64
+ # Fusion
65
+ fused = fusion_module(img_feat, q_feat, disease_vec)
66
+
67
+ # Task-specific predictions
68
+ preds = []
69
+ for i, t_idx in enumerate(task_pred):
70
+ task_type = q_types[t_idx] # map index to string
71
+ predictor = TaskPredictor(task_type).to(device)
72
+ preds.append(predictor(fused[i].unsqueeze(0)))
73
+
74
+ return preds, answers, task_logits
75
+
76
+ #for i, task_type in enumerate(q_classes):
77
+ # predictor = TaskPredictor(task_type).to(device)
78
+ # pred_out = predictor(fused[i].unsqueeze(0))
79
+ # preds.append(pred_out)
80
+ #return preds, answers
81
+
82
+
83
+ def extract_count(answer_str):
84
+ """
85
+ Try to convert an answer string into a number.
86
+ Returns None if it cannot be parsed.
87
+ """
88
+ try:
89
+ # Direct numeric
90
+ return float(answer_str)
91
+ except ValueError:
92
+ pass
93
+
94
+ # Handle words like "one", "two", etc.
95
+ word2num = {
96
+ "zero": 0, "one": 1, "two": 2, "three": 3,
97
+ "four": 4, "five": 5, "six": 6,
98
+ "seven": 7, "eight": 8, "nine": 9, "ten": 10
99
+ }
100
+ tokens = answer_str.lower().split()
101
+ for t in tokens:
102
+ if t in word2num:
103
+ return float(word2num[t])
104
+
105
+ # Extract any digits from the string
106
+ numbers = re.findall(r"\d+", answer_str)
107
+ if numbers:
108
+ return float(numbers[0])
109
+
110
+ return None # fallback
111
+
112
+
113
+ def compute_meteor(preds, answers, answer_vocabs, mapped_classes):
114
+ scores = []
115
+ for pred, ans, c in zip(preds, answers, mapped_classes):
116
+ if c not in answer_vocabs:
117
+ continue
118
+ # Get predicted index
119
+ pred_idx = pred.argmax(dim=1).item()
120
+ # Map index back to string
121
+ inv_vocab = {v: k for k, v in answer_vocabs[c].items()}
122
+ pred_str = inv_vocab.get(pred_idx, "")
123
+ # METEOR score between predicted and ground truth answer
124
+ score = meteor_score([ans.split()], pred_str.split())
125
+ scores.append(score)
126
+ return sum(scores) / len(scores) if scores else 0.0
127
+
128
+
129
+ def compute_rouge(preds, answers, answer_vocabs, mapped_classes):
130
+ scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=True)
131
+ scores = []
132
+ for pred, ans, c in zip(preds, answers, mapped_classes):
133
+ if c not in answer_vocabs:
134
+ continue
135
+ pred_idx = pred.argmax(dim=1).item()
136
+ inv_vocab = {v: k for k, v in answer_vocabs[c].items()}
137
+ pred_str = inv_vocab.get(pred_idx, "")
138
+ score = scorer.score(ans, pred_str)["rougeL"].fmeasure
139
+ scores.append(score)
140
+ return sum(scores) / len(scores) if scores else 0.0
141
+
142
+ def compute_loss(preds, answers, task_logits, true_q_classes, answer_vocabs,q_types_mapping,q_types,task_heads):
143
+ """
144
+ preds: list of model predictions for each sample
145
+ answers: list of strings (descriptive answers)
146
+ task_logits: tensor [batch_size, num_task_types]
147
+ true_q_classes: list of lists (fine-grained classes for each question)
148
+ answer_vocabs: dict mapping {q_type: {answer: index}}
149
+ """
150
+
151
+ ce_loss = CrossEntropyLoss()
152
+ mse_loss = MSELoss()
153
+
154
+ total_loss = 0
155
+
156
+ # 1) Map fine-grained → general classes
157
+ mapped_classes = [
158
+ q_types_mapping[c[0] if isinstance(c, list) else c]
159
+ for c in true_q_classes
160
+ ]
161
+
162
+ # 2) Question type classification loss
163
+ true_task_types = torch.tensor(
164
+ [q_types.index(c) for c in mapped_classes],
165
+ device=task_logits.device
166
+ )
167
+
168
+ #print("task_logits, true_task_types\t",task_logits, true_task_types)
169
+
170
+ #print("task_logits, true_task_types\t",task_logits.shape, true_task_types.shape)
171
+
172
+ task_loss = ce_loss(task_logits, true_task_types)
173
+ total_loss += task_loss
174
+
175
+ # 3) Answer prediction loss (per sample)
176
+ for pred, ans, c in zip(preds, answers, mapped_classes):
177
+ predictor = task_heads[c] # ✅ trained head
178
+ if c == "count":
179
+ # For count, answer must be numeric
180
+ try:
181
+ ans_val = float(ans)
182
+ ans_val = torch.tensor([ans_val], device=pred.device)
183
+ total_loss += mse_loss(pred.squeeze(), ans_val)
184
+ except ValueError:
185
+ print(f"[Warning] Skipping non-numeric count answer: {ans}")
186
+ continue
187
+
188
+ else:
189
+ # For categorical tasks (yesno, single, multi, etc.)
190
+ if ans not in answer_vocabs.get(c, {}):
191
+ print(f"[Warning] Skipping unseen or descriptive answer {ans} for task {c}")
192
+ continue
193
+
194
+ ans_idx = answer_vocabs[c][ans]
195
+
196
+ if ans_idx >= pred.size(1):
197
+ print(f"[Warning] Skipping answer {ans} for task {c}: "
198
+ f"index {ans_idx} >= pred.size(1)")
199
+ continue
200
+
201
+ ans_tensor = torch.tensor([ans_idx], device=pred.device)
202
+ total_loss += ce_loss(pred, ans_tensor)
203
+
204
+ meteor = compute_meteor(preds, answers, answer_vocabs, mapped_classes)
205
+ print(f"Validation METEOR: {meteor:.4f}")
206
+ #rouge = compute_rouge(preds, answers, answer_vocabs, mapped_classes)
207
+ #print(f"Validation ROUGE-L: {rouge:.4f}")
208
+
209
+
210
+ return total_loss / len(preds)
models.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from transformers import (
5
+ AutoTokenizer, AutoModelForSequenceClassification,
6
+ AutoModel, AutoProcessor, VisionEncoderDecoderModel,
7
+ T5Tokenizer, T5ForConditionalGeneration
8
+ )
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+
11
+ gen_name = "t5-base"
12
+ gen_tokenizer = T5Tokenizer.from_pretrained(gen_name)
13
+ gen_model = T5ForConditionalGeneration.from_pretrained(gen_name).to(device)
14
+
15
+ def generate_descriptive_answer(question, prediction, fused_features):
16
+ # Construct a prompt combining prediction and context
17
+ prompt = f"Question: {question} | Prediction: {prediction} | Context: GI disease analysis"
18
+ inputs = gen_tokenizer(prompt, return_tensors="pt", truncation=True).to(device)
19
+ outputs = gen_model.generate(**inputs, max_length=50)
20
+ return gen_tokenizer.decode(outputs[0], skip_special_tokens=True)
21
+
22
+ def disease_model(img):
23
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+ #torch.tensor(np.random.rand(23)).to(device)
25
+ return torch.zeros(23).to(device)
26
+
27
+ router_name = "distilbert-base-uncased"
28
+ router_tokenizer = AutoTokenizer.from_pretrained(router_name)
29
+
qtype.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from transformers import DistilBertModel
4
+
5
+ class QuestionTypeClassifier(nn.Module):
6
+ def __init__(self, num_types):
7
+ super().__init__()
8
+
9
+ # Load pre-trained DistilBERT
10
+ self.distilbert = DistilBertModel.from_pretrained("distilbert-base-uncased")
11
+
12
+ # Classification head
13
+ self.fc = nn.Linear(self.distilbert.config.hidden_size, num_types)
14
+
15
+ def forward(self, input_ids, attention_mask):
16
+ outputs = self.distilbert(
17
+ input_ids=input_ids,
18
+ attention_mask=attention_mask
19
+ )
20
+
21
+ # Take [CLS] token embedding (DistilBERT uses first token as [CLS])
22
+ cls_token = outputs.last_hidden_state[:, 0, :] # [B, hidden]
23
+
24
+ logits = self.fc(cls_token) # [B, num_types]
25
+ return logits
tpred.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ---------------------------
2
+ # Step 5: Task-Specific Predictors
3
+ # ---------------------------
4
+ import torch.nn as nn
5
+ class TaskPredictor(nn.Module):
6
+ def __init__(self, task_type, hidden=512):
7
+ super().__init__()
8
+ if task_type == "yesno":
9
+ self.head = nn.Linear(hidden, 2)
10
+ elif task_type == "single":
11
+ self.head = nn.Linear(hidden, 10)
12
+ elif task_type == "multi":
13
+ self.head = nn.Linear(hidden, 10)
14
+ elif task_type == "color":
15
+ self.head = nn.Linear(hidden, 5)
16
+ elif task_type == "location":
17
+ self.head = nn.Linear(hidden, 6)
18
+ elif task_type == "count":
19
+ self.head = nn.Linear(hidden, 1)
20
+ else:
21
+ raise ValueError("Unknown task")
22
+
23
+ def forward(self, x):
24
+ return self.head(x)