import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms as transforms from PIL import Image from flask import Flask, jsonify, request, render_template import os app = Flask(__name__) # Directorio de carga de imágenes UPLOAD_FOLDER = 'static/uploads' app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER # Aplicar la transformación transform = transforms.Compose([ transforms.Resize((32, 32)), # Ajustar al tamaño de entrada de la red transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) # Mostrar la imagen # imshow(transform(image)) class Net(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = torch.flatten(x, start_dim=1) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') net = Net() PATH='cifar_net.pth' net.load_state_dict(torch.load(PATH)) net.eval() # Establecer la red en modo de evaluación # Endpoint para hacer predicciones @app.route('/', methods=['GET', 'POST']) def predict(): prediction = None image_path = None if request.method == 'POST': try: # Obtener la imagen desde la solicitud POST file = request.files['file'] # Guardar la imagen cargada en el directorio de carga image_path = os.path.join(app.config['UPLOAD_FOLDER'], file.filename) file.save(image_path) # Aplicar la transformación a la imagen image = Image.open(file) if image.mode in ("RGBA", "P"): # Convert the image to RGB mode image = image.convert("RGB") image_tensor = transform(image).unsqueeze(0) # Obtener la salida del modelo output = net(image_tensor) # Aplicar softmax para obtener las probabilidades probabilities = F.softmax(output, dim=1) # Obtener la clase predicha y la probabilidad máxima max_prob, predicted_class = torch.max(probabilities, 1) predicted_class_name = classes[predicted_class.item()] # Almacenar el resultado de la predicción prediction = { 'predicted_class': predicted_class_name, 'probability': round(max_prob.item() * 100, 2) } except Exception as e: return jsonify({'error': str(e)}) return render_template('index.html', prediction=prediction, image_path=image_path) if __name__ == '__main__': app.run(debug=True, host="0.0.0.0", port="7860")