restructured files
Browse files- __init__.py +14 -0
- functions.py +197 -0
- fussionmodel.py +87 -0
- model.py +330 -0
- model_functions.py +210 -0
- models.py +29 -0
- qtype.py +25 -0
- tpred.py +24 -0
__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
from qtype import QuestionTypeClassifier
|
| 3 |
+
from tpred import TaskPredictor
|
| 4 |
+
from model import VQAModel
|
| 5 |
+
from fussionmodel import CoAttentionFusion
|
| 6 |
+
from functions import preprocess_example, preprocess_image, collate_fn
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from datasets import load_dataset
|
| 10 |
+
from torch.utils.data import DataLoader
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from .functions import preprocess_example, preprocess_image, collate_fn
|
| 14 |
+
from .model import VQAModel
|
functions.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import requests
|
| 4 |
+
import torch
|
| 5 |
+
import torchvision.transforms as transforms
|
| 6 |
+
transformten = transforms.Compose([
|
| 7 |
+
transforms.Resize((224, 224)), # adjust size for your model
|
| 8 |
+
transforms.ToTensor(), # convert to tensor
|
| 9 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], # ImageNet normalization
|
| 10 |
+
std=[0.229, 0.224, 0.225])
|
| 11 |
+
])
|
| 12 |
+
from collections import defaultdict
|
| 13 |
+
from torch.utils.data import DataLoader
|
| 14 |
+
import os
|
| 15 |
+
from transformers import AutoTokenizer
|
| 16 |
+
|
| 17 |
+
image_cache = {}
|
| 18 |
+
|
| 19 |
+
def preprocess_image(image_source):
|
| 20 |
+
"""
|
| 21 |
+
Preprocess a single image for inference.
|
| 22 |
+
`image_source` can be either a URL or a local file path.
|
| 23 |
+
Returns a tensor [C, H, W].
|
| 24 |
+
"""
|
| 25 |
+
if isinstance(image_source, str):
|
| 26 |
+
if image_source.startswith("http"): # URL
|
| 27 |
+
image = Image.open(requests.get(image_source, stream=True).raw).convert("RGB")
|
| 28 |
+
else: # local path
|
| 29 |
+
image = Image.open(image_source).convert("RGB")
|
| 30 |
+
elif isinstance(image_source, Image.Image): # already a PIL image
|
| 31 |
+
image = image_source
|
| 32 |
+
else:
|
| 33 |
+
raise ValueError("Unsupported image_source type")
|
| 34 |
+
|
| 35 |
+
# Apply the same transform used during training
|
| 36 |
+
image = transformten(image) # e.g. Resize(224) → ToTensor() → Normalize()
|
| 37 |
+
|
| 38 |
+
return image # torch.Tensor [3, H, W]
|
| 39 |
+
|
| 40 |
+
def preprocess_example(example):
|
| 41 |
+
# Download image
|
| 42 |
+
#image = Image.open(requests.get(example["image"], stream=True).raw).convert("RGB")
|
| 43 |
+
|
| 44 |
+
router_tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
|
| 45 |
+
|
| 46 |
+
#Image from dataset
|
| 47 |
+
image_name = example["image"].split("/")[-1]
|
| 48 |
+
image_path = os.path.join("/kaggle/input/medico2025", image_name)
|
| 49 |
+
|
| 50 |
+
# 2. Check if the image is already in our cache
|
| 51 |
+
if image_path in image_cache:
|
| 52 |
+
image = image_cache[image_path]
|
| 53 |
+
|
| 54 |
+
else:
|
| 55 |
+
image = Image.open(image_path)
|
| 56 |
+
if image.mode != 'RGB':
|
| 57 |
+
image = image.convert('RGB')
|
| 58 |
+
image_cache[image_path] = image # Cache the loaded image object
|
| 59 |
+
|
| 60 |
+
# Apply your normalize/transform method
|
| 61 |
+
image = transformten(image) # e.g. Resize + ToTensor + Normalize
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
#print("DEBUG image:", type(image), image.shape)
|
| 65 |
+
|
| 66 |
+
# Tokenize the question
|
| 67 |
+
q_inputs = router_tokenizer(example["question"],
|
| 68 |
+
return_tensors="pt",
|
| 69 |
+
truncation=True,
|
| 70 |
+
padding="max_length",
|
| 71 |
+
max_length=32)
|
| 72 |
+
|
| 73 |
+
# q_inputs is a BatchEncoding with tensors inside (batch_size=1), so we squeeze
|
| 74 |
+
input_ids = q_inputs["input_ids"].squeeze(0) # torch.Tensor [seq_len]
|
| 75 |
+
attention_mask = q_inputs["attention_mask"].squeeze(0)
|
| 76 |
+
|
| 77 |
+
# Pack features
|
| 78 |
+
return {
|
| 79 |
+
"image": image,
|
| 80 |
+
"input_ids": input_ids,
|
| 81 |
+
"attention_mask": attention_mask,
|
| 82 |
+
"answer": example["answer"],
|
| 83 |
+
"question_class": example["question_class"],
|
| 84 |
+
"image_url": example["image"],
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
def normalize_answer(ans, q_type):
|
| 88 |
+
ans = ans.strip().lower()
|
| 89 |
+
|
| 90 |
+
if q_type == "yesno":
|
| 91 |
+
if "yes" in ans or "present" in ans or "evidence" in ans:
|
| 92 |
+
return "Yes"
|
| 93 |
+
elif "no" in ans or "absent" in ans or "none" in ans:
|
| 94 |
+
return "No"
|
| 95 |
+
else:
|
| 96 |
+
return None # ambiguous
|
| 97 |
+
|
| 98 |
+
if q_type == "count":
|
| 99 |
+
# Extract numeric value or return None
|
| 100 |
+
from re import findall
|
| 101 |
+
numbers = findall(r"\d+", ans)
|
| 102 |
+
if numbers:
|
| 103 |
+
return numbers[0]
|
| 104 |
+
elif "one" in ans: return "1"
|
| 105 |
+
elif "two" in ans: return "2"
|
| 106 |
+
return None
|
| 107 |
+
|
| 108 |
+
if q_type == "color":
|
| 109 |
+
for color in ["red","green","yellow","blue","white","black"]:
|
| 110 |
+
if color in ans:
|
| 111 |
+
return color
|
| 112 |
+
return None
|
| 113 |
+
|
| 114 |
+
if q_type == "location":
|
| 115 |
+
# Simplify locations to a small fixed set
|
| 116 |
+
for loc in ["upper","lower","left","right","central"]:
|
| 117 |
+
if loc in ans:
|
| 118 |
+
return loc
|
| 119 |
+
return None
|
| 120 |
+
|
| 121 |
+
if q_type in ["single","multi"]:
|
| 122 |
+
return ans # keep original but can also restrict choices
|
| 123 |
+
|
| 124 |
+
return ans
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def build_vocabs(dataset,q_types_mapping):
|
| 128 |
+
# Build task-specific vocabularies
|
| 129 |
+
task_vocabs = {}
|
| 130 |
+
for general_class in set(q_types_mapping.values()):
|
| 131 |
+
task_vocabs[general_class] = {}
|
| 132 |
+
|
| 133 |
+
for row in dataset:
|
| 134 |
+
fine_class = row["question_class"]
|
| 135 |
+
|
| 136 |
+
# ✅ Handle if fine_class is a list
|
| 137 |
+
if isinstance(fine_class, list):
|
| 138 |
+
fine_class = fine_class[0]
|
| 139 |
+
|
| 140 |
+
general_class = q_types_mapping[fine_class]
|
| 141 |
+
|
| 142 |
+
norm_ans = normalize_answer(row["answer"], general_class)
|
| 143 |
+
if norm_ans is None:
|
| 144 |
+
continue # skip unnormalizable answers
|
| 145 |
+
|
| 146 |
+
if norm_ans not in task_vocabs[general_class]:
|
| 147 |
+
idx = len(task_vocabs[general_class])
|
| 148 |
+
task_vocabs[general_class][norm_ans] = idx
|
| 149 |
+
|
| 150 |
+
return task_vocabs
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def build_answer_vocab(dataset, q_types_mapping):
|
| 154 |
+
answer_vocab = defaultdict(dict)
|
| 155 |
+
counters = defaultdict(int)
|
| 156 |
+
|
| 157 |
+
for ans, q_class in zip(dataset["answer"], dataset["question_class"]):
|
| 158 |
+
# q_class might be a list; pick the first (if multiple labels)
|
| 159 |
+
if isinstance(q_class, list):
|
| 160 |
+
q_class = q_class[0]
|
| 161 |
+
|
| 162 |
+
general_class = q_types_mapping[q_class]
|
| 163 |
+
|
| 164 |
+
if ans not in answer_vocab[general_class]:
|
| 165 |
+
answer_vocab[general_class][ans] = counters[general_class]
|
| 166 |
+
counters[general_class] += 1
|
| 167 |
+
|
| 168 |
+
return answer_vocab
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def collate_fn(batch):
|
| 173 |
+
#print(type(batch[0]["image"]))
|
| 174 |
+
|
| 175 |
+
#images = torch.stack([item["image"] for item in batch])
|
| 176 |
+
images = torch.stack([torch.tensor(item["image"]) if isinstance(item["image"], list) else item["image"] for item in batch])
|
| 177 |
+
|
| 178 |
+
#print(type(images), images.shape)
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
input_ids = torch.stack([torch.tensor(item["input_ids"]) if isinstance(item["input_ids"], list) else item["input_ids"] for item in batch])
|
| 182 |
+
attention_mask = torch.stack([torch.tensor(item["attention_mask"]) if isinstance(item["attention_mask"], list) else item["attention_mask"] for item in batch])
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
#input_ids = torch.stack([item["input_ids"] for item in batch])
|
| 187 |
+
#attention_mask = torch.stack([item["attention_mask"] for item in batch])
|
| 188 |
+
answers = [item["answer"] for item in batch] # keep as list for label encoding later
|
| 189 |
+
q_classes = [item["question_class"] for item in batch]
|
| 190 |
+
return {
|
| 191 |
+
"images": images,
|
| 192 |
+
"input_ids": input_ids,
|
| 193 |
+
"attention_mask": attention_mask,
|
| 194 |
+
"answers": answers,
|
| 195 |
+
"question_classes": q_classes,
|
| 196 |
+
}
|
| 197 |
+
|
fussionmodel.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import ViTModel, BertModel
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
|
| 7 |
+
#image_encoder = ViTModel.from_pretrained("google/vit-base-patch16-224").to(device)
|
| 8 |
+
#question_encoder = BertModel.from_pretrained("bert-base-uncased").to(device)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class CoAttentionFusion(nn.Module):
|
| 12 |
+
def __init__(self, img_dim, ques_dim, disease_dim, hidden_dim, answer_vocab):
|
| 13 |
+
super(CoAttentionFusion, self).__init__()
|
| 14 |
+
|
| 15 |
+
self.img_proj = nn.Linear(img_dim, hidden_dim)
|
| 16 |
+
self.ques_proj = nn.Linear(ques_dim, hidden_dim)
|
| 17 |
+
self.dis_proj = nn.Linear(disease_dim, hidden_dim)
|
| 18 |
+
|
| 19 |
+
self.att_img = nn.Linear(hidden_dim, 1)
|
| 20 |
+
self.att_dis = nn.Linear(hidden_dim, 1)
|
| 21 |
+
self.fusion = nn.Linear(hidden_dim * 3, hidden_dim)
|
| 22 |
+
|
| 23 |
+
# ✅ Store answer vocab inside model for later use
|
| 24 |
+
self.answer_vocab = answer_vocab
|
| 25 |
+
|
| 26 |
+
def forward(self, img_feat, ques_feat, dis_vec):
|
| 27 |
+
# Project features
|
| 28 |
+
#print("Input Shapes\t",img_feat.shape, ques_feat.shape, dis_vec.shape)
|
| 29 |
+
img_proj = torch.tanh(self.img_proj(img_feat)) # [B, H]
|
| 30 |
+
ques_proj = torch.tanh(self.ques_proj(ques_feat)) # [B, H]
|
| 31 |
+
|
| 32 |
+
#print("Ques_proj",ques_proj.shape,img_proj.shape)
|
| 33 |
+
|
| 34 |
+
dis_vec = dis_vec.to(torch.float32)
|
| 35 |
+
dis_proj = torch.tanh(self.dis_proj(dis_vec)) # [B, H]
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
#print("After projection\t",img_proj.shape, ques_proj.shape, dis_proj.shape)
|
| 39 |
+
|
| 40 |
+
# Expand question for image alignment
|
| 41 |
+
#ques_expand = ques_proj.unsqueeze(1).expand_as(img_proj)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
#ques_expand = ques_proj#.expand_as(img_proj)
|
| 45 |
+
#img_co = img_proj * ques_expand
|
| 46 |
+
|
| 47 |
+
#Replacement of above 2 lines
|
| 48 |
+
|
| 49 |
+
ques_proj = ques_proj.unsqueeze(1) # [16, 1, 512]
|
| 50 |
+
ques_expand = ques_proj.expand(-1, img_proj.size(1), -1) # [16, 197, 512]
|
| 51 |
+
img_co = img_proj * ques_expand # [16, 197, 512]
|
| 52 |
+
|
| 53 |
+
#print("ques_expand",ques_expand.shape,img_co.shape)
|
| 54 |
+
|
| 55 |
+
att_img_weights = torch.sigmoid(self.att_img(img_co)) # [B, 1]
|
| 56 |
+
#img_att = att_img_weights * img_proj
|
| 57 |
+
img_att = (att_img_weights * img_proj).sum(1)
|
| 58 |
+
#att_img_weights = F.softmax(self.att_img(img_co), dim=1) # [B, R, 1]
|
| 59 |
+
#img_att = (att_img_weights * img_proj).sum(1) # [B, H]
|
| 60 |
+
|
| 61 |
+
# Co-attention with disease vector
|
| 62 |
+
dis_co = dis_proj * ques_proj
|
| 63 |
+
att_dis_weights = torch.sigmoid(self.att_dis(dis_co)) # [B, 1]
|
| 64 |
+
dis_att = att_dis_weights * dis_proj # [B, H]
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
#print(img_att.shape, ques_proj.shape, dis_att.shape)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
if img_att.dim() == 1:
|
| 71 |
+
img_att = img_att.unsqueeze(0) # [1, H]
|
| 72 |
+
if ques_proj.dim() == 1:
|
| 73 |
+
ques_proj = ques_proj.unsqueeze(0) # [1, H]
|
| 74 |
+
if dis_att.dim() == 1:
|
| 75 |
+
dis_att = dis_att.unsqueeze(0) # [1, H]
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# Concatenate
|
| 79 |
+
|
| 80 |
+
ques_proj_flat = ques_proj.squeeze(1)
|
| 81 |
+
dis_att_flat = dis_att.squeeze(1)
|
| 82 |
+
|
| 83 |
+
#print(img_att.shape,ques_proj.shape,dis_att.shape)
|
| 84 |
+
|
| 85 |
+
joint_feat = torch.cat([img_att, ques_proj_flat, dis_att_flat], dim=1) # [B, 3H]
|
| 86 |
+
fused = torch.tanh(self.fusion(joint_feat)) # [B, H]
|
| 87 |
+
return fused
|
model.py
ADDED
|
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import os
|
| 4 |
+
from .qtype import QuestionTypeClassifier
|
| 5 |
+
from .functions import build_vocabs, build_answer_vocab, collate_fn, preprocess_example, normalize_answer, preprocess_image
|
| 6 |
+
from .models import disease_model, device, generate_descriptive_answer, router_tokenizer, gen_model
|
| 7 |
+
from .tpred import TaskPredictor
|
| 8 |
+
from .model_functions import compute_loss, compute_meteor, compute_rouge, extract_count, forward_batch
|
| 9 |
+
from .fussionmodel import BertModel, CoAttentionFusion, ViTModel, F
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class VQAModel(nn.Module):
|
| 13 |
+
def __init__(self,img_dim, ques_dim, disease_dim, hidden_dim):
|
| 14 |
+
super(VQAModel, self).__init__()
|
| 15 |
+
#self.fusion = CoAttentionFusion(img_dim, ques_dim, disease_dim, hidden_dim, answer_vocab=answer_vocab)
|
| 16 |
+
self.qtype_classifier=None
|
| 17 |
+
self.answer_classifier=None
|
| 18 |
+
self.epochs=1
|
| 19 |
+
self.device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 20 |
+
self.hidden_dim=hidden_dim
|
| 21 |
+
self.input_dim=768
|
| 22 |
+
self.ques_dim=ques_dim
|
| 23 |
+
self.disease_dim=disease_dim
|
| 24 |
+
self.img_dim=img_dim
|
| 25 |
+
self.fusion_module=None
|
| 26 |
+
self.question_encoder=BertModel.from_pretrained("bert-base-uncased").to(self.device)
|
| 27 |
+
self.image_encoder=ViTModel.from_pretrained("google/vit-base-patch16-224").to(self.device)
|
| 28 |
+
self.optimizer=None
|
| 29 |
+
self.answer_vocabs=None
|
| 30 |
+
self.task_vocabs=None
|
| 31 |
+
self.data_train=None
|
| 32 |
+
self.train_loader=None
|
| 33 |
+
self.q_types = ["yesno", "single", "multi", "color", "location", "count"]
|
| 34 |
+
# Create task-specific heads (trainable)
|
| 35 |
+
self.task_heads = nn.ModuleDict({
|
| 36 |
+
t: TaskPredictor(t, hidden=hidden_dim) for t in self.q_types
|
| 37 |
+
})
|
| 38 |
+
self.q_types_mapping = {
|
| 39 |
+
'abnormality_color': 'color',
|
| 40 |
+
'landmark_color': 'color',
|
| 41 |
+
'abnormality_location': 'location',
|
| 42 |
+
'instrument_location': 'location',
|
| 43 |
+
'landmark_location': 'location',
|
| 44 |
+
'finding_count': 'count',
|
| 45 |
+
'instrument_count': 'count',
|
| 46 |
+
'polyp_count': 'count',
|
| 47 |
+
'abnormality_presence': 'yesno',
|
| 48 |
+
'box_artifact_presence': 'yesno',
|
| 49 |
+
'finding_presence': 'yesno',
|
| 50 |
+
'instrument_presence': 'yesno',
|
| 51 |
+
'landmark_presence': 'yesno',
|
| 52 |
+
'text_presence': 'yesno',
|
| 53 |
+
'polyp_removal_status': 'yesno',
|
| 54 |
+
'polyp_type': 'single',
|
| 55 |
+
'polyp_size': 'single',
|
| 56 |
+
'procedure_type': 'single',
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def train(self,epochs,data_train,train_loader):
|
| 61 |
+
self.epochs=epochs
|
| 62 |
+
self.train_data=data_train
|
| 63 |
+
self.train_loader=train_loader
|
| 64 |
+
self.answer_vocabs = build_answer_vocab(self.train_data, self.q_types_mapping)
|
| 65 |
+
self.task_vocabs = build_vocabs(self.train_data,self.q_types_mapping)
|
| 66 |
+
#self.qtype_classifier = nn.Linear(hidden_dim, len(self.task_vocabs)) # ✅ match hidden_dim
|
| 67 |
+
self.qtype_classifier=QuestionTypeClassifier(num_types=len(self.q_types)).to(self.device)
|
| 68 |
+
#QuestionTypeClassifier(hidden=self.input_dim, num_types=len(self.q_types)).to(device)
|
| 69 |
+
#print(self.qtype_classifier)
|
| 70 |
+
self.answer_classifier = nn.Linear(self.hidden_dim, len(self.answer_vocabs)) # ✅ match hidden_dim
|
| 71 |
+
self.fusion_module = CoAttentionFusion(img_dim=self.img_dim,
|
| 72 |
+
ques_dim=self.ques_dim,
|
| 73 |
+
disease_dim=self.disease_dim,
|
| 74 |
+
hidden_dim=self.hidden_dim,
|
| 75 |
+
answer_vocab=self.answer_vocabs).to(self.device)
|
| 76 |
+
self.optimizer = torch.optim.AdamW(list(self.fusion_module.parameters()) +
|
| 77 |
+
list(self.question_encoder.parameters()) +
|
| 78 |
+
list(self.image_encoder.parameters())+
|
| 79 |
+
list(self.qtype_classifier.parameters()), lr=2e-5)
|
| 80 |
+
for epoch in range(self.epochs):
|
| 81 |
+
self.fusion_module.train()
|
| 82 |
+
self.qtype_classifier.train()
|
| 83 |
+
total_loss = 0
|
| 84 |
+
for batch in self.train_loader:
|
| 85 |
+
self.optimizer.zero_grad()
|
| 86 |
+
preds, answers, task_logits = forward_batch(
|
| 87 |
+
batch["images"],
|
| 88 |
+
batch["input_ids"],
|
| 89 |
+
batch["attention_mask"],
|
| 90 |
+
batch["answers"],
|
| 91 |
+
batch["question_classes"], # fine-grained from dataset
|
| 92 |
+
qtype_classifier=self.qtype_classifier,
|
| 93 |
+
fusion_module=self.fusion_module,
|
| 94 |
+
q_types=self.q_types,
|
| 95 |
+
q_types_mapping=self.q_types_mapping,
|
| 96 |
+
task_heads=self.task_heads,
|
| 97 |
+
device=self.device,
|
| 98 |
+
image_encoder=self.image_encoder,
|
| 99 |
+
question_encoder=self.question_encoder
|
| 100 |
+
)
|
| 101 |
+
#preds, answers = forward_batch(batch["images"],batch["input_ids"], batch["attention_mask"], batch["answers"], batch["question_classes"])
|
| 102 |
+
loss = compute_loss(preds,
|
| 103 |
+
answers,
|
| 104 |
+
task_logits,
|
| 105 |
+
batch["question_classes"],
|
| 106 |
+
answer_vocabs=self.answer_vocabs,
|
| 107 |
+
q_types_mapping=self.q_types_mapping,
|
| 108 |
+
q_types=self.q_types,
|
| 109 |
+
task_heads=self.task_heads
|
| 110 |
+
)
|
| 111 |
+
#loss = compute_loss(preds, answers, batch["question_classes"])
|
| 112 |
+
loss.backward()
|
| 113 |
+
self.optimizer.step()
|
| 114 |
+
total_loss += loss.item()
|
| 115 |
+
print(f"Epoch {epoch}, Train Loss: {total_loss / len(train_loader)}")
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def eval(self, val_loader):
|
| 119 |
+
"""
|
| 120 |
+
Evaluate the model on the validation set.
|
| 121 |
+
|
| 122 |
+
Args:
|
| 123 |
+
val_loader: DataLoader for validation data.
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
avg_loss: average validation loss
|
| 127 |
+
all_preds: list of predicted labels
|
| 128 |
+
all_answers: list of ground truth answers
|
| 129 |
+
"""
|
| 130 |
+
self.fusion_module.eval()
|
| 131 |
+
self.question_encoder.eval()
|
| 132 |
+
self.image_encoder.eval()
|
| 133 |
+
self.qtype_classifier.eval()
|
| 134 |
+
for head in self.task_heads.values():
|
| 135 |
+
head.eval()
|
| 136 |
+
|
| 137 |
+
total_loss = 0.0
|
| 138 |
+
all_preds, all_answers = [], []
|
| 139 |
+
|
| 140 |
+
with torch.no_grad():
|
| 141 |
+
for batch in val_loader:
|
| 142 |
+
images = batch["images"].to(self.device)
|
| 143 |
+
input_ids = batch["input_ids"].to(self.device)
|
| 144 |
+
attention_mask = batch["attention_mask"].to(self.device)
|
| 145 |
+
answers = batch["answers"]
|
| 146 |
+
q_classes = batch["question_classes"]
|
| 147 |
+
|
| 148 |
+
# ---- Disease vector ----
|
| 149 |
+
disease_vec = disease_model(images)
|
| 150 |
+
|
| 151 |
+
# ---- Question type classifier ----
|
| 152 |
+
task_logits = self.qtype_classifier(
|
| 153 |
+
input_ids=input_ids,
|
| 154 |
+
attention_mask=attention_mask
|
| 155 |
+
) # [B, num_types]
|
| 156 |
+
|
| 157 |
+
# map fine-grained → general
|
| 158 |
+
mapped_classes = [
|
| 159 |
+
self.q_types_mapping[c[0] if isinstance(c, list) else c]
|
| 160 |
+
for c in q_classes
|
| 161 |
+
]
|
| 162 |
+
|
| 163 |
+
# ---- Encoders ----
|
| 164 |
+
q_feat = self.question_encoder(
|
| 165 |
+
input_ids=input_ids,
|
| 166 |
+
attention_mask=attention_mask
|
| 167 |
+
).pooler_output # [B, 768]
|
| 168 |
+
|
| 169 |
+
img_outputs = self.image_encoder(pixel_values=images)
|
| 170 |
+
img_feat = img_outputs.last_hidden_state # [B, R, 768]
|
| 171 |
+
|
| 172 |
+
# ---- Fusion ----
|
| 173 |
+
fused = self.fusion_module(img_feat, q_feat, disease_vec)
|
| 174 |
+
|
| 175 |
+
# ---- Predict per sample ----
|
| 176 |
+
pred_tensors = []
|
| 177 |
+
batch_preds = []
|
| 178 |
+
for i, task_type in enumerate(mapped_classes):
|
| 179 |
+
predictor = self.task_heads[task_type]
|
| 180 |
+
#pred_out = predictor(fused[i].unsqueeze(0))
|
| 181 |
+
pred_tensor = predictor(fused[i].unsqueeze(0)) # shape [1, C] or [1,1] for count
|
| 182 |
+
pred_tensors.append(pred_tensor)
|
| 183 |
+
|
| 184 |
+
if task_type == "yesno":
|
| 185 |
+
pred_label = "Yes" if torch.argmax(pred_tensor, dim=1).item() == 1 else "No"
|
| 186 |
+
elif task_type == "count":
|
| 187 |
+
pred_val = pred_tensor.squeeze()
|
| 188 |
+
pred_label = str(int(round(pred_val.item())))
|
| 189 |
+
#pred_label = str(int(pred_out.item()))
|
| 190 |
+
else:
|
| 191 |
+
ans_idx = torch.argmax(pred_tensor, dim=1).item()
|
| 192 |
+
if task_type in self.answer_vocabs and ans_idx < len(self.answer_vocabs[task_type]):
|
| 193 |
+
inv_vocab = {v: k for k, v in self.answer_vocabs[task_type].items()}
|
| 194 |
+
pred_label = inv_vocab.get(ans_idx, str(ans_idx))
|
| 195 |
+
else:
|
| 196 |
+
pred_label = str(ans_idx)
|
| 197 |
+
|
| 198 |
+
batch_preds.append(pred_label)
|
| 199 |
+
|
| 200 |
+
# ---- Compute loss ----
|
| 201 |
+
"""
|
| 202 |
+
batch_loss = compute_loss(
|
| 203 |
+
[self.task_heads[c](fused[i].unsqueeze(0)) for i, c in enumerate(mapped_classes)],
|
| 204 |
+
answers,
|
| 205 |
+
task_logits,
|
| 206 |
+
q_classes,
|
| 207 |
+
self.answer_vocabs
|
| 208 |
+
)"""
|
| 209 |
+
# compute batch loss using the same preds (tensors) and required extra args
|
| 210 |
+
batch_loss = compute_loss(
|
| 211 |
+
preds=pred_tensors,
|
| 212 |
+
answers=answers,
|
| 213 |
+
task_logits=task_logits,
|
| 214 |
+
true_q_classes=q_classes,
|
| 215 |
+
answer_vocabs=self.answer_vocabs,
|
| 216 |
+
q_types_mapping=self.q_types_mapping,
|
| 217 |
+
q_types=self.q_types,
|
| 218 |
+
task_heads=self.task_heads
|
| 219 |
+
)
|
| 220 |
+
total_loss += batch_loss.item()
|
| 221 |
+
|
| 222 |
+
all_preds.extend(batch_preds)
|
| 223 |
+
all_answers.extend(answers)
|
| 224 |
+
|
| 225 |
+
avg_loss = total_loss / len(val_loader)
|
| 226 |
+
return avg_loss, all_preds, all_answers
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def load(self,load_path = "vqa_model.pt"):
|
| 230 |
+
checkpoint = torch.load(load_path, map_location=self.device,weights_only=False)
|
| 231 |
+
self.task_vocabs=checkpoint["task_vocabs"]
|
| 232 |
+
self.answer_vocabs=checkpoint["answer_vocabs"]
|
| 233 |
+
self.fusion_module = CoAttentionFusion(
|
| 234 |
+
img_dim=self.img_dim, ques_dim=self.ques_dim, disease_dim=self.disease_dim, hidden_dim=self.hidden_dim,
|
| 235 |
+
answer_vocab=checkpoint["answer_vocabs"]
|
| 236 |
+
).to(self.device)
|
| 237 |
+
self.fusion_module.load_state_dict(checkpoint["fusion_module"])
|
| 238 |
+
self.question_encoder.load_state_dict(checkpoint["question_encoder"])
|
| 239 |
+
self.image_encoder.load_state_dict(checkpoint["image_encoder"])
|
| 240 |
+
self.qtype_classifier.load_state_dict(checkpoint["qtype_classifier"])
|
| 241 |
+
|
| 242 |
+
for k, v in checkpoint["task_heads"].items():
|
| 243 |
+
self.task_heads[k].load_state_dict(v)
|
| 244 |
+
|
| 245 |
+
# 3. Recreate optimizer with correct params
|
| 246 |
+
self.optimizer = torch.optim.AdamW(
|
| 247 |
+
list(self.fusion_module.parameters()) +
|
| 248 |
+
list(self.question_encoder.parameters()) +
|
| 249 |
+
list(self.image_encoder.parameters()) +
|
| 250 |
+
list(self.qtype_classifier.parameters()),
|
| 251 |
+
lr=2e-5
|
| 252 |
+
)
|
| 253 |
+
self.optimizer.load_state_dict(checkpoint["optimizer"])
|
| 254 |
+
print("Model and components loaded successfully")
|
| 255 |
+
|
| 256 |
+
def save(self,save_path = "vqa_model.pt"):
|
| 257 |
+
torch.save({
|
| 258 |
+
"fusion_module": self.fusion_module.state_dict(),
|
| 259 |
+
"question_encoder": self.question_encoder.state_dict(),
|
| 260 |
+
"image_encoder": self.image_encoder.state_dict(),
|
| 261 |
+
"qtype_classifier": self.qtype_classifier.state_dict(),
|
| 262 |
+
"task_heads": {k: v.state_dict() for k, v in self.task_heads.items()},
|
| 263 |
+
"optimizer": self.optimizer.state_dict(),
|
| 264 |
+
"epochs": self.epochs,
|
| 265 |
+
"answer_vocabs": self.answer_vocabs,
|
| 266 |
+
"task_vocabs": self.task_vocabs
|
| 267 |
+
}, save_path)
|
| 268 |
+
print(f"Model saved at {save_path}")
|
| 269 |
+
|
| 270 |
+
def predict(self, image, question):
|
| 271 |
+
self.fusion_module.eval()
|
| 272 |
+
self.question_encoder.eval()
|
| 273 |
+
self.image_encoder.eval()
|
| 274 |
+
self.qtype_classifier.eval()
|
| 275 |
+
|
| 276 |
+
with torch.no_grad():
|
| 277 |
+
# ---- Preprocess image ----
|
| 278 |
+
image_tensor = preprocess_image(image).unsqueeze(0).to(self.device)
|
| 279 |
+
|
| 280 |
+
# ---- Disease vector ----
|
| 281 |
+
disease_vec = disease_model(image_tensor)
|
| 282 |
+
|
| 283 |
+
# ---- Encode question ----
|
| 284 |
+
q_inputs = router_tokenizer(
|
| 285 |
+
question,
|
| 286 |
+
return_tensors="pt",
|
| 287 |
+
truncation=True,
|
| 288 |
+
padding=True
|
| 289 |
+
).to(self.device)
|
| 290 |
+
|
| 291 |
+
# DistilBERT classifier for q-type
|
| 292 |
+
task_logits = self.qtype_classifier(
|
| 293 |
+
input_ids=q_inputs["input_ids"],
|
| 294 |
+
attention_mask=q_inputs["attention_mask"]
|
| 295 |
+
) # [1, num_types]
|
| 296 |
+
|
| 297 |
+
task_idx = torch.argmax(task_logits, dim=1).item()
|
| 298 |
+
task_type = self.q_types[task_idx] # map index → general type
|
| 299 |
+
|
| 300 |
+
# ---- Question encoder for fusion ----
|
| 301 |
+
q_feat = self.question_encoder(**q_inputs).pooler_output # [1, 768]
|
| 302 |
+
|
| 303 |
+
# ---- Image encoder ----
|
| 304 |
+
img_outputs = self.image_encoder(pixel_values=image_tensor)
|
| 305 |
+
img_feat = img_outputs.last_hidden_state # [1, R, 768]
|
| 306 |
+
|
| 307 |
+
# ---- Fusion ----
|
| 308 |
+
fused = self.fusion_module(img_feat, q_feat, disease_vec)
|
| 309 |
+
|
| 310 |
+
# ---- Task-specific head ----
|
| 311 |
+
predictor = self.task_heads[task_type] # use pretrained head
|
| 312 |
+
pred_out = predictor(fused)
|
| 313 |
+
|
| 314 |
+
# ---- Decode prediction ----
|
| 315 |
+
if task_type == "yesno":
|
| 316 |
+
pred_label = "Yes" if torch.argmax(pred_out, dim=1).item() == 1 else "No"
|
| 317 |
+
|
| 318 |
+
elif task_type == "count":
|
| 319 |
+
pred_label = str(int(pred_out.item()))
|
| 320 |
+
|
| 321 |
+
else: # categorical answer
|
| 322 |
+
ans_idx = torch.argmax(pred_out, dim=1).item()
|
| 323 |
+
if task_type in self.answer_vocabs and ans_idx < len(self.answer_vocabs[task_type]):
|
| 324 |
+
inv_vocab = {v: k for k, v in self.answer_vocabs[task_type].items()}
|
| 325 |
+
pred_label = inv_vocab.get(ans_idx, str(ans_idx))
|
| 326 |
+
else:
|
| 327 |
+
pred_label = str(ans_idx)
|
| 328 |
+
|
| 329 |
+
return pred_label
|
| 330 |
+
|
model_functions.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.nn import CrossEntropyLoss, MSELoss
|
| 3 |
+
import re
|
| 4 |
+
#!pip install rouge_score
|
| 5 |
+
|
| 6 |
+
#from rouge_score import rouge_scorer
|
| 7 |
+
|
| 8 |
+
from nltk.translate.meteor_score import meteor_score
|
| 9 |
+
from .models import disease_model
|
| 10 |
+
|
| 11 |
+
def forward_batch(images, input_ids, attention_mask, answers, question_classes=None,qtype_classifier=None,fusion_module=None,q_types=None,q_types_mapping=None,task_heads=None,device=None,image_encoder=None,question_encoder=None):
|
| 12 |
+
# Image encoding
|
| 13 |
+
img_outputs = image_encoder(pixel_values=images.to(device))
|
| 14 |
+
img_feat = img_outputs.last_hidden_state # [B, R, 768]
|
| 15 |
+
|
| 16 |
+
# Question encoding (DistilBERT for qtype classification)
|
| 17 |
+
task_logits = qtype_classifier(input_ids=input_ids.to(device),
|
| 18 |
+
attention_mask=attention_mask.to(device)) # [B, num_types]
|
| 19 |
+
|
| 20 |
+
# Use another encoder for question embeddings (router encoder you already had)
|
| 21 |
+
q_feat = question_encoder(input_ids=input_ids.to(device),
|
| 22 |
+
attention_mask=attention_mask.to(device)).pooler_output # [B, 768]
|
| 23 |
+
|
| 24 |
+
# Disease model
|
| 25 |
+
disease_vec = disease_model(images.to(device)) # [B, 23]
|
| 26 |
+
|
| 27 |
+
# Fusion
|
| 28 |
+
fused = fusion_module(img_feat, q_feat, disease_vec)
|
| 29 |
+
|
| 30 |
+
# Task-specific predictions (list of preds per sample, like before)
|
| 31 |
+
preds = []
|
| 32 |
+
for i, q_class in enumerate(question_classes):#q_class from task)type
|
| 33 |
+
mapped_type = q_types_mapping[q_class[0] if isinstance(q_class, list) else q_class]
|
| 34 |
+
predictor = task_heads[mapped_type] # ✅ trained head
|
| 35 |
+
pred_out = predictor(fused[i].unsqueeze(0))
|
| 36 |
+
preds.append(pred_out)
|
| 37 |
+
#general_class = q_types_mapping[task_type[0] if isinstance(task_type, list) else task_type]
|
| 38 |
+
#head = TaskPredictor(general_class, hidden=fused.size(-1)).to(device)
|
| 39 |
+
#preds.append(head(fused[i].unsqueeze(0)))
|
| 40 |
+
|
| 41 |
+
return preds, answers, task_logits
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def forward_batch1(images, input_ids, attention_mask, answers, true_q_classes=None,qtype_classifier=None,fusion_module=None,q_types=None):
|
| 45 |
+
# Disease vector (dummy placeholder: replace with your trained disease model)
|
| 46 |
+
disease_vec = disease_model(images) # [B, 23]
|
| 47 |
+
|
| 48 |
+
# Encode image
|
| 49 |
+
img_outputs = image_encoder(pixel_values=images.to(device))
|
| 50 |
+
img_feat = img_outputs.last_hidden_state # [B, R, 768]
|
| 51 |
+
|
| 52 |
+
# Encode question
|
| 53 |
+
q_feat = question_encoder(input_ids=input_ids.to(device),
|
| 54 |
+
attention_mask=attention_mask.to(device)).pooler_output # [B, 768]
|
| 55 |
+
|
| 56 |
+
# Predict task type from question
|
| 57 |
+
#print(q_feat.device)
|
| 58 |
+
#print(q_feat.shape)
|
| 59 |
+
#task_logits = qtype_classifier(q_feat) # [B, 6]
|
| 60 |
+
task_logits = qtype_classifier(input_ids=batch["input_ids"],
|
| 61 |
+
attention_mask=batch["attention_mask"])
|
| 62 |
+
task_pred = torch.argmax(task_logits, dim=1) # predicted type index
|
| 63 |
+
|
| 64 |
+
# Fusion
|
| 65 |
+
fused = fusion_module(img_feat, q_feat, disease_vec)
|
| 66 |
+
|
| 67 |
+
# Task-specific predictions
|
| 68 |
+
preds = []
|
| 69 |
+
for i, t_idx in enumerate(task_pred):
|
| 70 |
+
task_type = q_types[t_idx] # map index to string
|
| 71 |
+
predictor = TaskPredictor(task_type).to(device)
|
| 72 |
+
preds.append(predictor(fused[i].unsqueeze(0)))
|
| 73 |
+
|
| 74 |
+
return preds, answers, task_logits
|
| 75 |
+
|
| 76 |
+
#for i, task_type in enumerate(q_classes):
|
| 77 |
+
# predictor = TaskPredictor(task_type).to(device)
|
| 78 |
+
# pred_out = predictor(fused[i].unsqueeze(0))
|
| 79 |
+
# preds.append(pred_out)
|
| 80 |
+
#return preds, answers
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def extract_count(answer_str):
|
| 84 |
+
"""
|
| 85 |
+
Try to convert an answer string into a number.
|
| 86 |
+
Returns None if it cannot be parsed.
|
| 87 |
+
"""
|
| 88 |
+
try:
|
| 89 |
+
# Direct numeric
|
| 90 |
+
return float(answer_str)
|
| 91 |
+
except ValueError:
|
| 92 |
+
pass
|
| 93 |
+
|
| 94 |
+
# Handle words like "one", "two", etc.
|
| 95 |
+
word2num = {
|
| 96 |
+
"zero": 0, "one": 1, "two": 2, "three": 3,
|
| 97 |
+
"four": 4, "five": 5, "six": 6,
|
| 98 |
+
"seven": 7, "eight": 8, "nine": 9, "ten": 10
|
| 99 |
+
}
|
| 100 |
+
tokens = answer_str.lower().split()
|
| 101 |
+
for t in tokens:
|
| 102 |
+
if t in word2num:
|
| 103 |
+
return float(word2num[t])
|
| 104 |
+
|
| 105 |
+
# Extract any digits from the string
|
| 106 |
+
numbers = re.findall(r"\d+", answer_str)
|
| 107 |
+
if numbers:
|
| 108 |
+
return float(numbers[0])
|
| 109 |
+
|
| 110 |
+
return None # fallback
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def compute_meteor(preds, answers, answer_vocabs, mapped_classes):
|
| 114 |
+
scores = []
|
| 115 |
+
for pred, ans, c in zip(preds, answers, mapped_classes):
|
| 116 |
+
if c not in answer_vocabs:
|
| 117 |
+
continue
|
| 118 |
+
# Get predicted index
|
| 119 |
+
pred_idx = pred.argmax(dim=1).item()
|
| 120 |
+
# Map index back to string
|
| 121 |
+
inv_vocab = {v: k for k, v in answer_vocabs[c].items()}
|
| 122 |
+
pred_str = inv_vocab.get(pred_idx, "")
|
| 123 |
+
# METEOR score between predicted and ground truth answer
|
| 124 |
+
score = meteor_score([ans.split()], pred_str.split())
|
| 125 |
+
scores.append(score)
|
| 126 |
+
return sum(scores) / len(scores) if scores else 0.0
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def compute_rouge(preds, answers, answer_vocabs, mapped_classes):
|
| 130 |
+
scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=True)
|
| 131 |
+
scores = []
|
| 132 |
+
for pred, ans, c in zip(preds, answers, mapped_classes):
|
| 133 |
+
if c not in answer_vocabs:
|
| 134 |
+
continue
|
| 135 |
+
pred_idx = pred.argmax(dim=1).item()
|
| 136 |
+
inv_vocab = {v: k for k, v in answer_vocabs[c].items()}
|
| 137 |
+
pred_str = inv_vocab.get(pred_idx, "")
|
| 138 |
+
score = scorer.score(ans, pred_str)["rougeL"].fmeasure
|
| 139 |
+
scores.append(score)
|
| 140 |
+
return sum(scores) / len(scores) if scores else 0.0
|
| 141 |
+
|
| 142 |
+
def compute_loss(preds, answers, task_logits, true_q_classes, answer_vocabs,q_types_mapping,q_types,task_heads):
|
| 143 |
+
"""
|
| 144 |
+
preds: list of model predictions for each sample
|
| 145 |
+
answers: list of strings (descriptive answers)
|
| 146 |
+
task_logits: tensor [batch_size, num_task_types]
|
| 147 |
+
true_q_classes: list of lists (fine-grained classes for each question)
|
| 148 |
+
answer_vocabs: dict mapping {q_type: {answer: index}}
|
| 149 |
+
"""
|
| 150 |
+
|
| 151 |
+
ce_loss = CrossEntropyLoss()
|
| 152 |
+
mse_loss = MSELoss()
|
| 153 |
+
|
| 154 |
+
total_loss = 0
|
| 155 |
+
|
| 156 |
+
# 1) Map fine-grained → general classes
|
| 157 |
+
mapped_classes = [
|
| 158 |
+
q_types_mapping[c[0] if isinstance(c, list) else c]
|
| 159 |
+
for c in true_q_classes
|
| 160 |
+
]
|
| 161 |
+
|
| 162 |
+
# 2) Question type classification loss
|
| 163 |
+
true_task_types = torch.tensor(
|
| 164 |
+
[q_types.index(c) for c in mapped_classes],
|
| 165 |
+
device=task_logits.device
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
#print("task_logits, true_task_types\t",task_logits, true_task_types)
|
| 169 |
+
|
| 170 |
+
#print("task_logits, true_task_types\t",task_logits.shape, true_task_types.shape)
|
| 171 |
+
|
| 172 |
+
task_loss = ce_loss(task_logits, true_task_types)
|
| 173 |
+
total_loss += task_loss
|
| 174 |
+
|
| 175 |
+
# 3) Answer prediction loss (per sample)
|
| 176 |
+
for pred, ans, c in zip(preds, answers, mapped_classes):
|
| 177 |
+
predictor = task_heads[c] # ✅ trained head
|
| 178 |
+
if c == "count":
|
| 179 |
+
# For count, answer must be numeric
|
| 180 |
+
try:
|
| 181 |
+
ans_val = float(ans)
|
| 182 |
+
ans_val = torch.tensor([ans_val], device=pred.device)
|
| 183 |
+
total_loss += mse_loss(pred.squeeze(), ans_val)
|
| 184 |
+
except ValueError:
|
| 185 |
+
print(f"[Warning] Skipping non-numeric count answer: {ans}")
|
| 186 |
+
continue
|
| 187 |
+
|
| 188 |
+
else:
|
| 189 |
+
# For categorical tasks (yesno, single, multi, etc.)
|
| 190 |
+
if ans not in answer_vocabs.get(c, {}):
|
| 191 |
+
print(f"[Warning] Skipping unseen or descriptive answer {ans} for task {c}")
|
| 192 |
+
continue
|
| 193 |
+
|
| 194 |
+
ans_idx = answer_vocabs[c][ans]
|
| 195 |
+
|
| 196 |
+
if ans_idx >= pred.size(1):
|
| 197 |
+
print(f"[Warning] Skipping answer {ans} for task {c}: "
|
| 198 |
+
f"index {ans_idx} >= pred.size(1)")
|
| 199 |
+
continue
|
| 200 |
+
|
| 201 |
+
ans_tensor = torch.tensor([ans_idx], device=pred.device)
|
| 202 |
+
total_loss += ce_loss(pred, ans_tensor)
|
| 203 |
+
|
| 204 |
+
meteor = compute_meteor(preds, answers, answer_vocabs, mapped_classes)
|
| 205 |
+
print(f"Validation METEOR: {meteor:.4f}")
|
| 206 |
+
#rouge = compute_rouge(preds, answers, answer_vocabs, mapped_classes)
|
| 207 |
+
#print(f"Validation ROUGE-L: {rouge:.4f}")
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
return total_loss / len(preds)
|
models.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from transformers import (
|
| 5 |
+
AutoTokenizer, AutoModelForSequenceClassification,
|
| 6 |
+
AutoModel, AutoProcessor, VisionEncoderDecoderModel,
|
| 7 |
+
T5Tokenizer, T5ForConditionalGeneration
|
| 8 |
+
)
|
| 9 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 10 |
+
|
| 11 |
+
gen_name = "t5-base"
|
| 12 |
+
gen_tokenizer = T5Tokenizer.from_pretrained(gen_name)
|
| 13 |
+
gen_model = T5ForConditionalGeneration.from_pretrained(gen_name).to(device)
|
| 14 |
+
|
| 15 |
+
def generate_descriptive_answer(question, prediction, fused_features):
|
| 16 |
+
# Construct a prompt combining prediction and context
|
| 17 |
+
prompt = f"Question: {question} | Prediction: {prediction} | Context: GI disease analysis"
|
| 18 |
+
inputs = gen_tokenizer(prompt, return_tensors="pt", truncation=True).to(device)
|
| 19 |
+
outputs = gen_model.generate(**inputs, max_length=50)
|
| 20 |
+
return gen_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 21 |
+
|
| 22 |
+
def disease_model(img):
|
| 23 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 24 |
+
#torch.tensor(np.random.rand(23)).to(device)
|
| 25 |
+
return torch.zeros(23).to(device)
|
| 26 |
+
|
| 27 |
+
router_name = "distilbert-base-uncased"
|
| 28 |
+
router_tokenizer = AutoTokenizer.from_pretrained(router_name)
|
| 29 |
+
|
qtype.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from transformers import DistilBertModel
|
| 4 |
+
|
| 5 |
+
class QuestionTypeClassifier(nn.Module):
|
| 6 |
+
def __init__(self, num_types):
|
| 7 |
+
super().__init__()
|
| 8 |
+
|
| 9 |
+
# Load pre-trained DistilBERT
|
| 10 |
+
self.distilbert = DistilBertModel.from_pretrained("distilbert-base-uncased")
|
| 11 |
+
|
| 12 |
+
# Classification head
|
| 13 |
+
self.fc = nn.Linear(self.distilbert.config.hidden_size, num_types)
|
| 14 |
+
|
| 15 |
+
def forward(self, input_ids, attention_mask):
|
| 16 |
+
outputs = self.distilbert(
|
| 17 |
+
input_ids=input_ids,
|
| 18 |
+
attention_mask=attention_mask
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
# Take [CLS] token embedding (DistilBERT uses first token as [CLS])
|
| 22 |
+
cls_token = outputs.last_hidden_state[:, 0, :] # [B, hidden]
|
| 23 |
+
|
| 24 |
+
logits = self.fc(cls_token) # [B, num_types]
|
| 25 |
+
return logits
|
tpred.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ---------------------------
|
| 2 |
+
# Step 5: Task-Specific Predictors
|
| 3 |
+
# ---------------------------
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
class TaskPredictor(nn.Module):
|
| 6 |
+
def __init__(self, task_type, hidden=512):
|
| 7 |
+
super().__init__()
|
| 8 |
+
if task_type == "yesno":
|
| 9 |
+
self.head = nn.Linear(hidden, 2)
|
| 10 |
+
elif task_type == "single":
|
| 11 |
+
self.head = nn.Linear(hidden, 10)
|
| 12 |
+
elif task_type == "multi":
|
| 13 |
+
self.head = nn.Linear(hidden, 10)
|
| 14 |
+
elif task_type == "color":
|
| 15 |
+
self.head = nn.Linear(hidden, 5)
|
| 16 |
+
elif task_type == "location":
|
| 17 |
+
self.head = nn.Linear(hidden, 6)
|
| 18 |
+
elif task_type == "count":
|
| 19 |
+
self.head = nn.Linear(hidden, 1)
|
| 20 |
+
else:
|
| 21 |
+
raise ValueError("Unknown task")
|
| 22 |
+
|
| 23 |
+
def forward(self, x):
|
| 24 |
+
return self.head(x)
|