|
|
import torch |
|
|
from torch.nn import CrossEntropyLoss, MSELoss |
|
|
import re |
|
|
from models import disease_model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from nltk.translate.meteor_score import meteor_score |
|
|
|
|
|
|
|
|
def forward_batch(images, input_ids, attention_mask, answers, question_classes=None,qtype_classifier=None,fusion_module=None,q_types=None,q_types_mapping=None,task_heads=None,device=None,image_encoder=None,question_encoder=None): |
|
|
|
|
|
img_outputs = image_encoder(pixel_values=images.to(device)) |
|
|
img_feat = img_outputs.last_hidden_state |
|
|
|
|
|
|
|
|
task_logits = qtype_classifier(input_ids=input_ids.to(device), |
|
|
attention_mask=attention_mask.to(device)) |
|
|
|
|
|
|
|
|
q_feat = question_encoder(input_ids=input_ids.to(device), |
|
|
attention_mask=attention_mask.to(device)).pooler_output |
|
|
|
|
|
|
|
|
disease_vec = disease_model(images.to(device)) |
|
|
|
|
|
|
|
|
fused = fusion_module(img_feat, q_feat, disease_vec) |
|
|
|
|
|
|
|
|
preds = [] |
|
|
for i, q_class in enumerate(question_classes): |
|
|
mapped_type = q_types_mapping[q_class[0] if isinstance(q_class, list) else q_class] |
|
|
predictor = task_heads[mapped_type] |
|
|
pred_out = predictor(fused[i].unsqueeze(0)) |
|
|
preds.append(pred_out) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return preds, answers, task_logits |
|
|
|
|
|
|
|
|
def forward_batch1(images, input_ids, attention_mask, answers, true_q_classes=None,qtype_classifier=None,fusion_module=None,q_types=None): |
|
|
|
|
|
disease_vec = disease_model(images) |
|
|
|
|
|
|
|
|
img_outputs = image_encoder(pixel_values=images.to(device)) |
|
|
img_feat = img_outputs.last_hidden_state |
|
|
|
|
|
|
|
|
q_feat = question_encoder(input_ids=input_ids.to(device), |
|
|
attention_mask=attention_mask.to(device)).pooler_output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
task_logits = qtype_classifier(input_ids=batch["input_ids"], |
|
|
attention_mask=batch["attention_mask"]) |
|
|
task_pred = torch.argmax(task_logits, dim=1) |
|
|
|
|
|
|
|
|
fused = fusion_module(img_feat, q_feat, disease_vec) |
|
|
|
|
|
|
|
|
preds = [] |
|
|
for i, t_idx in enumerate(task_pred): |
|
|
task_type = q_types[t_idx] |
|
|
predictor = TaskPredictor(task_type).to(device) |
|
|
preds.append(predictor(fused[i].unsqueeze(0))) |
|
|
|
|
|
return preds, answers, task_logits |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def extract_count(answer_str): |
|
|
""" |
|
|
Try to convert an answer string into a number. |
|
|
Returns None if it cannot be parsed. |
|
|
""" |
|
|
try: |
|
|
|
|
|
return float(answer_str) |
|
|
except ValueError: |
|
|
pass |
|
|
|
|
|
|
|
|
word2num = { |
|
|
"zero": 0, "one": 1, "two": 2, "three": 3, |
|
|
"four": 4, "five": 5, "six": 6, |
|
|
"seven": 7, "eight": 8, "nine": 9, "ten": 10 |
|
|
} |
|
|
tokens = answer_str.lower().split() |
|
|
for t in tokens: |
|
|
if t in word2num: |
|
|
return float(word2num[t]) |
|
|
|
|
|
|
|
|
numbers = re.findall(r"\d+", answer_str) |
|
|
if numbers: |
|
|
return float(numbers[0]) |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
def compute_meteor(preds, answers, answer_vocabs, mapped_classes): |
|
|
scores = [] |
|
|
for pred, ans, c in zip(preds, answers, mapped_classes): |
|
|
if c not in answer_vocabs: |
|
|
continue |
|
|
|
|
|
pred_idx = pred.argmax(dim=1).item() |
|
|
|
|
|
inv_vocab = {v: k for k, v in answer_vocabs[c].items()} |
|
|
pred_str = inv_vocab.get(pred_idx, "") |
|
|
|
|
|
score = meteor_score([ans.split()], pred_str.split()) |
|
|
scores.append(score) |
|
|
return sum(scores) / len(scores) if scores else 0.0 |
|
|
|
|
|
|
|
|
def compute_rouge(preds, answers, answer_vocabs, mapped_classes): |
|
|
scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=True) |
|
|
scores = [] |
|
|
for pred, ans, c in zip(preds, answers, mapped_classes): |
|
|
if c not in answer_vocabs: |
|
|
continue |
|
|
pred_idx = pred.argmax(dim=1).item() |
|
|
inv_vocab = {v: k for k, v in answer_vocabs[c].items()} |
|
|
pred_str = inv_vocab.get(pred_idx, "") |
|
|
score = scorer.score(ans, pred_str)["rougeL"].fmeasure |
|
|
scores.append(score) |
|
|
return sum(scores) / len(scores) if scores else 0.0 |
|
|
|
|
|
def compute_loss(preds, answers, task_logits, true_q_classes, answer_vocabs,q_types_mapping,q_types,task_heads): |
|
|
""" |
|
|
preds: list of model predictions for each sample |
|
|
answers: list of strings (descriptive answers) |
|
|
task_logits: tensor [batch_size, num_task_types] |
|
|
true_q_classes: list of lists (fine-grained classes for each question) |
|
|
answer_vocabs: dict mapping {q_type: {answer: index}} |
|
|
""" |
|
|
|
|
|
ce_loss = CrossEntropyLoss() |
|
|
mse_loss = MSELoss() |
|
|
|
|
|
total_loss = 0 |
|
|
|
|
|
|
|
|
mapped_classes = [ |
|
|
q_types_mapping[c[0] if isinstance(c, list) else c] |
|
|
for c in true_q_classes |
|
|
] |
|
|
|
|
|
|
|
|
true_task_types = torch.tensor( |
|
|
[q_types.index(c) for c in mapped_classes], |
|
|
device=task_logits.device |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
task_loss = ce_loss(task_logits, true_task_types) |
|
|
total_loss += task_loss |
|
|
|
|
|
|
|
|
for pred, ans, c in zip(preds, answers, mapped_classes): |
|
|
predictor = task_heads[c] |
|
|
if c == "count": |
|
|
|
|
|
try: |
|
|
ans_val = float(ans) |
|
|
ans_val = torch.tensor([ans_val], device=pred.device) |
|
|
total_loss += mse_loss(pred.squeeze(), ans_val) |
|
|
except ValueError: |
|
|
print(f"[Warning] Skipping non-numeric count answer: {ans}") |
|
|
continue |
|
|
|
|
|
else: |
|
|
|
|
|
if ans not in answer_vocabs.get(c, {}): |
|
|
print(f"[Warning] Skipping unseen or descriptive answer {ans} for task {c}") |
|
|
continue |
|
|
|
|
|
ans_idx = answer_vocabs[c][ans] |
|
|
|
|
|
if ans_idx >= pred.size(1): |
|
|
print(f"[Warning] Skipping answer {ans} for task {c}: " |
|
|
f"index {ans_idx} >= pred.size(1)") |
|
|
continue |
|
|
|
|
|
ans_tensor = torch.tensor([ans_idx], device=pred.device) |
|
|
total_loss += ce_loss(pred, ans_tensor) |
|
|
|
|
|
meteor = compute_meteor(preds, answers, answer_vocabs, mapped_classes) |
|
|
print(f"Validation METEOR: {meteor:.4f}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return total_loss / len(preds) |