| | import os |
| | import json |
| | import math |
| | import tempfile |
| | from typing import Any, Dict, List, Tuple |
| |
|
| | import gradio as gr |
| | import numpy as np |
| | import cv2 |
| | from PIL import Image, ImageDraw, ImageFont |
| |
|
| | from inference_sdk import InferenceHTTPClient |
| | from inference_sdk.http.errors import HTTPCallErrorError |
| |
|
| |
|
| | |
| | client = InferenceHTTPClient( |
| | api_url="https://serverless.roboflow.com", |
| | api_key=os.environ["ROBOFLOW_API_KEY"] |
| | ) |
| |
|
| | WORKSPACE = "ata-assignment-1-mkqz4" |
| | WORKFLOW_ID = "custom-workflow-4" |
| |
|
| |
|
| | |
| | def _extract_predictions(result: Any) -> List[Dict[str, Any]]: |
| | """ |
| | Workflows can return different shapes. We try common patterns. |
| | Returns a list of prediction dicts with either: |
| | - x,y,width,height (+ class, confidence), OR |
| | - bbox dict |
| | """ |
| | |
| | item = result[0] if isinstance(result, list) and result else result |
| |
|
| | if not isinstance(item, dict): |
| | return [] |
| |
|
| | |
| | if isinstance(item.get("predictions"), dict) and isinstance(item["predictions"].get("predictions"), list): |
| | return item["predictions"]["predictions"] |
| |
|
| | |
| | for v in item.values(): |
| | if isinstance(v, dict) and isinstance(v.get("predictions"), list): |
| | return v["predictions"] |
| |
|
| | |
| | if isinstance(item.get("predictions"), list): |
| | return item["predictions"] |
| |
|
| | return [] |
| |
|
| |
|
| | def _draw_boxes_pil(img: Image.Image, preds: List[Dict[str, Any]], conf_thresh: float) -> Image.Image: |
| | img = img.convert("RGB") |
| | draw = ImageDraw.Draw(img) |
| |
|
| | try: |
| | font = ImageFont.load_default() |
| | except Exception: |
| | font = None |
| |
|
| | for p in preds: |
| | conf = float(p.get("confidence", p.get("conf", 0.0))) |
| | if conf < conf_thresh: |
| | continue |
| |
|
| | cls = p.get("class", p.get("label", "obj")) |
| |
|
| | |
| | if all(k in p for k in ["x", "y", "width", "height"]): |
| | x, y, w, h = float(p["x"]), float(p["y"]), float(p["width"]), float(p["height"]) |
| | x1, y1, x2, y2 = x - w / 2, y - h / 2, x + w / 2, y + h / 2 |
| |
|
| | |
| | elif isinstance(p.get("bbox"), dict): |
| | b = p["bbox"] |
| | if all(k in b for k in ["x1", "y1", "x2", "y2"]): |
| | x1, y1, x2, y2 = map(float, (b["x1"], b["y1"], b["x2"], b["y2"])) |
| | elif all(k in b for k in ["left", "top", "right", "bottom"]): |
| | x1, y1, x2, y2 = map(float, (b["left"], b["top"], b["right"], b["bottom"])) |
| | else: |
| | continue |
| | else: |
| | continue |
| |
|
| | draw.rectangle([x1, y1, x2, y2], width=3) |
| | label = f"{cls} {conf:.2f}" |
| | draw.text((x1, max(0, y1 - 14)), label, font=font) |
| |
|
| | return img |
| |
|
| |
|
| | def _run_on_image_path(image_path: str, use_cache: bool) -> Any: |
| | return client.run_workflow( |
| | workspace_name=WORKSPACE, |
| | workflow_id=WORKFLOW_ID, |
| | images={"image": image_path}, |
| | use_cache=use_cache |
| | ) |
| |
|
| |
|
| | |
| | def infer_image(image_path: str, use_cache: bool, conf_thresh: float): |
| | if image_path is None: |
| | return None, {"error": "No image uploaded."} |
| |
|
| | try: |
| | result = _run_on_image_path(image_path, use_cache=use_cache) |
| | preds = _extract_predictions(result) |
| |
|
| | img = Image.open(image_path) |
| | annotated = _draw_boxes_pil(img, preds, conf_thresh=conf_thresh) |
| | return annotated, result |
| |
|
| | except HTTPCallErrorError as e: |
| | return None, { |
| | "error": "Roboflow request failed", |
| | "status_code": getattr(e, "status_code", None), |
| | "api_message": getattr(e, "api_message", str(e)), |
| | "description": str(e), |
| | } |
| |
|
| |
|
| | |
| | def infer_video(video_path: str, use_cache: bool, conf_thresh: float, fps_out: int, sample_every_n: int): |
| | """ |
| | Reads video, runs workflow on every Nth frame, draws boxes, writes annotated mp4. |
| | For non-sampled frames, we reuse the last predictions (so boxes persist smoothly). |
| | """ |
| | if video_path is None: |
| | return None, {"error": "No video uploaded."} |
| |
|
| | cap = cv2.VideoCapture(video_path) |
| | if not cap.isOpened(): |
| | return None, {"error": "Could not open video."} |
| |
|
| | |
| | width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
| | height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
| | in_fps = cap.get(cv2.CAP_PROP_FPS) or 30.0 |
| |
|
| | |
| | tmp_out = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") |
| | tmp_out.close() |
| |
|
| | fourcc = cv2.VideoWriter_fourcc(*"mp4v") |
| | writer = cv2.VideoWriter(tmp_out.name, fourcc, float(fps_out), (width, height)) |
| |
|
| | frame_idx = 0 |
| | last_preds: List[Dict[str, Any]] = [] |
| | summary = { |
| | "input_fps": in_fps, |
| | "output_fps": fps_out, |
| | "sample_every_n_frames": sample_every_n, |
| | "frames_processed": 0, |
| | "workflow_calls": 0, |
| | "example_results": [] |
| | } |
| |
|
| | try: |
| | while True: |
| | ok, frame = cap.read() |
| | if not ok: |
| | break |
| |
|
| | |
| | if frame_idx % sample_every_n == 0: |
| | |
| | with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp_img: |
| | cv2.imwrite(tmp_img.name, frame) |
| | result = _run_on_image_path(tmp_img.name, use_cache=use_cache) |
| | preds = _extract_predictions(result) |
| | last_preds = preds |
| | summary["workflow_calls"] += 1 |
| |
|
| | |
| | if len(summary["example_results"]) < 3: |
| | summary["example_results"].append(result) |
| |
|
| | |
| | pil = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) |
| | pil = _draw_boxes_pil(pil, last_preds, conf_thresh=conf_thresh) |
| | out_frame = cv2.cvtColor(np.array(pil), cv2.COLOR_RGB2BGR) |
| |
|
| | writer.write(out_frame) |
| | summary["frames_processed"] += 1 |
| | frame_idx += 1 |
| |
|
| | except HTTPCallErrorError as e: |
| | return None, { |
| | "error": "Roboflow request failed during video processing", |
| | "status_code": getattr(e, "status_code", None), |
| | "api_message": getattr(e, "api_message", str(e)), |
| | "description": str(e), |
| | } |
| | finally: |
| | cap.release() |
| | writer.release() |
| |
|
| | return tmp_out.name, summary |
| |
|
| |
|
| | |
| | with gr.Blocks(title="Roboflow Workflow Runner (Image + Video)") as demo: |
| | gr.Markdown("# Roboflow Workflow Runner (Image + Video)\nUpload an image or a video, run your workflow, and see bounding boxes.") |
| |
|
| | with gr.Tab("Image"): |
| | img_in = gr.Image(type="filepath", label="Upload an image") |
| | img_cache = gr.Checkbox(value=True, label="Use cache (faster for repeat requests)") |
| | img_conf = gr.Slider(0.0, 1.0, value=0.25, step=0.05, label="Confidence threshold") |
| |
|
| | img_btn = gr.Button("Run on Image") |
| | img_out = gr.Image(type="pil", label="Annotated image") |
| | img_json = gr.JSON(label="Raw workflow result") |
| |
|
| | img_btn.click(fn=infer_image, inputs=[img_in, img_cache, img_conf], outputs=[img_out, img_json]) |
| |
|
| | with gr.Tab("Video"): |
| | vid_in = gr.Video(label="Upload a video") |
| | vid_cache = gr.Checkbox(value=True, label="Use cache (usually OFF for video, but you can try)") |
| | vid_conf = gr.Slider(0.0, 1.0, value=0.25, step=0.05, label="Confidence threshold") |
| | sample_every_n = gr.Slider(1, 30, value=5, step=1, label="Run inference every N frames (higher = cheaper/faster)") |
| | fps_out = gr.Slider(5, 30, value=15, step=1, label="Output video FPS") |
| |
|
| | vid_btn = gr.Button("Run on Video") |
| | vid_out = gr.Video(label="Annotated video") |
| | vid_summary = gr.JSON(label="Video summary (includes a few sample results)") |
| |
|
| | vid_btn.click( |
| | fn=infer_video, |
| | inputs=[vid_in, vid_cache, vid_conf, fps_out, sample_every_n], |
| | outputs=[vid_out, vid_summary] |
| | ) |
| |
|
| | demo.launch() |
| |
|