taller_CNN / app.py
NICOMOSHE's picture
Update app.py
4abf07a verified
"""
Gradio App para Hugging Face Spaces
Clasificador de im谩genes CNN
"""
import gradio as gr
import torch
import torchvision.transforms as transforms
from PIL import Image
import os
from pathlib import Path
MODELS_INFO = {
"fashion_mnist": {
"name": "Fashion-MNIST",
"classes": ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'],
},
"cifar10": {
"name": "CIFAR-10",
"classes": ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'],
},
"svhn": {
"name": "SVHN",
"classes": ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'],
},
"emnist": {
"name": "EMNIST",
"classes": [str(i) for i in range(10)] + [chr(c) for c in range(65, 91)] + [chr(c) for c in range(97, 123)],
},
"tiny_imagenet": {
"name": "Tiny ImageNet",
"classes": [f'class_{i}' for i in range(200)],
}
}
# Cargar modelo desde archivo local
loaded_models = {}
def get_transform(model_name):
"""Obtiene transformaciones seg煤n el modelo"""
if model_name == "fashion_mnist":
return transforms.Compose([
transforms.Grayscale(),
transforms.Resize((28, 28)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
elif model_name == "emnist":
return transforms.Compose([
transforms.Grayscale(),
transforms.Resize((28, 28)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
elif model_name == "cifar10" or model_name == "svhn":
return transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
else: # tiny_imagenet
return transforms.Compose([
transforms.Resize((64, 64)),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
def load_model(model_name):
"""Carga el modelo"""
if model_name in loaded_models:
return loaded_models[model_name]
checkpoint_path = f"checkpoints/{model_name}/best_model.pth"
if not os.path.exists(checkpoint_path):
return None
try:
checkpoint = torch.load(checkpoint_path, map_location='cpu')
if 'model_state_dict' in checkpoint:
state_dict = checkpoint['model_state_dict']
else:
state_dict = checkpoint
# Crear modelo seg煤n el tipo
if model_name == "tiny_imagenet":
from torchvision.models import resnet18
model = resnet18(weights=None)
model.fc = torch.nn.Linear(512, 200)
elif model_name == "fashion_mnist":
from models.cnn_builder import FashionCNN
model = FashionCNN(num_classes=10)
elif model_name == "cifar10":
from models.cnn_builder import CIFAR10CNN
model = CIFAR10CNN(num_classes=10)
elif model_name == "svhn":
from models.cnn_builder import SVHNNet
model = SVHNNet(num_classes=10)
elif model_name == "emnist":
from models.cnn_builder import EMNISTNet
model = EMNISTNet(num_classes=62)
else:
return None
model.load_state_dict(state_dict)
model.eval()
loaded_models[model_name] = model
return model
except Exception as e:
print(f"Error cargando modelo {model_name}: {e}")
return None
def predict(image, model_name):
"""Hace predicci贸n"""
if image is None:
return "Por favor sube una imagen"
model = load_model(model_name)
if model is None:
return f"Modelo {model_name} no disponible a煤n"
info = MODELS_INFO[model_name]
transform = get_transform(model_name)
# Convertir y transformar
img = image.convert('RGB') if image.mode == 'RGBA' else image
img_tensor = transform(img).unsqueeze(0)
# Predicci贸n
with torch.no_grad():
output = model(img_tensor)
probs = torch.nn.functional.softmax(output, dim=1)
top_prob, top_class = torch.max(probs, 1)
pred = info["classes"][top_class.item()]
conf = top_prob.item() * 100
# Top 3
top3 = torch.topk(probs, 3, dim=1)
results = []
for i in range(3):
cls = info["classes"][top3.indices[0][i].item()]
prob = top3.values[0][i].item() * 100
results.append(f"**{cls}**: {prob:.1f}%")
return f"## 馃幆 Predicci贸n: **{pred}**\n\n**Confianza:** {conf:.1f}%\n\n### Top 3:\n" + "\n".join(results)
# Interfaz Gradio
with gr.Blocks(title="Clasificador CNN") as demo:
gr.Markdown("# 馃 Clasificador de Im谩genes CNN")
gr.Markdown("Selecciona un modelo y sube una imagen para clasificar")
with gr.Row():
with gr.Column():
model_choice = gr.Dropdown(
choices=list(MODELS_INFO.keys()),
value="fashion_mnist",
label="Modelo",
)
input_image = gr.Image(type="pil", label="Sube una imagen")
btn = gr.Button("Clasificar 馃幆", variant="primary")
with gr.Column():
output = gr.Markdown()
btn.click(fn=predict, inputs=[input_image, model_choice], outputs=output)
gr.Examples(
examples=[["fashion_mNIST"], ["cifar10"], ["svhn"], ["emnist"], ["tiny_imagenet"]],
inputs=model_choice,
)
demo.launch()