File size: 1,272 Bytes
af59988
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
"""
Inference functions for Pneumonia classification.
"""

import torch
import torch.nn as nn
from PIL import Image
from pathlib import Path
from typing import Union, Tuple

from .dataset import get_transforms
from .config import CLASS_NAMES, CHECKPOINT_PATH


def load_model(model: nn.Module, checkpoint_path: Path = CHECKPOINT_PATH, device: str = "cpu") -> nn.Module:
    """Load model from checkpoint."""
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    return model


def predict_image(
    model: nn.Module,
    image: Union[str, Path, Image.Image],
    device: torch.device
) -> Tuple[str, float]:
    """Predict class for a single image."""
    model.eval()

    # Load image if path
    if isinstance(image, (str, Path)):
        image = Image.open(image).convert('RGB')

    # Transform
    transform = get_transforms(is_training=False)
    img_tensor = transform(image).unsqueeze(0).to(device)

    # Predict
    with torch.no_grad():
        output = model(img_tensor)
        prob = torch.sigmoid(output).item()

    pred_class = CLASS_NAMES[1] if prob > 0.5 else CLASS_NAMES[0]
    confidence = prob if prob > 0.5 else 1 - prob

    return pred_class, confidence