Spaces:
Sleeping
Sleeping
| import os | |
| import subprocess | |
| import shutil | |
| import base64 | |
| import json | |
| from typing import Optional | |
| from fastapi import FastAPI, UploadFile, File, Request, HTTPException, Form | |
| from fastapi.responses import HTMLResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.templating import Jinja2Templates | |
| from ultralytics import YOLO | |
| import cv2 | |
| import numpy as np | |
| from pathlib import Path | |
| import uuid | |
| import time | |
| from fastapi import BackgroundTasks | |
| from fastapi.responses import FileResponse | |
| app = FastAPI() | |
| # Setup paths | |
| BASE_DIR = Path(__file__).resolve().parent | |
| UPLOAD_DIR = BASE_DIR / "uploads" | |
| MODEL_DIR = UPLOAD_DIR / "models" | |
| VIDEO_DIR = UPLOAD_DIR / "videos" | |
| RESULT_DIR = UPLOAD_DIR / "results" | |
| TEMP_DIR = UPLOAD_DIR / "temp" | |
| for d in [MODEL_DIR, TEMP_DIR, VIDEO_DIR, RESULT_DIR]: | |
| d.mkdir(parents=True, exist_ok=True) | |
| # Global model state and task tracking | |
| current_model = None | |
| model_name = "" | |
| video_tasks = {} # task_id: {"progress": P, "status": S, "result": R} | |
| app.mount("/static", StaticFiles(directory=str(BASE_DIR / "static")), name="static") | |
| templates = Jinja2Templates(directory=str(BASE_DIR / "templates")) | |
| async def read_root(request: Request): | |
| return templates.TemplateResponse( | |
| request=request, | |
| name="index.html", | |
| context={ | |
| "model_loaded": current_model is not None, | |
| "model_name": model_name | |
| } | |
| ) | |
| async def upload_model(file: UploadFile = File(...)): | |
| global current_model, model_name | |
| if not file.filename.endswith(".pt"): | |
| raise HTTPException(status_code=400, detail="Only .pt files are supported") | |
| file_path = MODEL_DIR / file.filename | |
| with open(file_path, "wb") as buffer: | |
| shutil.copyfileobj(file.file, buffer) | |
| try: | |
| current_model = YOLO(str(file_path)) | |
| model_name = file.filename | |
| return {"status": "success", "message": f"Model {model_name} loaded successfully"} | |
| except Exception as e: | |
| if os.path.exists(file_path): | |
| os.remove(file_path) | |
| raise HTTPException(status_code=500, detail=f"Failed to load model: {str(e)}") | |
| def apply_roi_filter(results, roi, img_w, img_h): | |
| if not roi: | |
| return results, [] | |
| x1_roi = int(roi['x1'] * img_w / 100) | |
| y1_roi = int(roi['y1'] * img_h / 100) | |
| x2_roi = int(roi['x2'] * img_w / 100) | |
| y2_roi = int(roi['y2'] * img_h / 100) | |
| indices = [] | |
| for i, box in enumerate(results.boxes): | |
| bx1, by1, bx2, by2 = box.xyxy[0].tolist() | |
| bcx = (bx1 + bx2) / 2 | |
| bcy = (by1 + by2) / 2 | |
| if x1_roi <= bcx <= x2_roi and y1_roi <= bcy <= y2_roi: | |
| indices.append(i) | |
| results.boxes = results.boxes[indices] | |
| return results, [x1_roi, y1_roi, x2_roi, y2_roi] | |
| def draw_roi_on_img(img, roi_coords): | |
| if not roi_coords: | |
| return img | |
| x1, y1, x2, y2 = roi_coords | |
| # Draw a dashed or semi-transparent rectangle for ROI | |
| overlay = img.copy() | |
| cv2.rectangle(overlay, (x1, y1), (x2, y2), (0, 255, 255), 2) | |
| cv2.putText(overlay, "ROI ZONE", (x1 + 5, y1 + 25), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2) | |
| return cv2.addWeighted(overlay, 0.6, img, 0.4, 0) | |
| async def run_inference( | |
| file: UploadFile = File(...), | |
| conf_min: float = Form(0.25), | |
| conf_max: float = Form(1.0), | |
| roi: Optional[str] = Form(None) | |
| ): | |
| global current_model | |
| if current_model is None: | |
| raise HTTPException(status_code=400, detail="No model loaded. Please upload a model first.") | |
| # Parse ROI if present | |
| roi_data = json.loads(roi) if roi else None | |
| # Read image | |
| contents = await file.read() | |
| nparr = np.frombuffer(contents, np.uint8) | |
| img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) | |
| if img is None: | |
| raise HTTPException(status_code=400, detail="Invalid image file") | |
| h, w = img.shape[:2] | |
| # Run inference with min threshold | |
| results = current_model(img, conf=conf_min)[0] | |
| # Apply max confidence filtering | |
| if conf_max < 1.0: | |
| indices = [i for i, box in enumerate(results.boxes) if float(box.conf[0]) <= conf_max] | |
| results.boxes = results.boxes[indices] | |
| # Apply ROI filtering | |
| results, roi_coords = apply_roi_filter(results, roi_data, w, h) | |
| # Draw results | |
| annotated_img = results.plot() | |
| # Draw ROI box | |
| annotated_img = draw_roi_on_img(annotated_img, roi_coords) | |
| # Encode to base64 | |
| _, buffer = cv2.imencode('.jpg', annotated_img) | |
| img_str = base64.b64encode(buffer).decode('utf-8') | |
| # Extract box info | |
| boxes = [] | |
| for box in results.boxes: | |
| boxes.append({ | |
| "cls": int(box.cls[0]), | |
| "conf": float(box.conf[0]), | |
| "xyxy": box.xyxy[0].tolist() | |
| }) | |
| return { | |
| "status": "success", | |
| "image": f"data:image/jpeg;base64,{img_str}", | |
| "count": len(results.boxes), | |
| "boxes": boxes | |
| } | |
| def process_video_task(task_id: str, input_path: str, output_path: str, conf_min: float, conf_max: float, roi: Optional[dict]): | |
| global current_model, video_tasks | |
| # Temporary path for OpenCV output | |
| temp_output = str(RESULT_DIR / f"temp_{task_id}.mp4") | |
| try: | |
| cap = cv2.VideoCapture(input_path) | |
| if not cap.isOpened(): | |
| video_tasks[task_id]["status"] = "error" | |
| video_tasks[task_id]["message"] = "Could not open video file" | |
| return | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| # Using mp4v for the intermediate file | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| out = cv2.VideoWriter(temp_output, fourcc, fps, (w, h)) | |
| frame_count = 0 | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| # Inference with min threshold | |
| results = current_model(frame, conf=conf_min)[0] | |
| # Apply max confidence filtering | |
| if conf_max < 1.0: | |
| indices = [i for i, box in enumerate(results.boxes) if float(box.conf[0]) <= conf_max] | |
| results.boxes = results.boxes[indices] | |
| # Apply ROI filtering | |
| results, roi_coords = apply_roi_filter(results, roi, w, h) | |
| # Draw results | |
| annotated_frame = results.plot() | |
| # Draw ROI box | |
| annotated_frame = draw_roi_on_img(annotated_frame, roi_coords) | |
| out.write(annotated_frame) | |
| frame_count += 1 | |
| # Update progress (0-90% for processing) | |
| progress = int((frame_count / total_frames) * 90) | |
| video_tasks[task_id]["progress"] = progress | |
| cap.release() | |
| out.release() | |
| # Transcode to H.264 for web compatibility | |
| video_tasks[task_id]["progress"] = 95 | |
| video_tasks[task_id]["status"] = "transcoding" | |
| ffmpeg_cmd = [ | |
| 'ffmpeg', '-y', '-i', temp_output, | |
| '-c:v', 'libx264', '-preset', 'ultrafast', '-crf', '28', | |
| '-pix_fmt', 'yuv420p', '-c:a', 'aac', '-b:a', '128k', | |
| output_path | |
| ] | |
| subprocess.run(ffmpeg_cmd, check=True, capture_output=True) | |
| video_tasks[task_id]["progress"] = 100 | |
| video_tasks[task_id]["status"] = "completed" | |
| video_tasks[task_id]["result_url"] = f"/video-result/{task_id}" | |
| except Exception as e: | |
| video_tasks[task_id]["status"] = "error" | |
| video_tasks[task_id]["message"] = str(e) | |
| finally: | |
| # Cleanup files | |
| if os.path.exists(input_path): | |
| os.remove(input_path) | |
| if os.path.exists(temp_output): | |
| os.remove(temp_output) | |
| async def run_video_inference( | |
| background_tasks: BackgroundTasks, | |
| file: UploadFile = File(...), | |
| conf_min: float = Form(0.25), | |
| conf_max: float = Form(1.0), | |
| roi: Optional[str] = Form(None) | |
| ): | |
| global current_model, video_tasks | |
| if current_model is None: | |
| raise HTTPException(status_code=400, detail="No model loaded. Please upload a model first.") | |
| # Parse ROI | |
| roi_data = json.loads(roi) if roi else None | |
| task_id = str(uuid.uuid4()) | |
| input_filename = f"{task_id}_{file.filename}" | |
| input_path = VIDEO_DIR / input_filename | |
| output_filename = f"processed_{task_id}.mp4" | |
| output_path = RESULT_DIR / output_filename | |
| with open(input_path, "wb") as buffer: | |
| shutil.copyfileobj(file.file, buffer) | |
| video_tasks[task_id] = { | |
| "progress": 0, | |
| "status": "processing", | |
| "filename": file.filename | |
| } | |
| background_tasks.add_task(process_video_task, task_id, str(input_path), str(output_path), conf_min, conf_max, roi_data) | |
| return {"status": "success", "task_id": task_id} | |
| async def get_video_progress(task_id: str): | |
| if task_id not in video_tasks: | |
| raise HTTPException(status_code=404, detail="Task not found") | |
| return video_tasks[task_id] | |
| async def get_video_result(task_id: str): | |
| output_filename = f"processed_{task_id}.mp4" | |
| output_path = RESULT_DIR / output_filename | |
| if not output_path.exists(): | |
| raise HTTPException(status_code=404, detail="Result not found or still processing") | |
| return FileResponse(path=output_path, filename=f"inference_{task_id}.mp4", media_type="video/mp4") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| # Use port from environment variable for Hugging Face compatibility (default 7860) | |
| port = int(os.environ.get("PORT", 7860)) | |
| uvicorn.run(app, host="0.0.0.0", port=port) | |