|
|
import gradio as gr |
|
|
import torch |
|
|
from torchvision import models, transforms |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
model = models.mobilenet_v2(pretrained=True) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize(256), |
|
|
transforms.CenterCrop(224), |
|
|
transforms.ToTensor() |
|
|
]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
labels = ["tench", "goldfish", "great white shark", "...", "Egyptian cat", "tabby cat", "tiger cat", "Persian cat", ...] |
|
|
|
|
|
def clasificar(imagen): |
|
|
|
|
|
img_t = transform(imagen).unsqueeze(0) |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
salida = model(img_t) |
|
|
|
|
|
idx = salida[0].argmax().item() |
|
|
|
|
|
return labels[idx] |
|
|
|
|
|
|
|
|
gr.Interface(fn=clasificar, |
|
|
inputs=gr.Image(type="pil"), |
|
|
outputs="label", |
|
|
title="Clasificador de Imágenes", |
|
|
description="Clasifica imágenes usando MobileNetV2").launch() |