| from pathlib import Path |
|
|
| import torch |
| from PIL import Image |
|
|
| from train import build_model, build_rgb_transform |
|
|
|
|
| CLASS_NAMES = [ |
| "Annual Crop", |
| "Forest", |
| "Herbaceous Vegetation", |
| "Highway", |
| "Industrial Buildings", |
| "Pasture", |
| "Permanent Crop", |
| "Residential Buildings", |
| "River", |
| "SeaLake", |
| ] |
|
|
| DEFAULT_CHECKPOINT_PATH = Path("weights/rgb_e15_best.pt") |
|
|
|
|
| def load_rgb_model(checkpoint_path: str | Path = DEFAULT_CHECKPOINT_PATH) -> torch.nn.Module: |
| """Load the EuroSAT-RGB ResNet-50 checkpoint for CPU inference.""" |
| checkpoint_path = Path(checkpoint_path) |
| if not checkpoint_path.exists(): |
| raise FileNotFoundError( |
| f"RGB checkpoint not found at {checkpoint_path}. " |
| "Add weights/rgb_e15_best.pt before running the demo." |
| ) |
|
|
| device = torch.device("cpu") |
| model = build_model(num_classes=len(CLASS_NAMES), device=device, in_channels=3) |
| checkpoint = torch.load(checkpoint_path, map_location=device) |
| model.load_state_dict(checkpoint["model_state_dict"]) |
| model.eval() |
| return model |
|
|
|
|
| @torch.no_grad() |
| def predict_topk( |
| model: torch.nn.Module, image: Image.Image, top_k: int = 3 |
| ) -> list[tuple[str, float]]: |
| """Run RGB inference and return class names with probabilities.""" |
| transform = build_rgb_transform(train=False) |
| tensor = transform(image.convert("RGB")).unsqueeze(0) |
| logits = model(tensor) |
| probs = torch.softmax(logits, dim=1).squeeze(0) |
| top_probs, top_indices = torch.topk(probs, k=top_k) |
| return [ |
| (CLASS_NAMES[int(class_idx)], float(prob)) |
| for prob, class_idx in zip(top_probs, top_indices, strict=True) |
| ] |
|
|