bernabeSanchez commited on
Commit
31ebc5f
1 Parent(s): accfeb0

Upload 5 files

Browse files
files/cifar_net.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d68a16753138b325ae729b384dc00046eebd1eefd2c871b8efe9e10bd5a0f7a0
3
+ size 251604
files/inference.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision.transforms as transforms
5
+ from PIL import Image
6
+ from flask import Flask, jsonify, request, render_template
7
+ import os
8
+
9
+ app = Flask(__name__)
10
+
11
+ # Directorio de carga de im谩genes
12
+ UPLOAD_FOLDER = 'static/uploads'
13
+ app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
14
+
15
+
16
+ # Aplicar la transformaci贸n
17
+ transform = transforms.Compose([
18
+ transforms.Resize((32, 32)), # Ajustar al tama帽o de entrada de la red
19
+ transforms.ToTensor(),
20
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
21
+ ])
22
+
23
+ # Mostrar la imagen
24
+ # imshow(transform(image))
25
+
26
+ class Net(nn.Module):
27
+ def __init__(self):
28
+ super().__init__()
29
+ self.conv1 = nn.Conv2d(3, 6, 5)
30
+ self.pool = nn.MaxPool2d(2, 2)
31
+ self.conv2 = nn.Conv2d(6, 16, 5)
32
+ self.fc1 = nn.Linear(16 * 5 * 5, 120)
33
+ self.fc2 = nn.Linear(120, 84)
34
+ self.fc3 = nn.Linear(84, 10)
35
+
36
+ def forward(self, x):
37
+ x = self.pool(F.relu(self.conv1(x)))
38
+ x = self.pool(F.relu(self.conv2(x)))
39
+ x = torch.flatten(x, 1) # flatten all dimensions except batch
40
+ x = F.relu(self.fc1(x))
41
+ x = F.relu(self.fc2(x))
42
+ x = self.fc3(x)
43
+ return x
44
+
45
+
46
+ classes = ('plane', 'car', 'bird', 'cat',
47
+ 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
48
+ net = Net()
49
+ PATH='cifar_net.pth'
50
+ net.load_state_dict(torch.load(PATH))
51
+ net.eval() # Establecer la red en modo de evaluaci贸n
52
+
53
+
54
+
55
+ # Endpoint para hacer predicciones
56
+ @app.route('/', methods=['GET', 'POST'])
57
+ def predict():
58
+ prediction = None
59
+ image_path = None
60
+
61
+ if request.method == 'POST':
62
+ try:
63
+ # Obtener la imagen desde la solicitud POST
64
+ file = request.files['file']
65
+
66
+ # Guardar la imagen cargada en el directorio de carga
67
+ image_path = os.path.join(app.config['UPLOAD_FOLDER'], file.filename)
68
+ file.save(image_path)
69
+
70
+ # Aplicar la transformaci贸n a la imagen
71
+ image = Image.open(file)
72
+ image_tensor = transform(image).unsqueeze(0)
73
+
74
+ # Obtener la salida del modelo
75
+ output = net(image_tensor)
76
+
77
+ # Aplicar softmax para obtener las probabilidades
78
+ probabilities = F.softmax(output, dim=1)
79
+
80
+ # Obtener la clase predicha y la probabilidad m谩xima
81
+ max_prob, predicted_class = torch.max(probabilities, 1)
82
+ predicted_class_name = classes[predicted_class.item()]
83
+
84
+ # Almacenar el resultado de la predicci贸n
85
+ prediction = {
86
+ 'predicted_class': predicted_class_name,
87
+ 'probability': round(max_prob.item() * 100, 2)
88
+ }
89
+
90
+ except Exception as e:
91
+ return jsonify({'error': str(e)})
92
+ return render_template('index.html', prediction=prediction, image_path=image_path)
93
+
94
+
95
+ if __name__ == '__main__':
96
+ app.run(debug=True, host="0.0.0.0", port="7860")
files/static/uploads/avion_img.jpg ADDED
files/static/uploads/cat_image.jpg ADDED
files/templates/index.html ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta http-equiv="X-UA-Compatible" content="IE=edge">
6
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
7
+ <title>Predicci贸n de Im谩genes</title>
8
+ </head>
9
+ <body>
10
+ <h1>Sube una imagen para predecir</h1>
11
+ <form action="/predict" method="post" enctype="multipart/form-data">
12
+ <input type="file" name="file" accept="image/*" required>
13
+ <br>
14
+ <input type="submit" value="Predecir">
15
+ </form>
16
+ <br>
17
+ {% if prediction %}
18
+ <h2>Resultado de la Predicci贸n:</h2>
19
+ <p>Clase: {{ prediction['predicted_class'] }}</p>
20
+ <p>Probabilidad: {{ prediction['probability'] }}%</p>
21
+ <img src="{{ image_path }}" alt="Imagen Cargada">
22
+ <!-- <p>{{ image_path }}</p> -->
23
+ {% endif %}
24
+ </body>
25
+ </html>