Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| from flask import Flask, jsonify, request, render_template | |
| import os | |
| app = Flask(__name__) | |
| # Directorio de carga de im谩genes | |
| UPLOAD_FOLDER = 'static/uploads' | |
| app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER | |
| # Aplicar la transformaci贸n | |
| transform = transforms.Compose([ | |
| transforms.Resize((32, 32)), # Ajustar al tama帽o de entrada de la red | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
| ]) | |
| # Mostrar la imagen | |
| # imshow(transform(image)) | |
| class Net(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.conv1 = nn.Conv2d(3, 6, 5) | |
| self.pool = nn.MaxPool2d(2, 2) | |
| self.conv2 = nn.Conv2d(6, 16, 5) | |
| self.fc1 = nn.Linear(16 * 5 * 5, 120) | |
| self.fc2 = nn.Linear(120, 84) | |
| self.fc3 = nn.Linear(84, 10) | |
| def forward(self, x): | |
| x = self.pool(F.relu(self.conv1(x))) | |
| x = self.pool(F.relu(self.conv2(x))) | |
| x = torch.flatten(x, start_dim=1) | |
| x = F.relu(self.fc1(x)) | |
| x = F.relu(self.fc2(x)) | |
| x = self.fc3(x) | |
| return x | |
| classes = ('plane', 'car', 'bird', 'cat', | |
| 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') | |
| net = Net() | |
| PATH='cifar_net.pth' | |
| net.load_state_dict(torch.load(PATH)) | |
| net.eval() # Establecer la red en modo de evaluaci贸n | |
| # Endpoint para hacer predicciones | |
| def predict(): | |
| prediction = None | |
| image_path = None | |
| if request.method == 'POST': | |
| try: | |
| # Obtener la imagen desde la solicitud POST | |
| file = request.files['file'] | |
| # Guardar la imagen cargada en el directorio de carga | |
| image_path = os.path.join(app.config['UPLOAD_FOLDER'], file.filename) | |
| file.save(image_path) | |
| # Aplicar la transformaci贸n a la imagen | |
| image = Image.open(file) | |
| if image.mode in ("RGBA", "P"): | |
| # Convert the image to RGB mode | |
| image = image.convert("RGB") | |
| image_tensor = transform(image).unsqueeze(0) | |
| # Obtener la salida del modelo | |
| output = net(image_tensor) | |
| # Aplicar softmax para obtener las probabilidades | |
| probabilities = F.softmax(output, dim=1) | |
| # Obtener la clase predicha y la probabilidad m谩xima | |
| max_prob, predicted_class = torch.max(probabilities, 1) | |
| predicted_class_name = classes[predicted_class.item()] | |
| # Almacenar el resultado de la predicci贸n | |
| prediction = { | |
| 'predicted_class': predicted_class_name, | |
| 'probability': round(max_prob.item() * 100, 2) | |
| } | |
| except Exception as e: | |
| return jsonify({'error': str(e)}) | |
| return render_template('index.html', prediction=prediction, image_path=image_path) | |
| if __name__ == '__main__': | |
| app.run(debug=True, host="0.0.0.0", port="7860") |