File size: 7,944 Bytes
62305fe
 
 
a9673bf
62305fe
 
 
 
 
a9673bf
62305fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
import torch 
from torch.nn import CrossEntropyLoss, MSELoss
import re
from models import disease_model
#!pip install rouge_score

#from rouge_score import rouge_scorer

from nltk.translate.meteor_score import meteor_score


def forward_batch(images, input_ids, attention_mask, answers, question_classes=None,qtype_classifier=None,fusion_module=None,q_types=None,q_types_mapping=None,task_heads=None,device=None,image_encoder=None,question_encoder=None):
    # Image encoding
    img_outputs = image_encoder(pixel_values=images.to(device))
    img_feat = img_outputs.last_hidden_state  # [B, R, 768]

    # Question encoding (DistilBERT for qtype classification)
    task_logits = qtype_classifier(input_ids=input_ids.to(device), 
                                   attention_mask=attention_mask.to(device))  # [B, num_types]

    # Use another encoder for question embeddings (router encoder you already had)
    q_feat = question_encoder(input_ids=input_ids.to(device), 
                              attention_mask=attention_mask.to(device)).pooler_output  # [B, 768]

    # Disease model
    disease_vec = disease_model(images.to(device))  # [B, 23]

    # Fusion
    fused = fusion_module(img_feat, q_feat, disease_vec)

    # Task-specific predictions (list of preds per sample, like before)
    preds = []
    for i, q_class in enumerate(question_classes):#q_class from task)type
        mapped_type = q_types_mapping[q_class[0] if isinstance(q_class, list) else q_class]
        predictor = task_heads[mapped_type]   # ✅ trained head
        pred_out = predictor(fused[i].unsqueeze(0))
        preds.append(pred_out)
        #general_class = q_types_mapping[task_type[0] if isinstance(task_type, list) else task_type]
        #head = TaskPredictor(general_class, hidden=fused.size(-1)).to(device)
        #preds.append(head(fused[i].unsqueeze(0)))

    return preds, answers, task_logits


def forward_batch1(images, input_ids, attention_mask, answers, true_q_classes=None,qtype_classifier=None,fusion_module=None,q_types=None):
    # Disease vector (dummy placeholder: replace with your trained disease model)
    disease_vec = disease_model(images)  # [B, 23]

    # Encode image
    img_outputs = image_encoder(pixel_values=images.to(device))
    img_feat = img_outputs.last_hidden_state  # [B, R, 768]

    # Encode question
    q_feat = question_encoder(input_ids=input_ids.to(device),
                              attention_mask=attention_mask.to(device)).pooler_output  # [B, 768]

    # Predict task type from question
    #print(q_feat.device)
    #print(q_feat.shape)
    #task_logits = qtype_classifier(q_feat)  # [B, 6]
    task_logits = qtype_classifier(input_ids=batch["input_ids"],
                               attention_mask=batch["attention_mask"])
    task_pred = torch.argmax(task_logits, dim=1)  # predicted type index
    
    # Fusion
    fused = fusion_module(img_feat, q_feat, disease_vec)

    # Task-specific predictions
    preds = []
    for i, t_idx in enumerate(task_pred):
        task_type = q_types[t_idx]  # map index to string
        predictor = TaskPredictor(task_type).to(device)
        preds.append(predictor(fused[i].unsqueeze(0)))

    return preds, answers, task_logits
    
    #for i, task_type in enumerate(q_classes):
    #    predictor = TaskPredictor(task_type).to(device)
    #    pred_out = predictor(fused[i].unsqueeze(0))
    #    preds.append(pred_out)
    #return preds, answers


def extract_count(answer_str):
    """
    Try to convert an answer string into a number.
    Returns None if it cannot be parsed.
    """
    try:
        # Direct numeric
        return float(answer_str)
    except ValueError:
        pass

    # Handle words like "one", "two", etc.
    word2num = {
        "zero": 0, "one": 1, "two": 2, "three": 3,
        "four": 4, "five": 5, "six": 6,
        "seven": 7, "eight": 8, "nine": 9, "ten": 10
    }
    tokens = answer_str.lower().split()
    for t in tokens:
        if t in word2num:
            return float(word2num[t])

    # Extract any digits from the string
    numbers = re.findall(r"\d+", answer_str)
    if numbers:
        return float(numbers[0])

    return None  # fallback


def compute_meteor(preds, answers, answer_vocabs, mapped_classes):
    scores = []
    for pred, ans, c in zip(preds, answers, mapped_classes):
        if c not in answer_vocabs:
            continue
        # Get predicted index
        pred_idx = pred.argmax(dim=1).item()
        # Map index back to string
        inv_vocab = {v: k for k, v in answer_vocabs[c].items()}
        pred_str = inv_vocab.get(pred_idx, "")
        # METEOR score between predicted and ground truth answer
        score = meteor_score([ans.split()], pred_str.split())
        scores.append(score)
    return sum(scores) / len(scores) if scores else 0.0


def compute_rouge(preds, answers, answer_vocabs, mapped_classes):
    scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=True)
    scores = []
    for pred, ans, c in zip(preds, answers, mapped_classes):
        if c not in answer_vocabs:
            continue
        pred_idx = pred.argmax(dim=1).item()
        inv_vocab = {v: k for k, v in answer_vocabs[c].items()}
        pred_str = inv_vocab.get(pred_idx, "")
        score = scorer.score(ans, pred_str)["rougeL"].fmeasure
        scores.append(score)
    return sum(scores) / len(scores) if scores else 0.0

def compute_loss(preds, answers, task_logits, true_q_classes, answer_vocabs,q_types_mapping,q_types,task_heads):
    """
    preds: list of model predictions for each sample
    answers: list of strings (descriptive answers)
    task_logits: tensor [batch_size, num_task_types]
    true_q_classes: list of lists (fine-grained classes for each question)
    answer_vocabs: dict mapping {q_type: {answer: index}}
    """

    ce_loss = CrossEntropyLoss()
    mse_loss = MSELoss()

    total_loss = 0

    # 1) Map fine-grained → general classes
    mapped_classes = [
        q_types_mapping[c[0] if isinstance(c, list) else c]
        for c in true_q_classes
    ]

    # 2) Question type classification loss
    true_task_types = torch.tensor(
        [q_types.index(c) for c in mapped_classes],
        device=task_logits.device
    )

    #print("task_logits, true_task_types\t",task_logits, true_task_types)
    
    #print("task_logits, true_task_types\t",task_logits.shape, true_task_types.shape)
    
    task_loss = ce_loss(task_logits, true_task_types)
    total_loss += task_loss

    # 3) Answer prediction loss (per sample)
    for pred, ans, c in zip(preds, answers, mapped_classes):
        predictor = task_heads[c]   # ✅ trained head
        if c == "count":
            # For count, answer must be numeric
            try:
                ans_val = float(ans)
                ans_val = torch.tensor([ans_val], device=pred.device)
                total_loss += mse_loss(pred.squeeze(), ans_val)
            except ValueError:
                print(f"[Warning] Skipping non-numeric count answer: {ans}")
                continue

        else:
            # For categorical tasks (yesno, single, multi, etc.)
            if ans not in answer_vocabs.get(c, {}):
                print(f"[Warning] Skipping unseen or descriptive answer {ans} for task {c}")
                continue

            ans_idx = answer_vocabs[c][ans]

            if ans_idx >= pred.size(1):
                print(f"[Warning] Skipping answer {ans} for task {c}: "
                      f"index {ans_idx} >= pred.size(1)")
                continue

            ans_tensor = torch.tensor([ans_idx], device=pred.device)
            total_loss += ce_loss(pred, ans_tensor)

    meteor = compute_meteor(preds, answers, answer_vocabs, mapped_classes)
    print(f"Validation METEOR: {meteor:.4f}")
    #rouge = compute_rouge(preds, answers, answer_vocabs, mapped_classes)
    #print(f"Validation ROUGE-L: {rouge:.4f}")

    
    return total_loss / len(preds)