davidlsan's picture
Add Streamlit app source and RGB model weights
9d33171 verified
from pathlib import Path
import torch
from PIL import Image
from train import build_model, build_rgb_transform
CLASS_NAMES = [
"Annual Crop",
"Forest",
"Herbaceous Vegetation",
"Highway",
"Industrial Buildings",
"Pasture",
"Permanent Crop",
"Residential Buildings",
"River",
"SeaLake",
]
DEFAULT_CHECKPOINT_PATH = Path("weights/rgb_e15_best.pt")
def load_rgb_model(checkpoint_path: str | Path = DEFAULT_CHECKPOINT_PATH) -> torch.nn.Module:
"""Load the EuroSAT-RGB ResNet-50 checkpoint for CPU inference."""
checkpoint_path = Path(checkpoint_path)
if not checkpoint_path.exists():
raise FileNotFoundError(
f"RGB checkpoint not found at {checkpoint_path}. "
"Add weights/rgb_e15_best.pt before running the demo."
)
device = torch.device("cpu")
model = build_model(num_classes=len(CLASS_NAMES), device=device, in_channels=3)
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
return model
@torch.no_grad()
def predict_topk(
model: torch.nn.Module, image: Image.Image, top_k: int = 3
) -> list[tuple[str, float]]:
"""Run RGB inference and return class names with probabilities."""
transform = build_rgb_transform(train=False)
tensor = transform(image.convert("RGB")).unsqueeze(0)
logits = model(tensor)
probs = torch.softmax(logits, dim=1).squeeze(0)
top_probs, top_indices = torch.topk(probs, k=top_k)
return [
(CLASS_NAMES[int(class_idx)], float(prob))
for prob, class_idx in zip(top_probs, top_indices, strict=True)
]