Myloiose commited on
Commit
185285d
verified
1 Parent(s): 07dc2c8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -33
app.py CHANGED
@@ -1,39 +1,54 @@
1
- from fastapi import FastAPI, UploadFile, File
2
- from fastapi.middleware.cors import CORSMiddleware
3
- import requests
4
- import base64
5
- import logging
 
 
6
 
7
- # Configuraci贸n b谩sica
8
- app = FastAPI()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- # Agregar soporte de CORS
11
- app.add_middleware(
12
- CORSMiddleware,
13
- allow_origins=["*"], # Cambia esto a tus dominios espec铆ficos si es necesario
14
- allow_methods=["*"],
15
- allow_headers=["*"],
16
- )
17
 
18
- # URL del modelo en Hugging Face Space
19
- HUGGINGFACE_API = "https://mobilenetv1-tflite-demo.myloiose.hf.space/predict"
 
 
 
 
 
 
 
 
 
 
20
 
21
- # Endpoint para hacer la predicci贸n
22
- @app.post("/predict/")
23
  async def predict(file: UploadFile = File(...)):
24
- # Leer el archivo y convertirlo a base64
25
- img_bytes = await file.read()
26
- img_b64 = "data:image/jpeg;base64," + base64.b64encode(img_bytes).decode()
27
-
28
- # Crear el payload
29
- payload = {"data": [img_b64]}
30
-
31
- try:
32
- # Realizar la petici贸n al Space de Hugging Face
33
- response = requests.post(HUGGINGFACE_API, json=payload)
34
- response.raise_for_status() # Asegura que no haya errores de red
35
- return response.json() # Retorna la respuesta del Space
36
- except requests.exceptions.RequestException as e:
37
- logging.error(f"Error al conectar con Hugging Face: {str(e)}")
38
- return {"error": "Error en la predicci贸n", "details": str(e)}
39
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile
2
+ from fastapi.responses import JSONResponse
3
+ import numpy as np
4
+ import tensorflow as tf
5
+ from PIL import Image
6
+ import io
7
+ import uvicorn
8
 
9
+ # Cargar etiquetas y modelo
10
+ with open("labels_mobilenet_quant_v1_224.txt", "r") as f:
11
+ labels = f.read().splitlines()
12
+
13
+ interpreter = tf.lite.Interpreter(model_path="mobilenet_v1_1.0_224_quant.tflite")
14
+ interpreter.allocate_tensors()
15
+
16
+ input_details = interpreter.get_input_details()
17
+ output_details = interpreter.get_output_details()
18
+
19
+ # Clasificaci贸n
20
+ def classify_image(image: Image.Image):
21
+ image = image.convert("RGB").resize((224, 224))
22
+ input_data = np.expand_dims(np.array(image, dtype=np.uint8), axis=0)
23
+
24
+ interpreter.set_tensor(input_details[0]["index"], input_data)
25
+ interpreter.invoke()
26
+ output_data = interpreter.get_tensor(output_details[0]["index"])
27
 
28
+ # Ajuste de cuantizaci贸n
29
+ output_scale, output_zero_point = output_details[0]["quantization"]
30
+ output = output_scale * (output_data.astype(np.float32) - output_zero_point)
 
 
 
 
31
 
32
+ pred_idx = np.argmax(output[0])
33
+ pred_label = labels[pred_idx] if pred_idx < len(labels) else "Etiqueta no encontrada"
34
+
35
+ # Softmax
36
+ exp_scores = np.exp(output[0] - np.max(output[0]))
37
+ probabilities = exp_scores / np.sum(exp_scores)
38
+ confidence = probabilities[pred_idx]
39
+
40
+ return pred_label, float(confidence)
41
+
42
+ # Crear API
43
+ app = FastAPI()
44
 
45
+ @app.post("/predict")
 
46
  async def predict(file: UploadFile = File(...)):
47
+ contents = await file.read()
48
+ image = Image.open(io.BytesIO(contents))
49
+ label, conf = classify_image(image)
50
+ return JSONResponse({"label": label, "confidence": f"{conf*100:.2f}%"})
 
 
 
 
 
 
 
 
 
 
 
51
 
52
+ # Para correr localmente
53
+ # if __name__ == "__main__":
54
+ # uvicorn.run(app, host="0.0.0.0", port=7860)