sabiq-api / main.py
Rahaf2001's picture
Update main.py
46a15d0 verified
from fastapi import FastAPI, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from ultralytics import YOLO
import os, tempfile, random
app = FastAPI()
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
print("Loading model...")
model = YOLO("best.pt")
print("Model ready")
CLASS_NAMES = {0: "crack", 1: "other", 2: "pothole"}
RIYADH_LAT = (24.55, 24.85)
RIYADH_LNG = (46.55, 46.85)
def random_riyadh():
return round(random.uniform(*RIYADH_LAT), 6), round(random.uniform(*RIYADH_LNG), 6)
def severity(conf, area):
if conf > 0.85 and area > 0.05:
return "high"
elif conf > 0.65:
return "medium"
return "low"
def is_image(name):
return name.lower().rsplit(".", 1)[-1] in ("jpg", "jpeg", "png", "bmp", "webp")
def process_image(path):
results = model.predict(source=path, conf=0.25, verbose=False)
out = []
for r in results:
for box in r.boxes:
cls = int(box.cls)
conf = float(box.conf)
xywhn = box.xywhn[0].tolist()
lat, lng = random_riyadh()
out.append({
"damage_type": CLASS_NAMES.get(cls, "other"),
"confidence" : round(conf, 3),
"severity" : severity(conf, xywhn[2] * xywhn[3]),
"bbox" : xywhn,
"frame" : 0,
"latitude" : lat,
"longitude" : lng,
})
return out
def process_video(path):
results = model.track(
source=path,
conf=0.25,
tracker="bytetrack.yaml",
stream=True,
verbose=False,
save=False,
)
seen = {}
for frame_idx, r in enumerate(results):
if r.boxes is None or r.boxes.id is None:
continue
for tid, cls, conf, xywhn in zip(
r.boxes.id.int().tolist(),
r.boxes.cls.int().tolist(),
r.boxes.conf.tolist(),
r.boxes.xywhn.tolist(),
):
if tid not in seen or conf > seen[tid]["confidence"]:
lat = seen[tid]["latitude"] if tid in seen else random_riyadh()[0]
lng = seen[tid]["longitude"] if tid in seen else random_riyadh()[1]
seen[tid] = {
"damage_type": CLASS_NAMES.get(cls, "other"),
"confidence" : round(conf, 3),
"severity" : severity(conf, xywhn[2] * xywhn[3]),
"bbox" : xywhn,
"frame" : frame_idx,
"latitude" : lat,
"longitude" : lng,
}
return list(seen.values())
@app.get("/")
def root():
return {"status": "SABIQ API running"}
@app.post("/detect")
async def detect(file: UploadFile = File(...)):
contents = await file.read()
suffix = "." + file.filename.split(".")[-1]
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
tmp.write(contents)
tmp_path = tmp.name
try:
detections = process_image(tmp_path) if is_image(file.filename) else process_video(tmp_path)
finally:
os.unlink(tmp_path)
return {"total": len(detections), "detections": detections}