Rahaf2001 commited on
Commit
19cb9c2
·
verified ·
1 Parent(s): 0da372a

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +54 -15
main.py CHANGED
@@ -2,7 +2,8 @@ from fastapi import FastAPI, UploadFile, File
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from ultralytics import YOLO
4
  from PIL import Image
5
- import io, requests, os
 
6
 
7
  app = FastAPI()
8
 
@@ -21,36 +22,74 @@ def download_model():
21
  r = requests.get(MODEL_URL)
22
  with open("best.pt", "wb") as f:
23
  f.write(r.content)
24
- print(" Model ready")
25
 
26
  download_model()
27
  model = YOLO("best.pt")
28
 
29
- CLASS_NAMES = { 0: 'crack', 1: 'other',2: 'pothole'}
 
 
 
 
30
 
31
  def get_severity(conf, area):
32
- if conf > 0.85 and area > 0.05: return 'high'
33
- elif conf > 0.65: return 'medium'
34
- else: return 'low'
 
 
 
35
 
36
  @app.get("/")
37
  def root():
38
- return {"status": "SABIQ API running "}
39
 
40
  @app.post("/detect")
41
  async def detect(file: UploadFile = File(...)):
42
- img = Image.open(io.BytesIO(await file.read()))
43
- results = model(img)
44
- detections = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  for box in results[0].boxes:
46
- cls = int(box.cls)
47
- conf = float(box.conf)
48
- xywhn = box.xywhn[0].tolist()
49
- area = xywhn[2] * xywhn[3]
 
50
  detections.append({
51
  "damage_type": CLASS_NAMES.get(cls, 'other'),
52
  "confidence" : round(conf, 3),
53
  "severity" : get_severity(conf, area),
54
  "bbox" : xywhn
55
  })
56
- return {"total": len(detections), "detections": detections}
 
 
 
 
 
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from ultralytics import YOLO
4
  from PIL import Image
5
+ import io, requests, os, tempfile
6
+ import cv2
7
 
8
  app = FastAPI()
9
 
 
22
  r = requests.get(MODEL_URL)
23
  with open("best.pt", "wb") as f:
24
  f.write(r.content)
25
+ print(" Model ready")
26
 
27
  download_model()
28
  model = YOLO("best.pt")
29
 
30
+ CLASS_NAMES = {
31
+ 0: 'crack',
32
+ 1: 'other',
33
+ 2: 'pothole'
34
+ }
35
 
36
  def get_severity(conf, area):
37
+ if conf > 0.85 and area > 0.05:
38
+ return 'high'
39
+ elif conf > 0.65:
40
+ return 'medium'
41
+ else:
42
+ return 'low'
43
 
44
  @app.get("/")
45
  def root():
46
+ return {"status": "SABIQ API running"}
47
 
48
  @app.post("/detect")
49
  async def detect(file: UploadFile = File(...)):
50
+ contents = await file.read()
51
+ filename = file.filename.lower()
52
+
53
+ # لو فيديو
54
+ if any(filename.endswith(ext) for ext in ['.mp4', '.avi', '.mov', '.mkv']):
55
+ with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp:
56
+ tmp.write(contents)
57
+ tmp_path = tmp.name
58
+
59
+ cap = cv2.VideoCapture(tmp_path)
60
+ ret, frame = cap.read()
61
+ cap.release()
62
+ os.unlink(tmp_path)
63
+
64
+ if not ret:
65
+ return {"error": "cannot read video", "total": 0, "detections": []}
66
+
67
+ img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
68
+
69
+ # لو صورة
70
+ else:
71
+ try:
72
+ img = Image.open(io.BytesIO(contents))
73
+ except Exception as e:
74
+ return {"error": str(e), "total": 0, "detections": []}
75
+
76
+ results = model(img)
77
+ detections = []
78
+
79
  for box in results[0].boxes:
80
+ cls = int(box.cls)
81
+ conf = float(box.conf)
82
+ xywhn = box.xywhn[0].tolist()
83
+ area = xywhn[2] * xywhn[3]
84
+
85
  detections.append({
86
  "damage_type": CLASS_NAMES.get(cls, 'other'),
87
  "confidence" : round(conf, 3),
88
  "severity" : get_severity(conf, area),
89
  "bbox" : xywhn
90
  })
91
+
92
+ return {
93
+ "total" : len(detections),
94
+ "detections": detections
95
+ }