medico2025 / functions.py
alvikhan's picture
restructured files
62305fe
from PIL import Image
import requests
import torch
import torchvision.transforms as transforms
transformten = transforms.Compose([
transforms.Resize((224, 224)), # adjust size for your model
transforms.ToTensor(), # convert to tensor
transforms.Normalize(mean=[0.485, 0.456, 0.406], # ImageNet normalization
std=[0.229, 0.224, 0.225])
])
from collections import defaultdict
from torch.utils.data import DataLoader
import os
from transformers import AutoTokenizer
image_cache = {}
def preprocess_image(image_source):
"""
Preprocess a single image for inference.
`image_source` can be either a URL or a local file path.
Returns a tensor [C, H, W].
"""
if isinstance(image_source, str):
if image_source.startswith("http"): # URL
image = Image.open(requests.get(image_source, stream=True).raw).convert("RGB")
else: # local path
image = Image.open(image_source).convert("RGB")
elif isinstance(image_source, Image.Image): # already a PIL image
image = image_source
else:
raise ValueError("Unsupported image_source type")
# Apply the same transform used during training
image = transformten(image) # e.g. Resize(224) → ToTensor() → Normalize()
return image # torch.Tensor [3, H, W]
def preprocess_example(example):
# Download image
#image = Image.open(requests.get(example["image"], stream=True).raw).convert("RGB")
router_tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
#Image from dataset
image_name = example["image"].split("/")[-1]
image_path = os.path.join("/kaggle/input/medico2025", image_name)
# 2. Check if the image is already in our cache
if image_path in image_cache:
image = image_cache[image_path]
else:
image = Image.open(image_path)
if image.mode != 'RGB':
image = image.convert('RGB')
image_cache[image_path] = image # Cache the loaded image object
# Apply your normalize/transform method
image = transformten(image) # e.g. Resize + ToTensor + Normalize
#print("DEBUG image:", type(image), image.shape)
# Tokenize the question
q_inputs = router_tokenizer(example["question"],
return_tensors="pt",
truncation=True,
padding="max_length",
max_length=32)
# q_inputs is a BatchEncoding with tensors inside (batch_size=1), so we squeeze
input_ids = q_inputs["input_ids"].squeeze(0) # torch.Tensor [seq_len]
attention_mask = q_inputs["attention_mask"].squeeze(0)
# Pack features
return {
"image": image,
"input_ids": input_ids,
"attention_mask": attention_mask,
"answer": example["answer"],
"question_class": example["question_class"],
"image_url": example["image"],
}
def normalize_answer(ans, q_type):
ans = ans.strip().lower()
if q_type == "yesno":
if "yes" in ans or "present" in ans or "evidence" in ans:
return "Yes"
elif "no" in ans or "absent" in ans or "none" in ans:
return "No"
else:
return None # ambiguous
if q_type == "count":
# Extract numeric value or return None
from re import findall
numbers = findall(r"\d+", ans)
if numbers:
return numbers[0]
elif "one" in ans: return "1"
elif "two" in ans: return "2"
return None
if q_type == "color":
for color in ["red","green","yellow","blue","white","black"]:
if color in ans:
return color
return None
if q_type == "location":
# Simplify locations to a small fixed set
for loc in ["upper","lower","left","right","central"]:
if loc in ans:
return loc
return None
if q_type in ["single","multi"]:
return ans # keep original but can also restrict choices
return ans
def build_vocabs(dataset,q_types_mapping):
# Build task-specific vocabularies
task_vocabs = {}
for general_class in set(q_types_mapping.values()):
task_vocabs[general_class] = {}
for row in dataset:
fine_class = row["question_class"]
# ✅ Handle if fine_class is a list
if isinstance(fine_class, list):
fine_class = fine_class[0]
general_class = q_types_mapping[fine_class]
norm_ans = normalize_answer(row["answer"], general_class)
if norm_ans is None:
continue # skip unnormalizable answers
if norm_ans not in task_vocabs[general_class]:
idx = len(task_vocabs[general_class])
task_vocabs[general_class][norm_ans] = idx
return task_vocabs
def build_answer_vocab(dataset, q_types_mapping):
answer_vocab = defaultdict(dict)
counters = defaultdict(int)
for ans, q_class in zip(dataset["answer"], dataset["question_class"]):
# q_class might be a list; pick the first (if multiple labels)
if isinstance(q_class, list):
q_class = q_class[0]
general_class = q_types_mapping[q_class]
if ans not in answer_vocab[general_class]:
answer_vocab[general_class][ans] = counters[general_class]
counters[general_class] += 1
return answer_vocab
def collate_fn(batch):
#print(type(batch[0]["image"]))
#images = torch.stack([item["image"] for item in batch])
images = torch.stack([torch.tensor(item["image"]) if isinstance(item["image"], list) else item["image"] for item in batch])
#print(type(images), images.shape)
input_ids = torch.stack([torch.tensor(item["input_ids"]) if isinstance(item["input_ids"], list) else item["input_ids"] for item in batch])
attention_mask = torch.stack([torch.tensor(item["attention_mask"]) if isinstance(item["attention_mask"], list) else item["attention_mask"] for item in batch])
#input_ids = torch.stack([item["input_ids"] for item in batch])
#attention_mask = torch.stack([item["attention_mask"] for item in batch])
answers = [item["answer"] for item in batch] # keep as list for label encoding later
q_classes = [item["question_class"] for item in batch]
return {
"images": images,
"input_ids": input_ids,
"attention_mask": attention_mask,
"answers": answers,
"question_classes": q_classes,
}