Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| from PIL import Image | |
| import torchvision.transforms as transforms | |
| # Modelo Autoencoder | |
| class Autoencoder(nn.Module): | |
| def __init__(self): | |
| super(Autoencoder, self).__init__() | |
| self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1) | |
| self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1) | |
| self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1) | |
| self.fc1 = nn.Linear(128 * 8 * 8, 32) | |
| self.fc2 = nn.Linear(32, 128 * 8 * 8) | |
| self.conv4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1) | |
| self.conv5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1) | |
| self.conv6 = nn.ConvTranspose2d(32, 1, kernel_size=3, stride=2, padding=1, output_padding=1) | |
| def encode(self, x): | |
| z = torch.tanh(self.conv1(x)) | |
| z = torch.tanh(self.conv2(z)) | |
| z = torch.tanh(self.conv3(z)) | |
| z = z.view(z.size(0), -1) | |
| z = torch.tanh(self.fc1(z)) | |
| return z | |
| def decode(self, x): | |
| z = torch.tanh(self.fc2(x)) | |
| z = z.view(z.size(0), 128, 8, 8) | |
| z = torch.tanh(self.conv4(z)) | |
| z = torch.tanh(self.conv5(z)) | |
| z = torch.sigmoid(self.conv6(z)) | |
| return z | |
| def forward(self, x): | |
| return self.decode(self.encode(x)) | |
| # Cargar el modelo | |
| model = Autoencoder() | |
| model.load_state_dict(torch.load("autoencoder.pth", map_location=torch.device("cpu"))) | |
| model.eval() | |
| # Transformaci贸n | |
| transform = transforms.Compose([ | |
| transforms.Grayscale(), | |
| transforms.Resize((64, 64)), | |
| transforms.ToTensor() | |
| ]) | |
| # Umbral de error (ajustable) | |
| THRESHOLD = 0.01 | |
| # Funci贸n de predicci贸n | |
| def detectar_anomalia(imagen): | |
| img_tensor = transform(imagen).unsqueeze(0) | |
| with torch.no_grad(): | |
| reconstruida = model(img_tensor) | |
| mse = torch.mean((img_tensor - reconstruida) ** 2).item() | |
| resultado = "An贸mala" if mse > THRESHOLD else "Normal" | |
| return resultado | |
| # Interfaz Gradio | |
| demo = gr.Interface( | |
| fn=detectar_anomalia, | |
| inputs=gr.Image(type="pil", label="Sube una imagen para analizar"), | |
| outputs=gr.Label(label="Resultado"), | |
| examples=["anomalous.png", "normal.png"], | |
| title="Detecci贸n de Anomal铆as con Autoencoder (PyTorch)", | |
| description="Este Space utiliza un autoencoder entrenado con PyTorch para detectar anomal铆as en im谩genes de textiles.", | |
| ) | |
| demo.launch() | |