| import torch |
| from torch.nn import CrossEntropyLoss, MSELoss |
| import re |
| from models import disease_model |
| |
|
|
| |
|
|
| from nltk.translate.meteor_score import meteor_score |
|
|
|
|
| def forward_batch(images, input_ids, attention_mask, answers, question_classes=None,qtype_classifier=None,fusion_module=None,q_types=None,q_types_mapping=None,task_heads=None,device=None,image_encoder=None,question_encoder=None): |
| |
| img_outputs = image_encoder(pixel_values=images.to(device)) |
| img_feat = img_outputs.last_hidden_state |
|
|
| |
| task_logits = qtype_classifier(input_ids=input_ids.to(device), |
| attention_mask=attention_mask.to(device)) |
|
|
| |
| q_feat = question_encoder(input_ids=input_ids.to(device), |
| attention_mask=attention_mask.to(device)).pooler_output |
|
|
| |
| disease_vec = disease_model(images.to(device)) |
|
|
| |
| fused = fusion_module(img_feat, q_feat, disease_vec) |
|
|
| |
| preds = [] |
| for i, q_class in enumerate(question_classes): |
| mapped_type = q_types_mapping[q_class[0] if isinstance(q_class, list) else q_class] |
| predictor = task_heads[mapped_type] |
| pred_out = predictor(fused[i].unsqueeze(0)) |
| preds.append(pred_out) |
| |
| |
| |
|
|
| return preds, answers, task_logits |
|
|
|
|
| def forward_batch1(images, input_ids, attention_mask, answers, true_q_classes=None,qtype_classifier=None,fusion_module=None,q_types=None): |
| |
| disease_vec = disease_model(images) |
|
|
| |
| img_outputs = image_encoder(pixel_values=images.to(device)) |
| img_feat = img_outputs.last_hidden_state |
|
|
| |
| q_feat = question_encoder(input_ids=input_ids.to(device), |
| attention_mask=attention_mask.to(device)).pooler_output |
|
|
| |
| |
| |
| |
| task_logits = qtype_classifier(input_ids=batch["input_ids"], |
| attention_mask=batch["attention_mask"]) |
| task_pred = torch.argmax(task_logits, dim=1) |
| |
| |
| fused = fusion_module(img_feat, q_feat, disease_vec) |
|
|
| |
| preds = [] |
| for i, t_idx in enumerate(task_pred): |
| task_type = q_types[t_idx] |
| predictor = TaskPredictor(task_type).to(device) |
| preds.append(predictor(fused[i].unsqueeze(0))) |
|
|
| return preds, answers, task_logits |
| |
| |
| |
| |
| |
| |
|
|
|
|
| def extract_count(answer_str): |
| """ |
| Try to convert an answer string into a number. |
| Returns None if it cannot be parsed. |
| """ |
| try: |
| |
| return float(answer_str) |
| except ValueError: |
| pass |
|
|
| |
| word2num = { |
| "zero": 0, "one": 1, "two": 2, "three": 3, |
| "four": 4, "five": 5, "six": 6, |
| "seven": 7, "eight": 8, "nine": 9, "ten": 10 |
| } |
| tokens = answer_str.lower().split() |
| for t in tokens: |
| if t in word2num: |
| return float(word2num[t]) |
|
|
| |
| numbers = re.findall(r"\d+", answer_str) |
| if numbers: |
| return float(numbers[0]) |
|
|
| return None |
|
|
|
|
| def compute_meteor(preds, answers, answer_vocabs, mapped_classes): |
| scores = [] |
| for pred, ans, c in zip(preds, answers, mapped_classes): |
| if c not in answer_vocabs: |
| continue |
| |
| pred_idx = pred.argmax(dim=1).item() |
| |
| inv_vocab = {v: k for k, v in answer_vocabs[c].items()} |
| pred_str = inv_vocab.get(pred_idx, "") |
| |
| score = meteor_score([ans.split()], pred_str.split()) |
| scores.append(score) |
| return sum(scores) / len(scores) if scores else 0.0 |
|
|
|
|
| def compute_rouge(preds, answers, answer_vocabs, mapped_classes): |
| scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=True) |
| scores = [] |
| for pred, ans, c in zip(preds, answers, mapped_classes): |
| if c not in answer_vocabs: |
| continue |
| pred_idx = pred.argmax(dim=1).item() |
| inv_vocab = {v: k for k, v in answer_vocabs[c].items()} |
| pred_str = inv_vocab.get(pred_idx, "") |
| score = scorer.score(ans, pred_str)["rougeL"].fmeasure |
| scores.append(score) |
| return sum(scores) / len(scores) if scores else 0.0 |
|
|
| def compute_loss(preds, answers, task_logits, true_q_classes, answer_vocabs,q_types_mapping,q_types,task_heads): |
| """ |
| preds: list of model predictions for each sample |
| answers: list of strings (descriptive answers) |
| task_logits: tensor [batch_size, num_task_types] |
| true_q_classes: list of lists (fine-grained classes for each question) |
| answer_vocabs: dict mapping {q_type: {answer: index}} |
| """ |
|
|
| ce_loss = CrossEntropyLoss() |
| mse_loss = MSELoss() |
|
|
| total_loss = 0 |
|
|
| |
| mapped_classes = [ |
| q_types_mapping[c[0] if isinstance(c, list) else c] |
| for c in true_q_classes |
| ] |
|
|
| |
| true_task_types = torch.tensor( |
| [q_types.index(c) for c in mapped_classes], |
| device=task_logits.device |
| ) |
|
|
| |
| |
| |
| |
| task_loss = ce_loss(task_logits, true_task_types) |
| total_loss += task_loss |
|
|
| |
| for pred, ans, c in zip(preds, answers, mapped_classes): |
| predictor = task_heads[c] |
| if c == "count": |
| |
| try: |
| ans_val = float(ans) |
| ans_val = torch.tensor([ans_val], device=pred.device) |
| total_loss += mse_loss(pred.squeeze(), ans_val) |
| except ValueError: |
| print(f"[Warning] Skipping non-numeric count answer: {ans}") |
| continue |
|
|
| else: |
| |
| if ans not in answer_vocabs.get(c, {}): |
| print(f"[Warning] Skipping unseen or descriptive answer {ans} for task {c}") |
| continue |
|
|
| ans_idx = answer_vocabs[c][ans] |
|
|
| if ans_idx >= pred.size(1): |
| print(f"[Warning] Skipping answer {ans} for task {c}: " |
| f"index {ans_idx} >= pred.size(1)") |
| continue |
|
|
| ans_tensor = torch.tensor([ans_idx], device=pred.device) |
| total_loss += ce_loss(pred, ans_tensor) |
|
|
| meteor = compute_meteor(preds, answers, answer_vocabs, mapped_classes) |
| print(f"Validation METEOR: {meteor:.4f}") |
| |
| |
|
|
| |
| return total_loss / len(preds) |