Stroke-ia commited on
Commit
a55aeb4
·
verified ·
1 Parent(s): dc5057c

Update api.py

Browse files
Files changed (1) hide show
  1. api.py +67 -26
api.py CHANGED
@@ -1,15 +1,19 @@
1
  from fastapi import FastAPI, UploadFile, File, HTTPException, Security, Depends
2
  from fastapi.security.api_key import APIKeyHeader
3
- from fastapi.responses import FileResponse, JSONResponse
4
  import uvicorn
5
- import shutil
6
- import os
 
 
7
  from ultralytics import YOLO
 
 
8
 
9
  # ==========================
10
  # 🔑 Sécurité : API Key
11
  # ==========================
12
- API_KEY = "12345" # <-- Change ici avant de partager
13
  api_key_header = APIKeyHeader(name="X-API-Key")
14
 
15
  def verify_api_key(api_key: str = Security(api_key_header)):
@@ -43,13 +47,13 @@ async def predict(
43
  api_key: str = Depends(verify_api_key)
44
  ):
45
  try:
46
- # Sauvegarde temporaire
47
- temp_file = f"temp_{file.filename}"
48
- with open(temp_file, "wb") as buffer:
49
- shutil.copyfileobj(file.file, buffer)
50
 
51
  # Prédiction YOLO
52
- results = model.predict(temp_file)
53
 
54
  output = []
55
  for r in results:
@@ -57,16 +61,13 @@ async def predict(
57
  output.append({
58
  "class": r.names[int(box.cls[0].item())],
59
  "confidence": float(box.conf[0].item()),
60
- "bbox": box.xyxy[0].tolist() # [x1, y1, x2, y2]
61
  })
62
 
63
  return JSONResponse(content={"predictions": output})
64
 
65
  except Exception as e:
66
  raise HTTPException(status_code=500, detail=str(e))
67
- finally:
68
- if os.path.exists(temp_file):
69
- os.remove(temp_file)
70
 
71
  # ==========================
72
  # 📦 Endpoint Image
@@ -77,25 +78,65 @@ async def predict_image(
77
  api_key: str = Depends(verify_api_key)
78
  ):
79
  try:
80
- # Sauvegarde temporaire
81
- temp_file = f"temp_{file.filename}"
82
- with open(temp_file, "wb") as buffer:
83
- shutil.copyfileobj(file.file, buffer)
84
 
85
- # Prédiction YOLO avec sauvegarde image annotée
86
- results = model.predict(temp_file, save=True, project="runs", name="detect", exist_ok=True)
 
87
 
88
- # Récupérer la dernière image annotée
89
- output_dir = results[0].save_dir
90
- output_path = os.path.join(output_dir, os.path.basename(temp_file))
 
 
91
 
92
- return FileResponse(output_path, media_type="image/png")
93
 
94
  except Exception as e:
95
  raise HTTPException(status_code=500, detail=str(e))
96
- finally:
97
- if os.path.exists(temp_file):
98
- os.remove(temp_file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
  # ==========================
101
  # 🚀 Lancement local
 
1
  from fastapi import FastAPI, UploadFile, File, HTTPException, Security, Depends
2
  from fastapi.security.api_key import APIKeyHeader
3
+ from fastapi.responses import JSONResponse, StreamingResponse
4
  import uvicorn
5
+ import io
6
+ import numpy as np
7
+ from PIL import Image
8
+ import cv2
9
  from ultralytics import YOLO
10
+ import requests
11
+ import os
12
 
13
  # ==========================
14
  # 🔑 Sécurité : API Key
15
  # ==========================
16
+ API_KEY = "1234" # <-- Change ici avant de partager
17
  api_key_header = APIKeyHeader(name="X-API-Key")
18
 
19
  def verify_api_key(api_key: str = Security(api_key_header)):
 
47
  api_key: str = Depends(verify_api_key)
48
  ):
49
  try:
50
+ # Lire directement en mémoire
51
+ contents = await file.read()
52
+ image = Image.open(io.BytesIO(contents)).convert("RGB")
53
+ np_image = np.array(image)
54
 
55
  # Prédiction YOLO
56
+ results = model.predict(np_image, conf=0.5, verbose=False)
57
 
58
  output = []
59
  for r in results:
 
61
  output.append({
62
  "class": r.names[int(box.cls[0].item())],
63
  "confidence": float(box.conf[0].item()),
64
+ "bbox": box.xyxy[0].tolist()
65
  })
66
 
67
  return JSONResponse(content={"predictions": output})
68
 
69
  except Exception as e:
70
  raise HTTPException(status_code=500, detail=str(e))
 
 
 
71
 
72
  # ==========================
73
  # 📦 Endpoint Image
 
78
  api_key: str = Depends(verify_api_key)
79
  ):
80
  try:
81
+ # Lire directement en mémoire
82
+ contents = await file.read()
83
+ image = Image.open(io.BytesIO(contents)).convert("RGB")
84
+ np_image = np.array(image)
85
 
86
+ # Prédiction YOLO + image annotée
87
+ results = model.predict(np_image, conf=0.5, verbose=False)
88
+ annotated = results[0].plot()
89
 
90
+ # Convertir en bytes pour StreamingResponse
91
+ annotated_pil = Image.fromarray(cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB))
92
+ img_byte_arr = io.BytesIO()
93
+ annotated_pil.save(img_byte_arr, format="PNG")
94
+ img_byte_arr.seek(0)
95
 
96
+ return StreamingResponse(img_byte_arr, media_type="image/png")
97
 
98
  except Exception as e:
99
  raise HTTPException(status_code=500, detail=str(e))
100
+
101
+ # ==========================
102
+ # 🧪 Endpoint Test interne
103
+ # ==========================
104
+ @app.get("/test_request/")
105
+ async def test_request():
106
+ """
107
+ Test interne de l'API déployée sur Hugging Face.
108
+ Utilise une image locale 'test.jpg' (⚠️ à placer dans ton repo Space).
109
+ """
110
+ try:
111
+ file_path = "test.jpg" # ⚠️ Mets une image dans ton Space
112
+ base_url = "https://stroke-ia-api.hf.space" # ⚠️ adapte au nom exact de ton Space
113
+
114
+ if not os.path.exists(file_path):
115
+ return {"error": f"{file_path} introuvable dans le Space."}
116
+
117
+ # Test JSON
118
+ url_predict = f"{base_url}/v1/predict/"
119
+ files = {"file": open(file_path, "rb")}
120
+ headers = {"X-API-Key": API_KEY}
121
+ response = requests.post(url_predict, files=files, headers=headers)
122
+ json_result = response.json()
123
+
124
+ # Test image annotée
125
+ url_img = f"{base_url}/v1/predict_image/"
126
+ files = {"file": open(file_path, "rb")}
127
+ response_img = requests.post(url_img, files=files, headers=headers)
128
+
129
+ with open("result.png", "wb") as f:
130
+ f.write(response_img.content)
131
+
132
+ return {
133
+ "message": "✅ Test request exécuté sur Hugging Face API. Résultats sauvegardés.",
134
+ "json_result": json_result,
135
+ "saved_image": "result.png"
136
+ }
137
+
138
+ except Exception as e:
139
+ return {"error": str(e)}
140
 
141
  # ==========================
142
  # 🚀 Lancement local