Practicas_AIIA / files /inference.py
bernabeSanchez's picture
Update files/inference.py
c1082e9
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, request, render_template
import os
app = Flask(__name__)
# Directorio de carga de im谩genes
UPLOAD_FOLDER = 'static/uploads'
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
# Aplicar la transformaci贸n
transform = transforms.Compose([
transforms.Resize((32, 32)), # Ajustar al tama帽o de entrada de la red
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# Mostrar la imagen
# imshow(transform(image))
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = torch.flatten(x, start_dim=1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
net = Net()
PATH='cifar_net.pth'
net.load_state_dict(torch.load(PATH))
net.eval() # Establecer la red en modo de evaluaci贸n
# Endpoint para hacer predicciones
@app.route('/', methods=['GET', 'POST'])
def predict():
prediction = None
image_path = None
if request.method == 'POST':
try:
# Obtener la imagen desde la solicitud POST
file = request.files['file']
# Guardar la imagen cargada en el directorio de carga
image_path = os.path.join(app.config['UPLOAD_FOLDER'], file.filename)
file.save(image_path)
# Aplicar la transformaci贸n a la imagen
image = Image.open(file)
if image.mode in ("RGBA", "P"):
# Convert the image to RGB mode
image = image.convert("RGB")
image_tensor = transform(image).unsqueeze(0)
# Obtener la salida del modelo
output = net(image_tensor)
# Aplicar softmax para obtener las probabilidades
probabilities = F.softmax(output, dim=1)
# Obtener la clase predicha y la probabilidad m谩xima
max_prob, predicted_class = torch.max(probabilities, 1)
predicted_class_name = classes[predicted_class.item()]
# Almacenar el resultado de la predicci贸n
prediction = {
'predicted_class': predicted_class_name,
'probability': round(max_prob.item() * 100, 2)
}
except Exception as e:
return jsonify({'error': str(e)})
return render_template('index.html', prediction=prediction, image_path=image_path)
if __name__ == '__main__':
app.run(debug=True, host="0.0.0.0", port="7860")