File size: 6,523 Bytes
b14f628
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45cc4e4
b14f628
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e22cf33
 
5ad6e4f
b14f628
 
 
e22cf33
b14f628
e22cf33
 
b14f628
e22cf33
 
 
b14f628
 
 
e22cf33
 
b14f628
 
 
 
 
 
 
 
 
 
 
e22cf33
 
 
 
 
 
 
 
 
 
b14f628
e22cf33
 
 
 
 
 
b14f628
e22cf33
 
 
b14f628
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
"""ArtScope — app de Gradio para HuggingFace Space.

Pipeline:
  imagen del cuadro
    -> predicción con ConvNeXt + ViT en ensemble (promedio de probas)
    -> Grad-CAM sobre el ConvNeXt
    -> descripción del estilo predicho con Claude (opcional)

Pasos para desplegar:
  1. Crea un Space (Gradio) en HuggingFace.
  2. Sube este archivo + requirements.txt + README.md.
  3. En Settings -> Variables and secrets, añade ANTHROPIC_API_KEY (opcional).
  4. Sustituye HF_USER por tu usuario antes de subir.
"""
import os
import gradio as gr
import torch
import numpy as np
from PIL import Image
from torchvision import transforms as T

from huggingface_hub import from_pretrained_fastai
from fastai.vision.all import PILImage
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image

# ---------------------------------------------------------------
# Configuración
# ---------------------------------------------------------------
HF_USER = "sattoru96"  # <-- sustitúyelo por tu usuario de HF antes de subir
REPO_CNN = f"{HF_USER}/artscope-convnext"
REPO_VIT = f"{HF_USER}/artscope-vit"

# ---------------------------------------------------------------
# Carga de modelos (al iniciar el Space, una sola vez)
# ---------------------------------------------------------------
print("Descargando modelos del Hub...")
learn_cnn = from_pretrained_fastai(REPO_CNN)
learn_vit = from_pretrained_fastai(REPO_VIT)
LABELS = list(learn_cnn.dls.vocab)
print(f"Modelos listos. {len(LABELS)} clases: {LABELS}")

# Grad-CAM lo montamos solo sobre el ConvNeXt: es más rápido y la visualización
# es más interpretable en este tipo de arquitectura.
learn_cnn.model.eval()
# learn_cnn.model[0] es un TimmBody (wrapper de fastai); el modelo timm puro está en .model
_cnn_timm = learn_cnn.model[0].model
_target_layer = _cnn_timm.stages[-1].blocks[-1]
cam = GradCAM(model=learn_cnn.model, target_layers=[_target_layer])

# ---------------------------------------------------------------
# Cliente Claude vía REST directo (evitamos el SDK por problemas de deps en el Space)
# ---------------------------------------------------------------
import requests

ANTHROPIC_KEY = os.environ.get("ANTHROPIC_API_KEY")
ANTHROPIC_URL = "https://api.anthropic.com/v1/messages"
print(f"[init] ANTHROPIC_API_KEY presente: {bool(ANTHROPIC_KEY)} "
      f"(longitud: {len(ANTHROPIC_KEY) if ANTHROPIC_KEY else 0})")


def llm_describe(top_style: str, second_style: str) -> str:
    """Genera una descripción del movimiento detectado usando Claude vía REST."""
    if not ANTHROPIC_KEY:
        return (
            "_(Descripción LLM desactivada. Para activarla, añade "
            "`ANTHROPIC_API_KEY` en los Secrets del Space.)_"
        )
    prompt = (
        f"Eres un guía de museo experto. Acabo de mostrar un cuadro a un clasificador "
        f"y dice que es {top_style.replace('_', ' ')}, con {second_style.replace('_', ' ')} "
        f"como segunda opción. En 4-5 frases en español, explica qué rasgos visuales "
        f"definen al {top_style.replace('_', ' ')} y por qué podría confundirse con "
        f"{second_style.replace('_', ' ')}. Tono divulgativo, sin tecnicismos innecesarios."
    )
    headers = {
        "x-api-key": ANTHROPIC_KEY.strip(),
        "anthropic-version": "2023-06-01",
        "content-type": "application/json",
    }
    payload = {
        "model": "claude-haiku-4-5-20251001",
        "max_tokens": 400,
        "messages": [{"role": "user", "content": prompt}],
    }
    try:
        r = requests.post(ANTHROPIC_URL, headers=headers, json=payload, timeout=30)
        if r.status_code != 200:
            print(f"[claude] HTTP {r.status_code}: {r.text[:500]}")
            return f"_(Error HTTP {r.status_code} de Claude: {r.text[:200]})_"
        data = r.json()
        return data["content"][0]["text"]
    except Exception as e:
        import traceback
        traceback.print_exc()
        return f"_(Error llamando a Claude: `{type(e).__name__}: {e}`)_"


# ---------------------------------------------------------------
# Función principal
# ---------------------------------------------------------------
def predict(img):
    if img is None:
        return None, None, ""

    # Preparación
    pil_img = Image.fromarray(np.array(img)).convert("RGB")
    fastai_img = PILImage.create(pil_img)

    # Ensemble: promedio de probabilidades de ambos modelos
    _, _, probs_cnn = learn_cnn.predict(fastai_img)
    _, _, probs_vit = learn_vit.predict(fastai_img)
    probs = ((probs_cnn + probs_vit) / 2).numpy()

    # Top-3
    order = sorted(range(len(LABELS)), key=lambda i: -probs[i])[:3]
    label_dict = {LABELS[i]: float(probs[i]) for i in order}

    # Grad-CAM
    preprocess = T.Compose([T.Resize((224, 224)), T.ToTensor()])
    tensor = preprocess(pil_img).unsqueeze(0)
    if torch.cuda.is_available():
        tensor = tensor.cuda()
    grayscale = cam(input_tensor=tensor)[0]
    rgb = np.array(preprocess(pil_img).permute(1, 2, 0))
    cam_img = show_cam_on_image(rgb, grayscale, use_rgb=True)

    # Descripción
    description = llm_describe(LABELS[order[0]], LABELS[order[1]])

    return label_dict, cam_img, description


# ---------------------------------------------------------------
# UI
# ---------------------------------------------------------------
DESCRIPTION = (
    "# 🎨 ArtScope\n"
    "Sube un cuadro y descubre a qué **movimiento artístico** pertenece, dónde "
    "está mirando el modelo (mapa Grad-CAM) y qué hace especial a ese estilo "
    "(descripción generada por Claude).\n\n"
    "*Modelo: ensemble ConvNeXt-tiny + ViT-small, fine-tuned sobre un subset de "
    "WikiArt con 10 movimientos.*"
)

with gr.Blocks(title="ArtScope", theme=gr.themes.Soft()) as demo:
    gr.Markdown(DESCRIPTION)
    with gr.Row():
        with gr.Column(scale=1):
            inp = gr.Image(type="numpy", label="Sube un cuadro")
            btn = gr.Button("Analizar", variant="primary")
            gr.Markdown(
                "**Estilos soportados**: "
                + ", ".join(s.replace("_", " ") for s in LABELS)
            )
        with gr.Column(scale=1):
            out_label = gr.Label(num_top_classes=3, label="Top movimientos")
            out_cam = gr.Image(label="Dónde mira el modelo (Grad-CAM)")
            out_desc = gr.Markdown(label="Descripción del estilo")

    btn.click(predict, inputs=inp, outputs=[out_label, out_cam, out_desc])

if __name__ == "__main__":
    demo.launch()