PneumoniaAPI / src /predict.py
GitHub Actions
Auto-deploy from GitHub: 495db78a06be79166200269bb14d9e9b1e8906d6
af59988
"""
Inference functions for Pneumonia classification.
"""
import torch
import torch.nn as nn
from PIL import Image
from pathlib import Path
from typing import Union, Tuple
from .dataset import get_transforms
from .config import CLASS_NAMES, CHECKPOINT_PATH
def load_model(model: nn.Module, checkpoint_path: Path = CHECKPOINT_PATH, device: str = "cpu") -> nn.Module:
"""Load model from checkpoint."""
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
return model
def predict_image(
model: nn.Module,
image: Union[str, Path, Image.Image],
device: torch.device
) -> Tuple[str, float]:
"""Predict class for a single image."""
model.eval()
# Load image if path
if isinstance(image, (str, Path)):
image = Image.open(image).convert('RGB')
# Transform
transform = get_transforms(is_training=False)
img_tensor = transform(image).unsqueeze(0).to(device)
# Predict
with torch.no_grad():
output = model(img_tensor)
prob = torch.sigmoid(output).item()
pred_class = CLASS_NAMES[1] if prob > 0.5 else CLASS_NAMES[0]
confidence = prob if prob > 0.5 else 1 - prob
return pred_class, confidence