Spaces:
Sleeping
Sleeping
| """GeoSpotter — clasificador de imágenes satelitales con explicación NLP. | |
| Proyecto final de Computer Vision, máster MIOTI. | |
| """ | |
| # ── Monkey-patch para bug de gradio_client en Gradio 5.x ────────────────── | |
| # Error: "argument of type 'bool' is not iterable" en _json_schema_to_python_type | |
| # cuando additionalProperties es un booleano en lugar de un dict. | |
| import gradio_client.utils as _gcu | |
| _original = _gcu._json_schema_to_python_type | |
| def _patched(schema, defs=None): | |
| if isinstance(schema, bool): | |
| return "Any" | |
| ap = schema.get("additionalProperties") if isinstance(schema, dict) else None | |
| if ap is not None and isinstance(ap, bool): | |
| schema = {**schema, "additionalProperties": {}} | |
| return _original(schema, defs) | |
| _gcu._json_schema_to_python_type = _patched | |
| # ────────────────────────────────────────────────────────────────────────── | |
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import timm | |
| import json | |
| from PIL import Image | |
| from torchvision import transforms | |
| from huggingface_hub import hf_hub_download | |
| WEIGHTS_REPO = "JorjoPM/geospotter-convnext-weights" | |
| class AdaptiveConcatPool2d(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.ap = nn.AdaptiveAvgPool2d(1) | |
| self.mp = nn.AdaptiveMaxPool2d(1) | |
| def forward(self, x): | |
| return torch.cat([self.mp(x), self.ap(x)], dim=1) | |
| class Flatten(nn.Module): | |
| def forward(self, x): | |
| return x.view(x.size(0), -1) | |
| def build_model(num_classes): | |
| backbone = timm.create_model( | |
| "convnext_tiny.fb_in22k", pretrained=False, | |
| num_classes=0, global_pool="", | |
| ) | |
| head = nn.Sequential( | |
| AdaptiveConcatPool2d(), | |
| Flatten(), | |
| nn.BatchNorm1d(1536, eps=1e-05, momentum=0.1), | |
| nn.Dropout(p=0.25), | |
| nn.Linear(1536, 512, bias=False), | |
| nn.ReLU(inplace=True), | |
| nn.BatchNorm1d(512, eps=1e-05, momentum=0.1), | |
| nn.Dropout(p=0.5), | |
| nn.Linear(512, num_classes, bias=False), | |
| ) | |
| return nn.Sequential(backbone, head) | |
| print(f"Descargando pesos desde {WEIGHTS_REPO}...") | |
| weights_path = hf_hub_download(repo_id=WEIGHTS_REPO, filename="convnext_weights.pth") | |
| classes_path = hf_hub_download(repo_id=WEIGHTS_REPO, filename="class_names.json") | |
| with open(classes_path) as f: | |
| VOCAB = json.load(f) | |
| model = build_model(len(VOCAB)) | |
| state_dict_raw = torch.load(weights_path, map_location="cpu") | |
| state_dict_fixed = {} | |
| for k, v in state_dict_raw.items(): | |
| new_key = ("0." + k[len("0.model."):]) if k.startswith("0.model.") else k | |
| state_dict_fixed[new_key] = v | |
| model.load_state_dict(state_dict_fixed, strict=False) | |
| model.eval() | |
| print(f"Modelo cargado. {len(VOCAB)} clases.") | |
| inference_tfms = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| KNOWN_CONFUSIONS = { | |
| frozenset(["palace", "church"]): | |
| "ambos son edificios monumentales con patios interiores y geometria rectilinea vistos desde arriba", | |
| frozenset(["rectangular_farmland", "terrace"]): | |
| "ambas clases muestran patrones de paralelogramos vegetales, distinguibles solo por la pendiente del terreno", | |
| frozenset(["runway", "airport"]): | |
| "una pista de aterrizaje es un componente de un aeropuerto; la frontera entre ambas clases es semantica mas que visual", | |
| frozenset(["medium_residential", "dense_residential"]): | |
| "ambas son areas residenciales y se diferencian solo por la densidad de edificacion", | |
| frozenset(["commercial_area", "industrial_area"]): | |
| "ambas tienen edificios grandes de techo plano y zonas pavimentadas; la diferencia esta en el tipo de actividad", | |
| frozenset(["bridge", "overpass"]): | |
| "ambas son estructuras lineales que cruzan un obstaculo; un puente cruza agua y un overpass cruza otra via", | |
| frozenset(["river", "lake"]): | |
| "ambas son masas de agua; los rios son alargados y los lagos cerrados, pero meandros y desembocaduras confunden", | |
| } | |
| ES_LABELS = { | |
| "airplane": "avion", "airport": "aeropuerto", "baseball_diamond": "campo de beisbol", | |
| "basketball_court": "cancha de baloncesto", "beach": "playa", "bridge": "puente", | |
| "chaparral": "matorral", "church": "iglesia", "circular_farmland": "campo circular", | |
| "cloud": "nube", "commercial_area": "area comercial", "dense_residential": "zona residencial densa", | |
| "desert": "desierto", "forest": "bosque", "freeway": "autopista", "golf_course": "campo de golf", | |
| "ground_track_field": "pista de atletismo", "harbor": "puerto", "industrial_area": "area industrial", | |
| "intersection": "interseccion", "island": "isla", "lake": "lago", "meadow": "pradera", | |
| "medium_residential": "zona residencial media", "mobile_home_park": "parque de caravanas", | |
| "mountain": "montana", "overpass": "paso elevado", "palace": "palacio", | |
| "parking_lot": "aparcamiento", "railway": "vias de tren", "railway_station": "estacion de tren", | |
| "rectangular_farmland": "campo rectangular", "river": "rio", "roundabout": "rotonda", | |
| "runway": "pista de aterrizaje", "sea_ice": "hielo marino", "ship": "barco", | |
| "snowberg": "iceberg", "sparse_residential": "zona residencial dispersa", "stadium": "estadio", | |
| "storage_tank": "deposito", "tennis_court": "pista de tenis", "terrace": "terraza agricola", | |
| "thermal_power_station": "central termica", "wetland": "humedal", | |
| } | |
| def es(label): | |
| return ES_LABELS.get(label, label.replace("_", " ")) | |
| def generate_explanation(top_classes, top_probs): | |
| top1, top2 = top_classes[0], top_classes[1] | |
| p1, p2 = top_probs[0], top_probs[1] | |
| t1, t2 = es(top1), es(top2) | |
| if p1 > 0.85: | |
| return (f"La imagen muestra con alta confianza un/a **{t1}** ({p1:.0%}). " | |
| f"El modelo identifica claramente los patrones visuales de esta clase.") | |
| if p1 > 0.5: | |
| key = frozenset([top1, top2]) | |
| if key in KNOWN_CONFUSIONS: | |
| return (f"La imagen muestra probablemente un/a **{t1}** ({p1:.0%}), aunque el modelo " | |
| f"tambien considera un/a **{t2}** ({p2:.0%}). Esta confusion es conocida: " | |
| f"{KNOWN_CONFUSIONS[key]}.") | |
| return (f"La imagen muestra probablemente un/a **{t1}** ({p1:.0%}). El modelo tambien " | |
| f"considera un/a **{t2}** ({p2:.0%}), lo que sugiere elementos visuales compartidos.") | |
| t3, p3 = es(top_classes[2]), top_probs[2] | |
| return (f"El modelo no esta seguro. La prediccion mas probable es **{t1}** ({p1:.0%}), " | |
| f"pero podria ser **{t2}** ({p2:.0%}) o **{t3}** ({p3:.0%}). " | |
| f"La baja confianza sugiere una imagen ambigua o con encuadre atipico.") | |
| def format_predictions(top_classes, top_probs): | |
| lines = [] | |
| for cls, prob in zip(top_classes, top_probs): | |
| label = es(cls) | |
| pct = prob * 100 | |
| bar_width = int(pct) | |
| lines.append(f""" | |
| <div style="margin-bottom:10px"> | |
| <div style="display:flex; justify-content:space-between; margin-bottom:3px"> | |
| <span style="font-weight:600">{label}</span> | |
| <span style="color:#888">{pct:.1f}%</span> | |
| </div> | |
| <div style="background:#e0e0e0; border-radius:4px; height:12px; width:100%"> | |
| <div style="background:#2563eb; border-radius:4px; height:12px; width:{bar_width}%"></div> | |
| </div> | |
| </div>""") | |
| return "".join(lines) | |
| def predict(img): | |
| if img is None: | |
| return "<p>Por favor, sube una imagen satelital.</p>", "Por favor, sube una imagen satelital." | |
| x = inference_tfms(Image.open(img).convert("RGB")).unsqueeze(0) | |
| with torch.no_grad(): | |
| probs = torch.softmax(model(x), dim=1)[0] | |
| top5 = probs.argsort(descending=True)[:5] | |
| top5_classes = [VOCAB[i.item()] for i in top5] | |
| top5_probs = [probs[i].item() for i in top5] | |
| predictions_html = format_predictions(top5_classes, top5_probs) | |
| explanation = generate_explanation(top5_classes, top5_probs) | |
| return predictions_html, explanation | |
| description = """ | |
| **GeoSpotter** clasifica imagenes satelitales en 45 categorias (NWPU-RESISC45) y genera | |
| una explicacion textual sobre lo que muestra la imagen. | |
| **Modelo:** ConvNeXt-tiny fine-tuned, 95.5% accuracy con TTA en validacion. | |
| Sube una imagen aerea/satelital (recomendado ~256x256 px). | |
| Proyecto final de Computer Vision. | |
| """ | |
| with gr.Blocks(title="GeoSpotter") as demo: | |
| gr.Markdown("# GeoSpotter - Clasificador de Imagenes Satelitales") | |
| gr.Markdown(description) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_img = gr.Image(type="filepath", label="Imagen satelital") | |
| btn = gr.Button("Clasificar", variant="primary") | |
| with gr.Column(): | |
| output_labels = gr.HTML() | |
| output_text = gr.Markdown() | |
| btn.click(fn=predict, inputs=input_img, outputs=[output_labels, output_text]) | |
| if __name__ == "__main__": | |
| demo.launch() | |