Spaces:
Sleeping
Sleeping
| """ | |
| Inference functions for Pneumonia classification. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| from PIL import Image | |
| from pathlib import Path | |
| from typing import Union, Tuple | |
| from .dataset import get_transforms | |
| from .config import CLASS_NAMES, CHECKPOINT_PATH | |
| def load_model(model: nn.Module, checkpoint_path: Path = CHECKPOINT_PATH, device: str = "cpu") -> nn.Module: | |
| """Load model from checkpoint.""" | |
| checkpoint = torch.load(checkpoint_path, map_location=device) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| model.eval() | |
| return model | |
| def predict_image( | |
| model: nn.Module, | |
| image: Union[str, Path, Image.Image], | |
| device: torch.device | |
| ) -> Tuple[str, float]: | |
| """Predict class for a single image.""" | |
| model.eval() | |
| # Load image if path | |
| if isinstance(image, (str, Path)): | |
| image = Image.open(image).convert('RGB') | |
| # Transform | |
| transform = get_transforms(is_training=False) | |
| img_tensor = transform(image).unsqueeze(0).to(device) | |
| # Predict | |
| with torch.no_grad(): | |
| output = model(img_tensor) | |
| prob = torch.sigmoid(output).item() | |
| pred_class = CLASS_NAMES[1] if prob > 0.5 else CLASS_NAMES[0] | |
| confidence = prob if prob > 0.5 else 1 - prob | |
| return pred_class, confidence | |