artscope / app.py
sattoru96's picture
fix app
45cc4e4 verified
"""ArtScope — app de Gradio para HuggingFace Space.
Pipeline:
imagen del cuadro
-> predicción con ConvNeXt + ViT en ensemble (promedio de probas)
-> Grad-CAM sobre el ConvNeXt
-> descripción del estilo predicho con Claude (opcional)
Pasos para desplegar:
1. Crea un Space (Gradio) en HuggingFace.
2. Sube este archivo + requirements.txt + README.md.
3. En Settings -> Variables and secrets, añade ANTHROPIC_API_KEY (opcional).
4. Sustituye HF_USER por tu usuario antes de subir.
"""
import os
import gradio as gr
import torch
import numpy as np
from PIL import Image
from torchvision import transforms as T
from huggingface_hub import from_pretrained_fastai
from fastai.vision.all import PILImage
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
# ---------------------------------------------------------------
# Configuración
# ---------------------------------------------------------------
HF_USER = "sattoru96" # <-- sustitúyelo por tu usuario de HF antes de subir
REPO_CNN = f"{HF_USER}/artscope-convnext"
REPO_VIT = f"{HF_USER}/artscope-vit"
# ---------------------------------------------------------------
# Carga de modelos (al iniciar el Space, una sola vez)
# ---------------------------------------------------------------
print("Descargando modelos del Hub...")
learn_cnn = from_pretrained_fastai(REPO_CNN)
learn_vit = from_pretrained_fastai(REPO_VIT)
LABELS = list(learn_cnn.dls.vocab)
print(f"Modelos listos. {len(LABELS)} clases: {LABELS}")
# Grad-CAM lo montamos solo sobre el ConvNeXt: es más rápido y la visualización
# es más interpretable en este tipo de arquitectura.
learn_cnn.model.eval()
# learn_cnn.model[0] es un TimmBody (wrapper de fastai); el modelo timm puro está en .model
_cnn_timm = learn_cnn.model[0].model
_target_layer = _cnn_timm.stages[-1].blocks[-1]
cam = GradCAM(model=learn_cnn.model, target_layers=[_target_layer])
# ---------------------------------------------------------------
# Cliente Claude vía REST directo (evitamos el SDK por problemas de deps en el Space)
# ---------------------------------------------------------------
import requests
ANTHROPIC_KEY = os.environ.get("ANTHROPIC_API_KEY")
ANTHROPIC_URL = "https://api.anthropic.com/v1/messages"
print(f"[init] ANTHROPIC_API_KEY presente: {bool(ANTHROPIC_KEY)} "
f"(longitud: {len(ANTHROPIC_KEY) if ANTHROPIC_KEY else 0})")
def llm_describe(top_style: str, second_style: str) -> str:
"""Genera una descripción del movimiento detectado usando Claude vía REST."""
if not ANTHROPIC_KEY:
return (
"_(Descripción LLM desactivada. Para activarla, añade "
"`ANTHROPIC_API_KEY` en los Secrets del Space.)_"
)
prompt = (
f"Eres un guía de museo experto. Acabo de mostrar un cuadro a un clasificador "
f"y dice que es {top_style.replace('_', ' ')}, con {second_style.replace('_', ' ')} "
f"como segunda opción. En 4-5 frases en español, explica qué rasgos visuales "
f"definen al {top_style.replace('_', ' ')} y por qué podría confundirse con "
f"{second_style.replace('_', ' ')}. Tono divulgativo, sin tecnicismos innecesarios."
)
headers = {
"x-api-key": ANTHROPIC_KEY.strip(),
"anthropic-version": "2023-06-01",
"content-type": "application/json",
}
payload = {
"model": "claude-haiku-4-5-20251001",
"max_tokens": 400,
"messages": [{"role": "user", "content": prompt}],
}
try:
r = requests.post(ANTHROPIC_URL, headers=headers, json=payload, timeout=30)
if r.status_code != 200:
print(f"[claude] HTTP {r.status_code}: {r.text[:500]}")
return f"_(Error HTTP {r.status_code} de Claude: {r.text[:200]})_"
data = r.json()
return data["content"][0]["text"]
except Exception as e:
import traceback
traceback.print_exc()
return f"_(Error llamando a Claude: `{type(e).__name__}: {e}`)_"
# ---------------------------------------------------------------
# Función principal
# ---------------------------------------------------------------
def predict(img):
if img is None:
return None, None, ""
# Preparación
pil_img = Image.fromarray(np.array(img)).convert("RGB")
fastai_img = PILImage.create(pil_img)
# Ensemble: promedio de probabilidades de ambos modelos
_, _, probs_cnn = learn_cnn.predict(fastai_img)
_, _, probs_vit = learn_vit.predict(fastai_img)
probs = ((probs_cnn + probs_vit) / 2).numpy()
# Top-3
order = sorted(range(len(LABELS)), key=lambda i: -probs[i])[:3]
label_dict = {LABELS[i]: float(probs[i]) for i in order}
# Grad-CAM
preprocess = T.Compose([T.Resize((224, 224)), T.ToTensor()])
tensor = preprocess(pil_img).unsqueeze(0)
if torch.cuda.is_available():
tensor = tensor.cuda()
grayscale = cam(input_tensor=tensor)[0]
rgb = np.array(preprocess(pil_img).permute(1, 2, 0))
cam_img = show_cam_on_image(rgb, grayscale, use_rgb=True)
# Descripción
description = llm_describe(LABELS[order[0]], LABELS[order[1]])
return label_dict, cam_img, description
# ---------------------------------------------------------------
# UI
# ---------------------------------------------------------------
DESCRIPTION = (
"# 🎨 ArtScope\n"
"Sube un cuadro y descubre a qué **movimiento artístico** pertenece, dónde "
"está mirando el modelo (mapa Grad-CAM) y qué hace especial a ese estilo "
"(descripción generada por Claude).\n\n"
"*Modelo: ensemble ConvNeXt-tiny + ViT-small, fine-tuned sobre un subset de "
"WikiArt con 10 movimientos.*"
)
with gr.Blocks(title="ArtScope", theme=gr.themes.Soft()) as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column(scale=1):
inp = gr.Image(type="numpy", label="Sube un cuadro")
btn = gr.Button("Analizar", variant="primary")
gr.Markdown(
"**Estilos soportados**: "
+ ", ".join(s.replace("_", " ") for s in LABELS)
)
with gr.Column(scale=1):
out_label = gr.Label(num_top_classes=3, label="Top movimientos")
out_cam = gr.Image(label="Dónde mira el modelo (Grad-CAM)")
out_desc = gr.Markdown(label="Descripción del estilo")
btn.click(predict, inputs=inp, outputs=[out_label, out_cam, out_desc])
if __name__ == "__main__":
demo.launch()