gatosvsperros / app.py
RodrigoGariv's picture
Create app.py
4800f22 verified
import gradio as gr
import torch
from torchvision import models, transforms
from PIL import Image
# Carga modelo base preentrenado
model = models.mobilenet_v2(pretrained=True)
model.eval()
# Transformaciones de imagen
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor()
])
# Labels simplificados (imagenet)
# Nota: La lista de etiquetas está incompleta en la imagen.
# Deberías obtener la lista completa de etiquetas de ImageNet si necesitas todas.
labels = ["tench", "goldfish", "great white shark", "...", "Egyptian cat", "tabby cat", "tiger cat", "Persian cat", ...]
def clasificar(imagen):
# Aplica las transformaciones y añade una dimensión de batch
img_t = transform(imagen).unsqueeze(0)
# Desactiva el cálculo de gradientes para la inferencia
with torch.no_grad():
# Pasa la imagen por el modelo
salida = model(img_t)
# Obtiene el índice de la clase con la mayor probabilidad
idx = salida[0].argmax().item()
# Devuelve la etiqueta correspondiente a ese índice
return labels[idx]
# Crea la interfaz de Gradio
gr.Interface(fn=clasificar,
inputs=gr.Image(type="pil"), # La entrada es una imagen PIL
outputs="label", # La salida es una etiqueta de texto
title="Clasificador de Imágenes",
description="Clasifica imágenes usando MobileNetV2").launch()