|
|
|
|
|
from PIL import Image |
|
|
import requests |
|
|
import torch |
|
|
import torchvision.transforms as transforms |
|
|
transformten = transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
from collections import defaultdict |
|
|
from torch.utils.data import DataLoader |
|
|
import os |
|
|
from transformers import AutoTokenizer |
|
|
|
|
|
image_cache = {} |
|
|
|
|
|
def preprocess_image(image_source): |
|
|
""" |
|
|
Preprocess a single image for inference. |
|
|
`image_source` can be either a URL or a local file path. |
|
|
Returns a tensor [C, H, W]. |
|
|
""" |
|
|
if isinstance(image_source, str): |
|
|
if image_source.startswith("http"): |
|
|
image = Image.open(requests.get(image_source, stream=True).raw).convert("RGB") |
|
|
else: |
|
|
image = Image.open(image_source).convert("RGB") |
|
|
elif isinstance(image_source, Image.Image): |
|
|
image = image_source |
|
|
else: |
|
|
raise ValueError("Unsupported image_source type") |
|
|
|
|
|
|
|
|
image = transformten(image) |
|
|
|
|
|
return image |
|
|
|
|
|
def preprocess_example(example): |
|
|
|
|
|
|
|
|
|
|
|
router_tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") |
|
|
|
|
|
|
|
|
image_name = example["image"].split("/")[-1] |
|
|
image_path = os.path.join("/kaggle/input/medico2025", image_name) |
|
|
|
|
|
|
|
|
if image_path in image_cache: |
|
|
image = image_cache[image_path] |
|
|
|
|
|
else: |
|
|
image = Image.open(image_path) |
|
|
if image.mode != 'RGB': |
|
|
image = image.convert('RGB') |
|
|
image_cache[image_path] = image |
|
|
|
|
|
|
|
|
image = transformten(image) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
q_inputs = router_tokenizer(example["question"], |
|
|
return_tensors="pt", |
|
|
truncation=True, |
|
|
padding="max_length", |
|
|
max_length=32) |
|
|
|
|
|
|
|
|
input_ids = q_inputs["input_ids"].squeeze(0) |
|
|
attention_mask = q_inputs["attention_mask"].squeeze(0) |
|
|
|
|
|
|
|
|
return { |
|
|
"image": image, |
|
|
"input_ids": input_ids, |
|
|
"attention_mask": attention_mask, |
|
|
"answer": example["answer"], |
|
|
"question_class": example["question_class"], |
|
|
"image_url": example["image"], |
|
|
} |
|
|
|
|
|
def normalize_answer(ans, q_type): |
|
|
ans = ans.strip().lower() |
|
|
|
|
|
if q_type == "yesno": |
|
|
if "yes" in ans or "present" in ans or "evidence" in ans: |
|
|
return "Yes" |
|
|
elif "no" in ans or "absent" in ans or "none" in ans: |
|
|
return "No" |
|
|
else: |
|
|
return None |
|
|
|
|
|
if q_type == "count": |
|
|
|
|
|
from re import findall |
|
|
numbers = findall(r"\d+", ans) |
|
|
if numbers: |
|
|
return numbers[0] |
|
|
elif "one" in ans: return "1" |
|
|
elif "two" in ans: return "2" |
|
|
return None |
|
|
|
|
|
if q_type == "color": |
|
|
for color in ["red","green","yellow","blue","white","black"]: |
|
|
if color in ans: |
|
|
return color |
|
|
return None |
|
|
|
|
|
if q_type == "location": |
|
|
|
|
|
for loc in ["upper","lower","left","right","central"]: |
|
|
if loc in ans: |
|
|
return loc |
|
|
return None |
|
|
|
|
|
if q_type in ["single","multi"]: |
|
|
return ans |
|
|
|
|
|
return ans |
|
|
|
|
|
|
|
|
def build_vocabs(dataset,q_types_mapping): |
|
|
|
|
|
task_vocabs = {} |
|
|
for general_class in set(q_types_mapping.values()): |
|
|
task_vocabs[general_class] = {} |
|
|
|
|
|
for row in dataset: |
|
|
fine_class = row["question_class"] |
|
|
|
|
|
|
|
|
if isinstance(fine_class, list): |
|
|
fine_class = fine_class[0] |
|
|
|
|
|
general_class = q_types_mapping[fine_class] |
|
|
|
|
|
norm_ans = normalize_answer(row["answer"], general_class) |
|
|
if norm_ans is None: |
|
|
continue |
|
|
|
|
|
if norm_ans not in task_vocabs[general_class]: |
|
|
idx = len(task_vocabs[general_class]) |
|
|
task_vocabs[general_class][norm_ans] = idx |
|
|
|
|
|
return task_vocabs |
|
|
|
|
|
|
|
|
def build_answer_vocab(dataset, q_types_mapping): |
|
|
answer_vocab = defaultdict(dict) |
|
|
counters = defaultdict(int) |
|
|
|
|
|
for ans, q_class in zip(dataset["answer"], dataset["question_class"]): |
|
|
|
|
|
if isinstance(q_class, list): |
|
|
q_class = q_class[0] |
|
|
|
|
|
general_class = q_types_mapping[q_class] |
|
|
|
|
|
if ans not in answer_vocab[general_class]: |
|
|
answer_vocab[general_class][ans] = counters[general_class] |
|
|
counters[general_class] += 1 |
|
|
|
|
|
return answer_vocab |
|
|
|
|
|
|
|
|
|
|
|
def collate_fn(batch): |
|
|
|
|
|
|
|
|
|
|
|
images = torch.stack([torch.tensor(item["image"]) if isinstance(item["image"], list) else item["image"] for item in batch]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input_ids = torch.stack([torch.tensor(item["input_ids"]) if isinstance(item["input_ids"], list) else item["input_ids"] for item in batch]) |
|
|
attention_mask = torch.stack([torch.tensor(item["attention_mask"]) if isinstance(item["attention_mask"], list) else item["attention_mask"] for item in batch]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
answers = [item["answer"] for item in batch] |
|
|
q_classes = [item["question_class"] for item in batch] |
|
|
return { |
|
|
"images": images, |
|
|
"input_ids": input_ids, |
|
|
"attention_mask": attention_mask, |
|
|
"answers": answers, |
|
|
"question_classes": q_classes, |
|
|
} |
|
|
|
|
|
|