File size: 1,705 Bytes
2a034f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import io
import os
import math
import datetime as dt
from typing import List, Tuple

import numpy as np
from PIL import Image
import torch
import pydicom

IMG_SIZE = 224

def _read_image(file_path: str) -> Image.Image:
    ext = os.path.splitext(file_path)[1].lower()
    if ext == ".dcm":
        ds = pydicom.dcmread(file_path)
        arr = ds.pixel_array.astype(np.float32)
        # normalize to 0..255
        arr = arr - arr.min()
        if arr.max() > 0:
            arr = arr / arr.max()
        arr = (arr * 255.0).clip(0,255).astype(np.uint8)
        return Image.fromarray(arr)
    else:
        return Image.open(file_path).convert("L")  # grayscale

def _to_tensor(img: Image.Image) -> torch.Tensor:
    # Resize, center-crop/pad to square, stack to 3 channels, normalize 0..1
    img = img.resize((IMG_SIZE, IMG_SIZE))
    arr = np.array(img).astype(np.float32) / 255.0
    if arr.ndim == 2:
        arr = np.stack([arr, arr, arr], axis=0)  # 3xHxW
    elif arr.ndim == 3:
        arr = arr.transpose(2, 0, 1)  # HWC -> CHW
    return torch.from_numpy(arr)

def load_exam_as_batch(file_paths: List[str]) -> torch.Tensor:
    imgs = [_to_tensor(_read_image(p)) for p in file_paths]
    x = torch.stack(imgs, dim=0)  # Nx3xHxW
    return x

def aggregate_predictions(days_list: List[float], proba_list: List[float]) -> Tuple[float, float]:
    if len(days_list) == 0:
        return 0.0, 0.0
    return float(np.mean(days_list)), float(np.mean(proba_list))

def clamp_days(d: float) -> float:
    return float(max(1.0, min(300.0, d)))

def today_plus_days(days: float) -> str:
    base = dt.date.today()
    target = base + dt.timedelta(days=int(round(days)))
    return target.isoformat()