File size: 1,705 Bytes
2a034f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
import io
import os
import math
import datetime as dt
from typing import List, Tuple
import numpy as np
from PIL import Image
import torch
import pydicom
IMG_SIZE = 224
def _read_image(file_path: str) -> Image.Image:
ext = os.path.splitext(file_path)[1].lower()
if ext == ".dcm":
ds = pydicom.dcmread(file_path)
arr = ds.pixel_array.astype(np.float32)
# normalize to 0..255
arr = arr - arr.min()
if arr.max() > 0:
arr = arr / arr.max()
arr = (arr * 255.0).clip(0,255).astype(np.uint8)
return Image.fromarray(arr)
else:
return Image.open(file_path).convert("L") # grayscale
def _to_tensor(img: Image.Image) -> torch.Tensor:
# Resize, center-crop/pad to square, stack to 3 channels, normalize 0..1
img = img.resize((IMG_SIZE, IMG_SIZE))
arr = np.array(img).astype(np.float32) / 255.0
if arr.ndim == 2:
arr = np.stack([arr, arr, arr], axis=0) # 3xHxW
elif arr.ndim == 3:
arr = arr.transpose(2, 0, 1) # HWC -> CHW
return torch.from_numpy(arr)
def load_exam_as_batch(file_paths: List[str]) -> torch.Tensor:
imgs = [_to_tensor(_read_image(p)) for p in file_paths]
x = torch.stack(imgs, dim=0) # Nx3xHxW
return x
def aggregate_predictions(days_list: List[float], proba_list: List[float]) -> Tuple[float, float]:
if len(days_list) == 0:
return 0.0, 0.0
return float(np.mean(days_list)), float(np.mean(proba_list))
def clamp_days(d: float) -> float:
return float(max(1.0, min(300.0, d)))
def today_plus_days(days: float) -> str:
base = dt.date.today()
target = base + dt.timedelta(days=int(round(days)))
return target.isoformat()
|