Spaces:
Sleeping
Sleeping
File size: 1,272 Bytes
af59988 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
"""
Inference functions for Pneumonia classification.
"""
import torch
import torch.nn as nn
from PIL import Image
from pathlib import Path
from typing import Union, Tuple
from .dataset import get_transforms
from .config import CLASS_NAMES, CHECKPOINT_PATH
def load_model(model: nn.Module, checkpoint_path: Path = CHECKPOINT_PATH, device: str = "cpu") -> nn.Module:
"""Load model from checkpoint."""
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
return model
def predict_image(
model: nn.Module,
image: Union[str, Path, Image.Image],
device: torch.device
) -> Tuple[str, float]:
"""Predict class for a single image."""
model.eval()
# Load image if path
if isinstance(image, (str, Path)):
image = Image.open(image).convert('RGB')
# Transform
transform = get_transforms(is_training=False)
img_tensor = transform(image).unsqueeze(0).to(device)
# Predict
with torch.no_grad():
output = model(img_tensor)
prob = torch.sigmoid(output).item()
pred_class = CLASS_NAMES[1] if prob > 0.5 else CLASS_NAMES[0]
confidence = prob if prob > 0.5 else 1 - prob
return pred_class, confidence
|