Spaces:
Sleeping
Sleeping
File size: 3,605 Bytes
52417ed | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 | """High-level predictor for disaster building damage assessment."""
import logging
from dataclasses import dataclass
from pathlib import Path
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
from .config import InferenceConfig
from .model import build_model, load_checkpoint
logger = logging.getLogger(__name__)
CLASS_NAMES: list[str] = [
"被害なし",
"E1(地震大)",
"E2(地震中)",
"E3(地震小)",
"T1(津波大)",
"T3(津波小)",
]
REJECTION_THRESHOLD = 0.5
@dataclass(frozen=True)
class PredictionResult:
"""Immutable container for a single prediction."""
class_id: int
class_name: str
confidence: float
probabilities: list[float]
rejected: bool
class Predictor:
"""Singleton predictor -- load once, infer many times.
Usage::
predictor = Predictor()
predictor.initialize(checkpoint_dir=Path("..."), device="cuda")
result = predictor.predict(image)
"""
_instance: "Predictor | None" = None
def __new__(cls) -> "Predictor":
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def initialize(self, checkpoint_dir: Path, device: str = "cuda") -> None:
"""Load model and prepare transforms.
Parameters
----------
checkpoint_dir : Path
Directory containing ``best_model.pth``.
device : str
Target device (``"cuda"`` or ``"cpu"``).
"""
if self._initialized:
logger.info("Predictor already initialized, skipping.")
return
config = InferenceConfig()
model = build_model(config, device)
checkpoint_path = checkpoint_dir / "best_model.pth"
self.model = load_checkpoint(model, checkpoint_path, device)
self.device = device
self.transform = transforms.Compose(
[
transforms.Resize(570),
transforms.CenterCrop(518),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
]
)
self._initialized = True
logger.info("Predictor initialized on %s", device)
def predict(self, image: Image.Image) -> PredictionResult:
"""Run inference on a single PIL image.
Parameters
----------
image : Image.Image
Input image (any mode -- will be converted to RGB).
Returns
-------
PredictionResult
Prediction with class, confidence, and rejection flag.
Raises
------
RuntimeError
If :meth:`initialize` has not been called.
"""
if not self._initialized:
raise RuntimeError("Predictor not initialized. Call initialize() first.")
tensor = self.transform(image.convert("RGB")).unsqueeze(0).to(self.device)
with torch.no_grad():
outputs = self.model(tensor)
probs = F.softmax(outputs["full"], dim=-1).squeeze(0).cpu().tolist()
class_id = int(torch.argmax(torch.tensor(probs)).item())
confidence = probs[class_id]
rejected = confidence < REJECTION_THRESHOLD
return PredictionResult(
class_id=class_id,
class_name=CLASS_NAMES[class_id],
confidence=confidence,
probabilities=probs,
rejected=rejected,
)
|