Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, UploadFile, File | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from ultralytics import YOLO | |
| import os, tempfile, random | |
| app = FastAPI() | |
| app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) | |
| print("Loading model...") | |
| model = YOLO("best.pt") | |
| print("Model ready") | |
| CLASS_NAMES = {0: "crack", 1: "other", 2: "pothole"} | |
| RIYADH_LAT = (24.55, 24.85) | |
| RIYADH_LNG = (46.55, 46.85) | |
| def random_riyadh(): | |
| return round(random.uniform(*RIYADH_LAT), 6), round(random.uniform(*RIYADH_LNG), 6) | |
| def severity(conf, area): | |
| if conf > 0.85 and area > 0.05: | |
| return "high" | |
| elif conf > 0.65: | |
| return "medium" | |
| return "low" | |
| def is_image(name): | |
| return name.lower().rsplit(".", 1)[-1] in ("jpg", "jpeg", "png", "bmp", "webp") | |
| def process_image(path): | |
| results = model.predict(source=path, conf=0.25, verbose=False) | |
| out = [] | |
| for r in results: | |
| for box in r.boxes: | |
| cls = int(box.cls) | |
| conf = float(box.conf) | |
| xywhn = box.xywhn[0].tolist() | |
| lat, lng = random_riyadh() | |
| out.append({ | |
| "damage_type": CLASS_NAMES.get(cls, "other"), | |
| "confidence" : round(conf, 3), | |
| "severity" : severity(conf, xywhn[2] * xywhn[3]), | |
| "bbox" : xywhn, | |
| "frame" : 0, | |
| "latitude" : lat, | |
| "longitude" : lng, | |
| }) | |
| return out | |
| def process_video(path): | |
| results = model.track( | |
| source=path, | |
| conf=0.25, | |
| tracker="bytetrack.yaml", | |
| stream=True, | |
| verbose=False, | |
| save=False, | |
| ) | |
| seen = {} | |
| for frame_idx, r in enumerate(results): | |
| if r.boxes is None or r.boxes.id is None: | |
| continue | |
| for tid, cls, conf, xywhn in zip( | |
| r.boxes.id.int().tolist(), | |
| r.boxes.cls.int().tolist(), | |
| r.boxes.conf.tolist(), | |
| r.boxes.xywhn.tolist(), | |
| ): | |
| if tid not in seen or conf > seen[tid]["confidence"]: | |
| lat = seen[tid]["latitude"] if tid in seen else random_riyadh()[0] | |
| lng = seen[tid]["longitude"] if tid in seen else random_riyadh()[1] | |
| seen[tid] = { | |
| "damage_type": CLASS_NAMES.get(cls, "other"), | |
| "confidence" : round(conf, 3), | |
| "severity" : severity(conf, xywhn[2] * xywhn[3]), | |
| "bbox" : xywhn, | |
| "frame" : frame_idx, | |
| "latitude" : lat, | |
| "longitude" : lng, | |
| } | |
| return list(seen.values()) | |
| def root(): | |
| return {"status": "SABIQ API running"} | |
| async def detect(file: UploadFile = File(...)): | |
| contents = await file.read() | |
| suffix = "." + file.filename.split(".")[-1] | |
| with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp: | |
| tmp.write(contents) | |
| tmp_path = tmp.name | |
| try: | |
| detections = process_image(tmp_path) if is_image(file.filename) else process_video(tmp_path) | |
| finally: | |
| os.unlink(tmp_path) | |
| return {"total": len(detections), "detections": detections} | |