PlatanosCalidad / app.py
JairoDanielMT's picture
Deployment of Bananas Quality App in Hugging Face Space
c95301d
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import cv2
import numpy as np
from PIL import Image
import gradio as gr
# Preprocesamiento de imágenes
transform = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
self.fc1 = nn.Linear(32 * 256 * 256, 128)
self.fc2 = nn.Linear(128, 4) # 4 clases: baja, regular, excelente, mala
def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.maxpool(x)
x = x.view(x.size(0), -1)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
# Configurar dispositivo en CPU
device = torch.device('cpu')
# Cargar el modelo previamente guardado
model = CNN().to(device)
model.load_state_dict(torch.load('calidadplatano.pth', map_location=device))
model.eval()
# Función para clasificar la imagen de entrada
def classify_image(input_image):
input_image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB)
input_image = Image.fromarray(input_image)
input_image = transform(input_image).unsqueeze(0).to(device)
output = model(input_image)
probabilities = torch.softmax(output, dim=1).squeeze().detach().cpu().numpy()
class_labels = ['baja', 'regular', 'excelente', 'mala']
predicted_class = class_labels[np.argmax(probabilities)]
confidence = probabilities[np.argmax(probabilities)]
return predicted_class, confidence
# Definir la interfaz gráfica de usuario
inputs = gr.inputs.Image()
outputs = gr.outputs.Label(num_top_classes=1)
def process_image(input_image):
predicted_class, confidence = classify_image(input_image)
return predicted_class + " (" + str(round(confidence * 100, 2)) + "%)"
title = "Clasificación de calidad de plátanos"
description = "Carga una imagen de plátano y obtén la clasificación de calidad."
iface = gr.Interface(fn=process_image, inputs=inputs, outputs=outputs, title=title, description=description)
iface.launch()