Files changed (1) hide show
  1. main.py +80 -61
main.py CHANGED
@@ -1,88 +1,107 @@
1
  from fastapi import FastAPI, UploadFile, File
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from ultralytics import YOLO
4
- import os, tempfile
5
 
6
  app = FastAPI()
7
-
8
- app.add_middleware(
9
- CORSMiddleware,
10
- allow_origins=["*"],
11
- allow_methods=["*"],
12
- allow_headers=["*"],
13
- )
14
-
15
- #tesssssssst
16
- # هذا لو كنت ابغى يعتمد على مودل موجود بالهقنق فيس
17
- # MODEL_URL = "https://huggingface.co/Rahaf2001/sabiq-road-detection/resolve/main/best.pt"
18
-
19
- # def download_model():
20
- # if not os.path.exists("best.pt"):
21
- # print("Downloading model...")
22
- # r = requests.get(MODEL_URL)
23
- # with open("best.pt", "wb") as f:
24
- # f.write(r.content)
25
- # print(" Model ready")
26
 
27
  print("Loading model...")
28
  model = YOLO("best.pt")
29
  print("Model ready")
30
 
31
- CLASS_NAMES = {
32
- 0: 'crack',
33
- 1: 'other',
34
- 2: 'pothole'
35
- }
 
 
 
36
 
37
- def get_severity(conf, area):
38
  if conf > 0.85 and area > 0.05:
39
- return 'high'
40
  elif conf > 0.65:
41
- return 'medium'
42
- else:
43
- return 'low'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  @app.get("/")
46
  def root():
47
  return {"status": "SABIQ API running"}
48
 
 
49
  @app.post("/detect")
50
  async def detect(file: UploadFile = File(...)):
51
  contents = await file.read()
 
52
 
53
- suffix = '.' + file.filename.split('.')[-1]
54
  with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
55
  tmp.write(contents)
56
  tmp_path = tmp.name
57
 
58
- results = model.predict(
59
- source = tmp_path,
60
- conf = 0.25,
61
- verbose = False,
62
- stream = True
63
- )
64
-
65
- all_detections = []
66
- frame_num = 0
67
-
68
- for result in results:
69
- for box in result.boxes:
70
- cls = int(box.cls)
71
- conf = float(box.conf)
72
- xywhn = box.xywhn[0].tolist()
73
- area = xywhn[2] * xywhn[3]
74
- all_detections.append({
75
- "damage_type": CLASS_NAMES.get(cls, 'other'),
76
- "confidence" : round(conf, 3),
77
- "severity" : get_severity(conf, area),
78
- "bbox" : xywhn,
79
- "frame" : frame_num
80
- })
81
- frame_num += 1
82
-
83
- os.unlink(tmp_path)
84
 
85
- return {
86
- "total" : len(all_detections),
87
- "detections": all_detections
88
- }
 
1
  from fastapi import FastAPI, UploadFile, File
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from ultralytics import YOLO
4
+ import os, tempfile, random
5
 
6
  app = FastAPI()
7
+ app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  print("Loading model...")
10
  model = YOLO("best.pt")
11
  print("Model ready")
12
 
13
+ CLASS_NAMES = {0: "crack", 1: "other", 2: "pothole"}
14
+ RIYADH_LAT = (24.55, 24.85)
15
+ RIYADH_LNG = (46.55, 46.85)
16
+
17
+
18
+ def random_riyadh():
19
+ return round(random.uniform(*RIYADH_LAT), 6), round(random.uniform(*RIYADH_LNG), 6)
20
+
21
 
22
+ def severity(conf, area):
23
  if conf > 0.85 and area > 0.05:
24
+ return "high"
25
  elif conf > 0.65:
26
+ return "medium"
27
+ return "low"
28
+
29
+
30
+ def is_image(name):
31
+ return name.lower().rsplit(".", 1)[-1] in ("jpg", "jpeg", "png", "bmp", "webp")
32
+
33
+
34
+ def process_image(path):
35
+ results = model.predict(source=path, conf=0.25, verbose=False)
36
+ out = []
37
+ for r in results:
38
+ for box in r.boxes:
39
+ cls = int(box.cls)
40
+ conf = float(box.conf)
41
+ xywhn = box.xywhn[0].tolist()
42
+ lat, lng = random_riyadh()
43
+ out.append({
44
+ "damage_type": CLASS_NAMES.get(cls, "other"),
45
+ "confidence": round(conf, 3),
46
+ "severity": severity(conf, xywhn[2] * xywhn[3]),
47
+ "bbox": xywhn,
48
+ "frame": 0,
49
+ "latitude": lat,
50
+ "longitude": lng,
51
+ })
52
+ return out
53
+
54
+
55
+ def process_video(path):
56
+ results = model.track(
57
+ source=path, conf=0.25, tracker="bytetrack.yaml",
58
+ stream=True, verbose=False, save=True,
59
+ )
60
+
61
+ seen = {} # track_id -> best detection
62
+
63
+ for frame_idx, r in enumerate(results):
64
+ if r.boxes is None or r.boxes.id is None:
65
+ continue
66
+
67
+ for tid, cls, conf, xywhn in zip(
68
+ r.boxes.id.int().tolist(),
69
+ r.boxes.cls.int().tolist(),
70
+ r.boxes.conf.tolist(),
71
+ r.boxes.xywhn.tolist(),
72
+ ):
73
+ if tid not in seen or conf > seen[tid]["confidence"]:
74
+ lat, lng = random_riyadh() if tid not in seen else (seen[tid]["latitude"], seen[tid]["longitude"])
75
+ seen[tid] = {
76
+ "damage_type": CLASS_NAMES.get(cls, "other"),
77
+ "confidence": round(conf, 3),
78
+ "severity": severity(conf, xywhn[2] * xywhn[3]),
79
+ "bbox": xywhn,
80
+ "frame": frame_idx,
81
+ "latitude": lat,
82
+ "longitude": lng,
83
+ }
84
+
85
+ return list(seen.values())
86
+
87
 
88
  @app.get("/")
89
  def root():
90
  return {"status": "SABIQ API running"}
91
 
92
+
93
  @app.post("/detect")
94
  async def detect(file: UploadFile = File(...)):
95
  contents = await file.read()
96
+ suffix = "." + file.filename.split(".")[-1]
97
 
 
98
  with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
99
  tmp.write(contents)
100
  tmp_path = tmp.name
101
 
102
+ try:
103
+ detections = process_image(tmp_path) if is_image(file.filename) else process_video(tmp_path)
104
+ finally:
105
+ os.unlink(tmp_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
+ return {"total": len(detections), "detections": detections}