Geospotter / app.py
JorjoPM's picture
Update app.py
81f53ff verified
"""GeoSpotter — clasificador de imágenes satelitales con explicación NLP.
Proyecto final de Computer Vision, máster MIOTI.
"""
# ── Monkey-patch para bug de gradio_client en Gradio 5.x ──────────────────
# Error: "argument of type 'bool' is not iterable" en _json_schema_to_python_type
# cuando additionalProperties es un booleano en lugar de un dict.
import gradio_client.utils as _gcu
_original = _gcu._json_schema_to_python_type
def _patched(schema, defs=None):
if isinstance(schema, bool):
return "Any"
ap = schema.get("additionalProperties") if isinstance(schema, dict) else None
if ap is not None and isinstance(ap, bool):
schema = {**schema, "additionalProperties": {}}
return _original(schema, defs)
_gcu._json_schema_to_python_type = _patched
# ──────────────────────────────────────────────────────────────────────────
import gradio as gr
import torch
import torch.nn as nn
import timm
import json
from PIL import Image
from torchvision import transforms
from huggingface_hub import hf_hub_download
WEIGHTS_REPO = "JorjoPM/geospotter-convnext-weights"
class AdaptiveConcatPool2d(nn.Module):
def __init__(self):
super().__init__()
self.ap = nn.AdaptiveAvgPool2d(1)
self.mp = nn.AdaptiveMaxPool2d(1)
def forward(self, x):
return torch.cat([self.mp(x), self.ap(x)], dim=1)
class Flatten(nn.Module):
def forward(self, x):
return x.view(x.size(0), -1)
def build_model(num_classes):
backbone = timm.create_model(
"convnext_tiny.fb_in22k", pretrained=False,
num_classes=0, global_pool="",
)
head = nn.Sequential(
AdaptiveConcatPool2d(),
Flatten(),
nn.BatchNorm1d(1536, eps=1e-05, momentum=0.1),
nn.Dropout(p=0.25),
nn.Linear(1536, 512, bias=False),
nn.ReLU(inplace=True),
nn.BatchNorm1d(512, eps=1e-05, momentum=0.1),
nn.Dropout(p=0.5),
nn.Linear(512, num_classes, bias=False),
)
return nn.Sequential(backbone, head)
print(f"Descargando pesos desde {WEIGHTS_REPO}...")
weights_path = hf_hub_download(repo_id=WEIGHTS_REPO, filename="convnext_weights.pth")
classes_path = hf_hub_download(repo_id=WEIGHTS_REPO, filename="class_names.json")
with open(classes_path) as f:
VOCAB = json.load(f)
model = build_model(len(VOCAB))
state_dict_raw = torch.load(weights_path, map_location="cpu")
state_dict_fixed = {}
for k, v in state_dict_raw.items():
new_key = ("0." + k[len("0.model."):]) if k.startswith("0.model.") else k
state_dict_fixed[new_key] = v
model.load_state_dict(state_dict_fixed, strict=False)
model.eval()
print(f"Modelo cargado. {len(VOCAB)} clases.")
inference_tfms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
KNOWN_CONFUSIONS = {
frozenset(["palace", "church"]):
"ambos son edificios monumentales con patios interiores y geometria rectilinea vistos desde arriba",
frozenset(["rectangular_farmland", "terrace"]):
"ambas clases muestran patrones de paralelogramos vegetales, distinguibles solo por la pendiente del terreno",
frozenset(["runway", "airport"]):
"una pista de aterrizaje es un componente de un aeropuerto; la frontera entre ambas clases es semantica mas que visual",
frozenset(["medium_residential", "dense_residential"]):
"ambas son areas residenciales y se diferencian solo por la densidad de edificacion",
frozenset(["commercial_area", "industrial_area"]):
"ambas tienen edificios grandes de techo plano y zonas pavimentadas; la diferencia esta en el tipo de actividad",
frozenset(["bridge", "overpass"]):
"ambas son estructuras lineales que cruzan un obstaculo; un puente cruza agua y un overpass cruza otra via",
frozenset(["river", "lake"]):
"ambas son masas de agua; los rios son alargados y los lagos cerrados, pero meandros y desembocaduras confunden",
}
ES_LABELS = {
"airplane": "avion", "airport": "aeropuerto", "baseball_diamond": "campo de beisbol",
"basketball_court": "cancha de baloncesto", "beach": "playa", "bridge": "puente",
"chaparral": "matorral", "church": "iglesia", "circular_farmland": "campo circular",
"cloud": "nube", "commercial_area": "area comercial", "dense_residential": "zona residencial densa",
"desert": "desierto", "forest": "bosque", "freeway": "autopista", "golf_course": "campo de golf",
"ground_track_field": "pista de atletismo", "harbor": "puerto", "industrial_area": "area industrial",
"intersection": "interseccion", "island": "isla", "lake": "lago", "meadow": "pradera",
"medium_residential": "zona residencial media", "mobile_home_park": "parque de caravanas",
"mountain": "montana", "overpass": "paso elevado", "palace": "palacio",
"parking_lot": "aparcamiento", "railway": "vias de tren", "railway_station": "estacion de tren",
"rectangular_farmland": "campo rectangular", "river": "rio", "roundabout": "rotonda",
"runway": "pista de aterrizaje", "sea_ice": "hielo marino", "ship": "barco",
"snowberg": "iceberg", "sparse_residential": "zona residencial dispersa", "stadium": "estadio",
"storage_tank": "deposito", "tennis_court": "pista de tenis", "terrace": "terraza agricola",
"thermal_power_station": "central termica", "wetland": "humedal",
}
def es(label):
return ES_LABELS.get(label, label.replace("_", " "))
def generate_explanation(top_classes, top_probs):
top1, top2 = top_classes[0], top_classes[1]
p1, p2 = top_probs[0], top_probs[1]
t1, t2 = es(top1), es(top2)
if p1 > 0.85:
return (f"La imagen muestra con alta confianza un/a **{t1}** ({p1:.0%}). "
f"El modelo identifica claramente los patrones visuales de esta clase.")
if p1 > 0.5:
key = frozenset([top1, top2])
if key in KNOWN_CONFUSIONS:
return (f"La imagen muestra probablemente un/a **{t1}** ({p1:.0%}), aunque el modelo "
f"tambien considera un/a **{t2}** ({p2:.0%}). Esta confusion es conocida: "
f"{KNOWN_CONFUSIONS[key]}.")
return (f"La imagen muestra probablemente un/a **{t1}** ({p1:.0%}). El modelo tambien "
f"considera un/a **{t2}** ({p2:.0%}), lo que sugiere elementos visuales compartidos.")
t3, p3 = es(top_classes[2]), top_probs[2]
return (f"El modelo no esta seguro. La prediccion mas probable es **{t1}** ({p1:.0%}), "
f"pero podria ser **{t2}** ({p2:.0%}) o **{t3}** ({p3:.0%}). "
f"La baja confianza sugiere una imagen ambigua o con encuadre atipico.")
def format_predictions(top_classes, top_probs):
lines = []
for cls, prob in zip(top_classes, top_probs):
label = es(cls)
pct = prob * 100
bar_width = int(pct)
lines.append(f"""
<div style="margin-bottom:10px">
<div style="display:flex; justify-content:space-between; margin-bottom:3px">
<span style="font-weight:600">{label}</span>
<span style="color:#888">{pct:.1f}%</span>
</div>
<div style="background:#e0e0e0; border-radius:4px; height:12px; width:100%">
<div style="background:#2563eb; border-radius:4px; height:12px; width:{bar_width}%"></div>
</div>
</div>""")
return "".join(lines)
def predict(img):
if img is None:
return "<p>Por favor, sube una imagen satelital.</p>", "Por favor, sube una imagen satelital."
x = inference_tfms(Image.open(img).convert("RGB")).unsqueeze(0)
with torch.no_grad():
probs = torch.softmax(model(x), dim=1)[0]
top5 = probs.argsort(descending=True)[:5]
top5_classes = [VOCAB[i.item()] for i in top5]
top5_probs = [probs[i].item() for i in top5]
predictions_html = format_predictions(top5_classes, top5_probs)
explanation = generate_explanation(top5_classes, top5_probs)
return predictions_html, explanation
description = """
**GeoSpotter** clasifica imagenes satelitales en 45 categorias (NWPU-RESISC45) y genera
una explicacion textual sobre lo que muestra la imagen.
**Modelo:** ConvNeXt-tiny fine-tuned, 95.5% accuracy con TTA en validacion.
Sube una imagen aerea/satelital (recomendado ~256x256 px).
Proyecto final de Computer Vision.
"""
with gr.Blocks(title="GeoSpotter") as demo:
gr.Markdown("# GeoSpotter - Clasificador de Imagenes Satelitales")
gr.Markdown(description)
with gr.Row():
with gr.Column():
input_img = gr.Image(type="filepath", label="Imagen satelital")
btn = gr.Button("Clasificar", variant="primary")
with gr.Column():
output_labels = gr.HTML()
output_text = gr.Markdown()
btn.click(fn=predict, inputs=input_img, outputs=[output_labels, output_text])
if __name__ == "__main__":
demo.launch()