Spaces:
Sleeping
Sleeping
Commit 路
31ebc5f
1
Parent(s): accfeb0
Upload 5 files
Browse files- files/cifar_net.pth +3 -0
- files/inference.py +96 -0
- files/static/uploads/avion_img.jpg +0 -0
- files/static/uploads/cat_image.jpg +0 -0
- files/templates/index.html +25 -0
files/cifar_net.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d68a16753138b325ae729b384dc00046eebd1eefd2c871b8efe9e10bd5a0f7a0
|
| 3 |
+
size 251604
|
files/inference.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import torchvision.transforms as transforms
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from flask import Flask, jsonify, request, render_template
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
app = Flask(__name__)
|
| 10 |
+
|
| 11 |
+
# Directorio de carga de im谩genes
|
| 12 |
+
UPLOAD_FOLDER = 'static/uploads'
|
| 13 |
+
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# Aplicar la transformaci贸n
|
| 17 |
+
transform = transforms.Compose([
|
| 18 |
+
transforms.Resize((32, 32)), # Ajustar al tama帽o de entrada de la red
|
| 19 |
+
transforms.ToTensor(),
|
| 20 |
+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
| 21 |
+
])
|
| 22 |
+
|
| 23 |
+
# Mostrar la imagen
|
| 24 |
+
# imshow(transform(image))
|
| 25 |
+
|
| 26 |
+
class Net(nn.Module):
|
| 27 |
+
def __init__(self):
|
| 28 |
+
super().__init__()
|
| 29 |
+
self.conv1 = nn.Conv2d(3, 6, 5)
|
| 30 |
+
self.pool = nn.MaxPool2d(2, 2)
|
| 31 |
+
self.conv2 = nn.Conv2d(6, 16, 5)
|
| 32 |
+
self.fc1 = nn.Linear(16 * 5 * 5, 120)
|
| 33 |
+
self.fc2 = nn.Linear(120, 84)
|
| 34 |
+
self.fc3 = nn.Linear(84, 10)
|
| 35 |
+
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
x = self.pool(F.relu(self.conv1(x)))
|
| 38 |
+
x = self.pool(F.relu(self.conv2(x)))
|
| 39 |
+
x = torch.flatten(x, 1) # flatten all dimensions except batch
|
| 40 |
+
x = F.relu(self.fc1(x))
|
| 41 |
+
x = F.relu(self.fc2(x))
|
| 42 |
+
x = self.fc3(x)
|
| 43 |
+
return x
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
classes = ('plane', 'car', 'bird', 'cat',
|
| 47 |
+
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
|
| 48 |
+
net = Net()
|
| 49 |
+
PATH='cifar_net.pth'
|
| 50 |
+
net.load_state_dict(torch.load(PATH))
|
| 51 |
+
net.eval() # Establecer la red en modo de evaluaci贸n
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# Endpoint para hacer predicciones
|
| 56 |
+
@app.route('/', methods=['GET', 'POST'])
|
| 57 |
+
def predict():
|
| 58 |
+
prediction = None
|
| 59 |
+
image_path = None
|
| 60 |
+
|
| 61 |
+
if request.method == 'POST':
|
| 62 |
+
try:
|
| 63 |
+
# Obtener la imagen desde la solicitud POST
|
| 64 |
+
file = request.files['file']
|
| 65 |
+
|
| 66 |
+
# Guardar la imagen cargada en el directorio de carga
|
| 67 |
+
image_path = os.path.join(app.config['UPLOAD_FOLDER'], file.filename)
|
| 68 |
+
file.save(image_path)
|
| 69 |
+
|
| 70 |
+
# Aplicar la transformaci贸n a la imagen
|
| 71 |
+
image = Image.open(file)
|
| 72 |
+
image_tensor = transform(image).unsqueeze(0)
|
| 73 |
+
|
| 74 |
+
# Obtener la salida del modelo
|
| 75 |
+
output = net(image_tensor)
|
| 76 |
+
|
| 77 |
+
# Aplicar softmax para obtener las probabilidades
|
| 78 |
+
probabilities = F.softmax(output, dim=1)
|
| 79 |
+
|
| 80 |
+
# Obtener la clase predicha y la probabilidad m谩xima
|
| 81 |
+
max_prob, predicted_class = torch.max(probabilities, 1)
|
| 82 |
+
predicted_class_name = classes[predicted_class.item()]
|
| 83 |
+
|
| 84 |
+
# Almacenar el resultado de la predicci贸n
|
| 85 |
+
prediction = {
|
| 86 |
+
'predicted_class': predicted_class_name,
|
| 87 |
+
'probability': round(max_prob.item() * 100, 2)
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
except Exception as e:
|
| 91 |
+
return jsonify({'error': str(e)})
|
| 92 |
+
return render_template('index.html', prediction=prediction, image_path=image_path)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
if __name__ == '__main__':
|
| 96 |
+
app.run(debug=True, host="0.0.0.0", port="7860")
|
files/static/uploads/avion_img.jpg
ADDED
|
files/static/uploads/cat_image.jpg
ADDED
|
files/templates/index.html
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8">
|
| 5 |
+
<meta http-equiv="X-UA-Compatible" content="IE=edge">
|
| 6 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 7 |
+
<title>Predicci贸n de Im谩genes</title>
|
| 8 |
+
</head>
|
| 9 |
+
<body>
|
| 10 |
+
<h1>Sube una imagen para predecir</h1>
|
| 11 |
+
<form action="/predict" method="post" enctype="multipart/form-data">
|
| 12 |
+
<input type="file" name="file" accept="image/*" required>
|
| 13 |
+
<br>
|
| 14 |
+
<input type="submit" value="Predecir">
|
| 15 |
+
</form>
|
| 16 |
+
<br>
|
| 17 |
+
{% if prediction %}
|
| 18 |
+
<h2>Resultado de la Predicci贸n:</h2>
|
| 19 |
+
<p>Clase: {{ prediction['predicted_class'] }}</p>
|
| 20 |
+
<p>Probabilidad: {{ prediction['probability'] }}%</p>
|
| 21 |
+
<img src="{{ image_path }}" alt="Imagen Cargada">
|
| 22 |
+
<!-- <p>{{ image_path }}</p> -->
|
| 23 |
+
{% endif %}
|
| 24 |
+
</body>
|
| 25 |
+
</html>
|