Spaces:
Sleeping
Sleeping
File size: 3,276 Bytes
bc84e9e d593ddf bc84e9e d593ddf bc84e9e da951e1 bc84e9e da951e1 d593ddf bc84e9e d593ddf 19cb9c2 d593ddf 19cb9c2 d593ddf 46a15d0 d593ddf 46a15d0 d593ddf 46a15d0 d593ddf 46a15d0 d593ddf 46a15d0 d593ddf 46a15d0 d593ddf bc84e9e 19cb9c2 bc84e9e d593ddf bc84e9e 19cb9c2 46a15d0 f9da080 c0b49be f9da080 d593ddf f9da080 46a15d0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 | from fastapi import FastAPI, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from ultralytics import YOLO
import os, tempfile, random
app = FastAPI()
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
print("Loading model...")
model = YOLO("best.pt")
print("Model ready")
CLASS_NAMES = {0: "crack", 1: "other", 2: "pothole"}
RIYADH_LAT = (24.55, 24.85)
RIYADH_LNG = (46.55, 46.85)
def random_riyadh():
return round(random.uniform(*RIYADH_LAT), 6), round(random.uniform(*RIYADH_LNG), 6)
def severity(conf, area):
if conf > 0.85 and area > 0.05:
return "high"
elif conf > 0.65:
return "medium"
return "low"
def is_image(name):
return name.lower().rsplit(".", 1)[-1] in ("jpg", "jpeg", "png", "bmp", "webp")
def process_image(path):
results = model.predict(source=path, conf=0.25, verbose=False)
out = []
for r in results:
for box in r.boxes:
cls = int(box.cls)
conf = float(box.conf)
xywhn = box.xywhn[0].tolist()
lat, lng = random_riyadh()
out.append({
"damage_type": CLASS_NAMES.get(cls, "other"),
"confidence" : round(conf, 3),
"severity" : severity(conf, xywhn[2] * xywhn[3]),
"bbox" : xywhn,
"frame" : 0,
"latitude" : lat,
"longitude" : lng,
})
return out
def process_video(path):
results = model.track(
source=path,
conf=0.25,
tracker="bytetrack.yaml",
stream=True,
verbose=False,
save=False,
)
seen = {}
for frame_idx, r in enumerate(results):
if r.boxes is None or r.boxes.id is None:
continue
for tid, cls, conf, xywhn in zip(
r.boxes.id.int().tolist(),
r.boxes.cls.int().tolist(),
r.boxes.conf.tolist(),
r.boxes.xywhn.tolist(),
):
if tid not in seen or conf > seen[tid]["confidence"]:
lat = seen[tid]["latitude"] if tid in seen else random_riyadh()[0]
lng = seen[tid]["longitude"] if tid in seen else random_riyadh()[1]
seen[tid] = {
"damage_type": CLASS_NAMES.get(cls, "other"),
"confidence" : round(conf, 3),
"severity" : severity(conf, xywhn[2] * xywhn[3]),
"bbox" : xywhn,
"frame" : frame_idx,
"latitude" : lat,
"longitude" : lng,
}
return list(seen.values())
@app.get("/")
def root():
return {"status": "SABIQ API running"}
@app.post("/detect")
async def detect(file: UploadFile = File(...)):
contents = await file.read()
suffix = "." + file.filename.split(".")[-1]
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
tmp.write(contents)
tmp_path = tmp.name
try:
detections = process_image(tmp_path) if is_image(file.filename) else process_video(tmp_path)
finally:
os.unlink(tmp_path)
return {"total": len(detections), "detections": detections}
|