|
|
import io |
|
|
import os |
|
|
import math |
|
|
import datetime as dt |
|
|
from typing import List, Tuple |
|
|
|
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import torch |
|
|
import pydicom |
|
|
|
|
|
IMG_SIZE = 224 |
|
|
|
|
|
def _read_image(file_path: str) -> Image.Image: |
|
|
ext = os.path.splitext(file_path)[1].lower() |
|
|
if ext == ".dcm": |
|
|
ds = pydicom.dcmread(file_path) |
|
|
arr = ds.pixel_array.astype(np.float32) |
|
|
|
|
|
arr = arr - arr.min() |
|
|
if arr.max() > 0: |
|
|
arr = arr / arr.max() |
|
|
arr = (arr * 255.0).clip(0,255).astype(np.uint8) |
|
|
return Image.fromarray(arr) |
|
|
else: |
|
|
return Image.open(file_path).convert("L") |
|
|
|
|
|
def _to_tensor(img: Image.Image) -> torch.Tensor: |
|
|
|
|
|
img = img.resize((IMG_SIZE, IMG_SIZE)) |
|
|
arr = np.array(img).astype(np.float32) / 255.0 |
|
|
if arr.ndim == 2: |
|
|
arr = np.stack([arr, arr, arr], axis=0) |
|
|
elif arr.ndim == 3: |
|
|
arr = arr.transpose(2, 0, 1) |
|
|
return torch.from_numpy(arr) |
|
|
|
|
|
def load_exam_as_batch(file_paths: List[str]) -> torch.Tensor: |
|
|
imgs = [_to_tensor(_read_image(p)) for p in file_paths] |
|
|
x = torch.stack(imgs, dim=0) |
|
|
return x |
|
|
|
|
|
def aggregate_predictions(days_list: List[float], proba_list: List[float]) -> Tuple[float, float]: |
|
|
if len(days_list) == 0: |
|
|
return 0.0, 0.0 |
|
|
return float(np.mean(days_list)), float(np.mean(proba_list)) |
|
|
|
|
|
def clamp_days(d: float) -> float: |
|
|
return float(max(1.0, min(300.0, d))) |
|
|
|
|
|
def today_plus_days(days: float) -> str: |
|
|
base = dt.date.today() |
|
|
target = base + dt.timedelta(days=int(round(days))) |
|
|
return target.isoformat() |
|
|
|