PAIR / utils.py
fantos's picture
Upload 5 files
2a034f9 verified
import io
import os
import math
import datetime as dt
from typing import List, Tuple
import numpy as np
from PIL import Image
import torch
import pydicom
IMG_SIZE = 224
def _read_image(file_path: str) -> Image.Image:
ext = os.path.splitext(file_path)[1].lower()
if ext == ".dcm":
ds = pydicom.dcmread(file_path)
arr = ds.pixel_array.astype(np.float32)
# normalize to 0..255
arr = arr - arr.min()
if arr.max() > 0:
arr = arr / arr.max()
arr = (arr * 255.0).clip(0,255).astype(np.uint8)
return Image.fromarray(arr)
else:
return Image.open(file_path).convert("L") # grayscale
def _to_tensor(img: Image.Image) -> torch.Tensor:
# Resize, center-crop/pad to square, stack to 3 channels, normalize 0..1
img = img.resize((IMG_SIZE, IMG_SIZE))
arr = np.array(img).astype(np.float32) / 255.0
if arr.ndim == 2:
arr = np.stack([arr, arr, arr], axis=0) # 3xHxW
elif arr.ndim == 3:
arr = arr.transpose(2, 0, 1) # HWC -> CHW
return torch.from_numpy(arr)
def load_exam_as_batch(file_paths: List[str]) -> torch.Tensor:
imgs = [_to_tensor(_read_image(p)) for p in file_paths]
x = torch.stack(imgs, dim=0) # Nx3xHxW
return x
def aggregate_predictions(days_list: List[float], proba_list: List[float]) -> Tuple[float, float]:
if len(days_list) == 0:
return 0.0, 0.0
return float(np.mean(days_list)), float(np.mean(proba_list))
def clamp_days(d: float) -> float:
return float(max(1.0, min(300.0, d)))
def today_plus_days(days: float) -> str:
base = dt.date.today()
target = base + dt.timedelta(days=int(round(days)))
return target.isoformat()