"""GeoSpotter — clasificador de imágenes satelitales con explicación NLP. Proyecto final de Computer Vision, máster MIOTI. """ # ── Monkey-patch para bug de gradio_client en Gradio 5.x ────────────────── # Error: "argument of type 'bool' is not iterable" en _json_schema_to_python_type # cuando additionalProperties es un booleano en lugar de un dict. import gradio_client.utils as _gcu _original = _gcu._json_schema_to_python_type def _patched(schema, defs=None): if isinstance(schema, bool): return "Any" ap = schema.get("additionalProperties") if isinstance(schema, dict) else None if ap is not None and isinstance(ap, bool): schema = {**schema, "additionalProperties": {}} return _original(schema, defs) _gcu._json_schema_to_python_type = _patched # ────────────────────────────────────────────────────────────────────────── import gradio as gr import torch import torch.nn as nn import timm import json from PIL import Image from torchvision import transforms from huggingface_hub import hf_hub_download WEIGHTS_REPO = "JorjoPM/geospotter-convnext-weights" class AdaptiveConcatPool2d(nn.Module): def __init__(self): super().__init__() self.ap = nn.AdaptiveAvgPool2d(1) self.mp = nn.AdaptiveMaxPool2d(1) def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], dim=1) class Flatten(nn.Module): def forward(self, x): return x.view(x.size(0), -1) def build_model(num_classes): backbone = timm.create_model( "convnext_tiny.fb_in22k", pretrained=False, num_classes=0, global_pool="", ) head = nn.Sequential( AdaptiveConcatPool2d(), Flatten(), nn.BatchNorm1d(1536, eps=1e-05, momentum=0.1), nn.Dropout(p=0.25), nn.Linear(1536, 512, bias=False), nn.ReLU(inplace=True), nn.BatchNorm1d(512, eps=1e-05, momentum=0.1), nn.Dropout(p=0.5), nn.Linear(512, num_classes, bias=False), ) return nn.Sequential(backbone, head) print(f"Descargando pesos desde {WEIGHTS_REPO}...") weights_path = hf_hub_download(repo_id=WEIGHTS_REPO, filename="convnext_weights.pth") classes_path = hf_hub_download(repo_id=WEIGHTS_REPO, filename="class_names.json") with open(classes_path) as f: VOCAB = json.load(f) model = build_model(len(VOCAB)) state_dict_raw = torch.load(weights_path, map_location="cpu") state_dict_fixed = {} for k, v in state_dict_raw.items(): new_key = ("0." + k[len("0.model."):]) if k.startswith("0.model.") else k state_dict_fixed[new_key] = v model.load_state_dict(state_dict_fixed, strict=False) model.eval() print(f"Modelo cargado. {len(VOCAB)} clases.") inference_tfms = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) KNOWN_CONFUSIONS = { frozenset(["palace", "church"]): "ambos son edificios monumentales con patios interiores y geometria rectilinea vistos desde arriba", frozenset(["rectangular_farmland", "terrace"]): "ambas clases muestran patrones de paralelogramos vegetales, distinguibles solo por la pendiente del terreno", frozenset(["runway", "airport"]): "una pista de aterrizaje es un componente de un aeropuerto; la frontera entre ambas clases es semantica mas que visual", frozenset(["medium_residential", "dense_residential"]): "ambas son areas residenciales y se diferencian solo por la densidad de edificacion", frozenset(["commercial_area", "industrial_area"]): "ambas tienen edificios grandes de techo plano y zonas pavimentadas; la diferencia esta en el tipo de actividad", frozenset(["bridge", "overpass"]): "ambas son estructuras lineales que cruzan un obstaculo; un puente cruza agua y un overpass cruza otra via", frozenset(["river", "lake"]): "ambas son masas de agua; los rios son alargados y los lagos cerrados, pero meandros y desembocaduras confunden", } ES_LABELS = { "airplane": "avion", "airport": "aeropuerto", "baseball_diamond": "campo de beisbol", "basketball_court": "cancha de baloncesto", "beach": "playa", "bridge": "puente", "chaparral": "matorral", "church": "iglesia", "circular_farmland": "campo circular", "cloud": "nube", "commercial_area": "area comercial", "dense_residential": "zona residencial densa", "desert": "desierto", "forest": "bosque", "freeway": "autopista", "golf_course": "campo de golf", "ground_track_field": "pista de atletismo", "harbor": "puerto", "industrial_area": "area industrial", "intersection": "interseccion", "island": "isla", "lake": "lago", "meadow": "pradera", "medium_residential": "zona residencial media", "mobile_home_park": "parque de caravanas", "mountain": "montana", "overpass": "paso elevado", "palace": "palacio", "parking_lot": "aparcamiento", "railway": "vias de tren", "railway_station": "estacion de tren", "rectangular_farmland": "campo rectangular", "river": "rio", "roundabout": "rotonda", "runway": "pista de aterrizaje", "sea_ice": "hielo marino", "ship": "barco", "snowberg": "iceberg", "sparse_residential": "zona residencial dispersa", "stadium": "estadio", "storage_tank": "deposito", "tennis_court": "pista de tenis", "terrace": "terraza agricola", "thermal_power_station": "central termica", "wetland": "humedal", } def es(label): return ES_LABELS.get(label, label.replace("_", " ")) def generate_explanation(top_classes, top_probs): top1, top2 = top_classes[0], top_classes[1] p1, p2 = top_probs[0], top_probs[1] t1, t2 = es(top1), es(top2) if p1 > 0.85: return (f"La imagen muestra con alta confianza un/a **{t1}** ({p1:.0%}). " f"El modelo identifica claramente los patrones visuales de esta clase.") if p1 > 0.5: key = frozenset([top1, top2]) if key in KNOWN_CONFUSIONS: return (f"La imagen muestra probablemente un/a **{t1}** ({p1:.0%}), aunque el modelo " f"tambien considera un/a **{t2}** ({p2:.0%}). Esta confusion es conocida: " f"{KNOWN_CONFUSIONS[key]}.") return (f"La imagen muestra probablemente un/a **{t1}** ({p1:.0%}). El modelo tambien " f"considera un/a **{t2}** ({p2:.0%}), lo que sugiere elementos visuales compartidos.") t3, p3 = es(top_classes[2]), top_probs[2] return (f"El modelo no esta seguro. La prediccion mas probable es **{t1}** ({p1:.0%}), " f"pero podria ser **{t2}** ({p2:.0%}) o **{t3}** ({p3:.0%}). " f"La baja confianza sugiere una imagen ambigua o con encuadre atipico.") def format_predictions(top_classes, top_probs): lines = [] for cls, prob in zip(top_classes, top_probs): label = es(cls) pct = prob * 100 bar_width = int(pct) lines.append(f"""
Por favor, sube una imagen satelital.
", "Por favor, sube una imagen satelital." x = inference_tfms(Image.open(img).convert("RGB")).unsqueeze(0) with torch.no_grad(): probs = torch.softmax(model(x), dim=1)[0] top5 = probs.argsort(descending=True)[:5] top5_classes = [VOCAB[i.item()] for i in top5] top5_probs = [probs[i].item() for i in top5] predictions_html = format_predictions(top5_classes, top5_probs) explanation = generate_explanation(top5_classes, top5_probs) return predictions_html, explanation description = """ **GeoSpotter** clasifica imagenes satelitales en 45 categorias (NWPU-RESISC45) y genera una explicacion textual sobre lo que muestra la imagen. **Modelo:** ConvNeXt-tiny fine-tuned, 95.5% accuracy con TTA en validacion. Sube una imagen aerea/satelital (recomendado ~256x256 px). Proyecto final de Computer Vision. """ with gr.Blocks(title="GeoSpotter") as demo: gr.Markdown("# GeoSpotter - Clasificador de Imagenes Satelitales") gr.Markdown(description) with gr.Row(): with gr.Column(): input_img = gr.Image(type="filepath", label="Imagen satelital") btn = gr.Button("Clasificar", variant="primary") with gr.Column(): output_labels = gr.HTML() output_text = gr.Markdown() btn.click(fn=predict, inputs=input_img, outputs=[output_labels, output_text]) if __name__ == "__main__": demo.launch()