Spaces:
Running
Running
| import os | |
| import json | |
| import time | |
| import uuid | |
| import shutil | |
| import traceback | |
| import re | |
| import sys | |
| import importlib.util | |
| import threading | |
| import subprocess | |
| from urllib.parse import urlsplit, urlunsplit | |
| import cv2 | |
| import numpy as np | |
| from fastapi import FastAPI, File, UploadFile, Form, Request, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.concurrency import run_in_threadpool | |
| from fastapi.staticfiles import StaticFiles | |
| from typing import List | |
| from huggingface_hub import hf_hub_download, snapshot_download, HfApi | |
| app = FastAPI(title="Sporalize Labs 3D Analysis Engine") | |
| CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| def default_runtime_root(): | |
| if os.path.isdir("/data"): | |
| return os.path.join("/data", "sporalize_runtime") | |
| return os.path.join(os.path.expanduser("~"), ".sporalize_runtime") | |
| RUNTIME_ROOT = os.environ.get("SPORALIZE_RUNTIME_DIR", default_runtime_root()) | |
| ASSETS_RUNTIME_ROOT = os.environ.get("SPORALIZE_ASSETS_DIR", os.path.join(RUNTIME_ROOT, "assets")) | |
| WEIGHTS_RUNTIME_ROOT = os.environ.get("SPORALIZE_WEIGHTS_DIR", os.path.join(RUNTIME_ROOT, "weights")) | |
| DEFAULT_LOCAL_STORAGE_ROOT = os.path.join(CURRENT_DIR, "Storage") | |
| if not os.path.isdir("/data") and not os.access(CURRENT_DIR, os.W_OK): | |
| DEFAULT_LOCAL_STORAGE_ROOT = os.path.join(RUNTIME_ROOT, "storage") | |
| STORAGE_ROOT = os.environ.get( | |
| "SPORALIZE_STORAGE_DIR", | |
| os.path.join("/data", "sporalize_storage") if os.path.isdir("/data") else DEFAULT_LOCAL_STORAGE_ROOT, | |
| ) | |
| STORAGE_DATASET_REPO_ID = os.environ.get("SPORALIZE_STORAGE_REPO_ID", "Shoraky/SporalizeLabs-runtime-private").strip() | |
| STORAGE_DATASET_REPO_TYPE = os.environ.get("SPORALIZE_STORAGE_REPO_TYPE", "dataset").strip() | |
| STORAGE_DATASET_PATH = os.environ.get("SPORALIZE_STORAGE_DATASET_PATH", "Storage").strip("/").strip() | |
| STORAGE_SYNC_INTERVAL_SECONDS = float(os.environ.get("SPORALIZE_STORAGE_SYNC_INTERVAL_SECONDS", "20")) | |
| _storage_sync_lock = threading.Lock() | |
| _storage_last_sync_ts = 0.0 | |
| DEFAULT_WEIGHT_SPECS = { | |
| "POSE_PATH": { | |
| "filename": "vitpose-s-coco_25.onnx", | |
| "repo_id": os.environ.get("SPORALIZE_POSE_MODEL_REPO_ID", "JunkyByte/easy_ViTPose"), | |
| "repo_type": os.environ.get("SPORALIZE_POSE_MODEL_REPO_TYPE", "model"), | |
| "repo_file": os.environ.get("SPORALIZE_POSE_MODEL_FILE", "onnx/coco_25/vitpose-25-s.onnx"), | |
| "override_env": "SPORALIZE_POSE_MODEL_PATH", | |
| "local_fallback": os.path.join(CURRENT_DIR, "Weights", "vitpose-s-coco_25.onnx"), | |
| }, | |
| "YOLO_PATH": { | |
| "filename": "yolov8m.pt", | |
| "repo_id": os.environ.get("SPORALIZE_YOLO_MODEL_REPO_ID", "Ultralytics/YOLOv8"), | |
| "repo_type": os.environ.get("SPORALIZE_YOLO_MODEL_REPO_TYPE", "model"), | |
| "repo_file": os.environ.get("SPORALIZE_YOLO_MODEL_FILE", "yolov8m.pt"), | |
| "override_env": "SPORALIZE_YOLO_MODEL_PATH", | |
| "local_fallback": os.path.join(CURRENT_DIR, "Weights", "yolov8m.pt"), | |
| }, | |
| } | |
| runtime_state = { | |
| "ready": False, | |
| "pipeline_root": None, | |
| "run_pipeline": None, | |
| "weights": {}, | |
| } | |
| def get_hf_token(): | |
| return os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") | |
| def hf_storage_enabled(): | |
| return bool(STORAGE_DATASET_REPO_ID and STORAGE_DATASET_PATH) | |
| def hf_storage_path(*parts: str) -> str: | |
| normalized = [STORAGE_DATASET_PATH] | |
| normalized.extend(part.strip("/").replace("\\", "/") for part in parts if part is not None and str(part).strip("/")) | |
| return "/".join(segment for segment in normalized if segment) | |
| def sync_storage_from_hf(force: bool = False): | |
| global _storage_last_sync_ts | |
| if not hf_storage_enabled(): | |
| return | |
| now = time.time() | |
| if not force and (now - _storage_last_sync_ts) < STORAGE_SYNC_INTERVAL_SECONDS: | |
| return | |
| with _storage_sync_lock: | |
| now = time.time() | |
| if not force and (now - _storage_last_sync_ts) < STORAGE_SYNC_INTERVAL_SECONDS: | |
| return | |
| sync_cache_root = os.path.join(RUNTIME_ROOT, "storage-sync-cache") | |
| os.makedirs(sync_cache_root, exist_ok=True) | |
| local_repo_dir = os.path.join(sync_cache_root, safe_name(STORAGE_DATASET_REPO_ID)) | |
| snapshot_download( | |
| repo_id=STORAGE_DATASET_REPO_ID, | |
| repo_type=STORAGE_DATASET_REPO_TYPE, | |
| token=get_hf_token(), | |
| local_dir=local_repo_dir, | |
| allow_patterns=[f"{STORAGE_DATASET_PATH}/**"], | |
| ) | |
| source_storage = os.path.join(local_repo_dir, STORAGE_DATASET_PATH) | |
| if os.path.isdir(source_storage): | |
| if os.path.isdir(STORAGE_ROOT): | |
| shutil.rmtree(STORAGE_ROOT, ignore_errors=True) | |
| shutil.copytree(source_storage, STORAGE_ROOT, dirs_exist_ok=True) | |
| _storage_last_sync_ts = time.time() | |
| def push_session_to_hf(player_id: str, session_id: str, session_dir: str): | |
| if not hf_storage_enabled(): | |
| return | |
| api = HfApi(token=get_hf_token()) | |
| api.upload_folder( | |
| repo_id=STORAGE_DATASET_REPO_ID, | |
| repo_type=STORAGE_DATASET_REPO_TYPE, | |
| folder_path=session_dir, | |
| path_in_repo=hf_storage_path(safe_name(player_id), safe_name(session_id)), | |
| commit_message=f"Add session {safe_name(session_id)} for player {safe_name(player_id)}", | |
| ) | |
| def delete_session_from_hf(player_id: str, session_id: str): | |
| if not hf_storage_enabled(): | |
| return | |
| api = HfApi(token=get_hf_token()) | |
| api.delete_folder( | |
| repo_id=STORAGE_DATASET_REPO_ID, | |
| repo_type=STORAGE_DATASET_REPO_TYPE, | |
| path_in_repo=hf_storage_path(safe_name(player_id), safe_name(session_id)), | |
| commit_message=f"Delete session {safe_name(session_id)} for player {safe_name(player_id)}", | |
| ) | |
| def delete_player_from_hf(player_id: str): | |
| if not hf_storage_enabled(): | |
| return | |
| api = HfApi(token=get_hf_token()) | |
| api.delete_folder( | |
| repo_id=STORAGE_DATASET_REPO_ID, | |
| repo_type=STORAGE_DATASET_REPO_TYPE, | |
| path_in_repo=hf_storage_path(safe_name(player_id)), | |
| commit_message=f"Delete player {safe_name(player_id)} storage", | |
| ) | |
| def path_has_session_data(directory: str): | |
| if not os.path.isdir(directory): | |
| return False | |
| for _root, _dirs, files in os.walk(directory): | |
| if "session.json" in files: | |
| return True | |
| return False | |
| def seed_storage_if_needed(seed_dir: str, target_dir: str): | |
| if not os.path.isdir(seed_dir): | |
| return | |
| os.makedirs(target_dir, exist_ok=True) | |
| if path_has_session_data(target_dir): | |
| return | |
| shutil.copytree(seed_dir, target_dir, dirs_exist_ok=True) | |
| def resolve_pipeline_root(): | |
| local_pipeline = os.path.join(CURRENT_DIR, "pipeline.py") | |
| local_vitpose = os.path.join(CURRENT_DIR, "ViTPose") | |
| if os.path.isfile(local_pipeline) and os.path.isdir(local_vitpose): | |
| return CURRENT_DIR | |
| repo_id = os.environ.get("SPORALIZE_ASSETS_REPO_ID") | |
| if not repo_id: | |
| raise RuntimeError( | |
| "SPORALIZE_ASSETS_REPO_ID is required when Backend/pipeline.py is not bundled locally." | |
| ) | |
| assets_dir = os.path.join(ASSETS_RUNTIME_ROOT, safe_name(repo_id)) | |
| snapshot_download( | |
| repo_id=repo_id, | |
| repo_type=os.environ.get("SPORALIZE_ASSETS_REPO_TYPE", "dataset"), | |
| revision=os.environ.get("SPORALIZE_ASSETS_REVISION"), | |
| token=get_hf_token(), | |
| local_dir=assets_dir, | |
| allow_patterns=["pipeline.py", "ViTPose/**", "Storage/**", "Weights/**"], | |
| ) | |
| seed_storage_if_needed(os.path.join(assets_dir, "Storage"), STORAGE_ROOT) | |
| return assets_dir | |
| def load_pipeline_callable(pipeline_root: str): | |
| pipeline_path = os.path.join(pipeline_root, "pipeline.py") | |
| if not os.path.isfile(pipeline_path): | |
| raise RuntimeError(f"pipeline.py was not found at {pipeline_path}") | |
| if pipeline_root not in sys.path: | |
| sys.path.insert(0, pipeline_root) | |
| module_name = "sporalize_runtime_pipeline" | |
| if module_name in sys.modules: | |
| del sys.modules[module_name] | |
| spec = importlib.util.spec_from_file_location(module_name, pipeline_path) | |
| if spec is None or spec.loader is None: | |
| raise RuntimeError(f"Unable to create import spec for {pipeline_path}") | |
| module = importlib.util.module_from_spec(spec) | |
| sys.modules[module_name] = module | |
| spec.loader.exec_module(module) | |
| run_pipeline = getattr(module, "run_pipeline", None) | |
| if run_pipeline is None: | |
| raise RuntimeError("run_pipeline was not found in the resolved pipeline module") | |
| return run_pipeline | |
| def ensure_weight_file(spec: dict, pipeline_root: str): | |
| override_path = os.environ.get(spec["override_env"]) | |
| if override_path and os.path.isfile(override_path): | |
| return override_path | |
| pipeline_weight = os.path.join(pipeline_root, "Weights", spec["filename"]) | |
| if os.path.isfile(pipeline_weight): | |
| return pipeline_weight | |
| local_fallback = spec.get("local_fallback") | |
| if local_fallback and os.path.isfile(local_fallback): | |
| return local_fallback | |
| os.makedirs(WEIGHTS_RUNTIME_ROOT, exist_ok=True) | |
| cached_path = os.path.join(WEIGHTS_RUNTIME_ROOT, spec["filename"]) | |
| if os.path.isfile(cached_path): | |
| return cached_path | |
| return hf_hub_download( | |
| repo_id=spec["repo_id"], | |
| repo_type=spec.get("repo_type", "model"), | |
| filename=spec["repo_file"], | |
| token=get_hf_token(), | |
| local_dir=WEIGHTS_RUNTIME_ROOT, | |
| ) | |
| def ensure_runtime_ready(force: bool = False): | |
| if runtime_state["ready"] and not force: | |
| return runtime_state | |
| os.makedirs(RUNTIME_ROOT, exist_ok=True) | |
| os.makedirs(STORAGE_ROOT, exist_ok=True) | |
| pipeline_root = resolve_pipeline_root() | |
| run_pipeline = load_pipeline_callable(pipeline_root) | |
| weight_paths = {name: ensure_weight_file(spec, pipeline_root) for name, spec in DEFAULT_WEIGHT_SPECS.items()} | |
| runtime_state.update({ | |
| "ready": True, | |
| "pipeline_root": pipeline_root, | |
| "run_pipeline": run_pipeline, | |
| "weights": weight_paths, | |
| }) | |
| return runtime_state | |
| os.makedirs(STORAGE_ROOT, exist_ok=True) | |
| app.mount("/storage", StaticFiles(directory=STORAGE_ROOT), name="storage") | |
| progress_store = {} | |
| cancel_store = {} | |
| def safe_name(value: str) -> str: | |
| allowed = [] | |
| for ch in str(value): | |
| if ch.isalnum() or ch in ("-", "_", "."): | |
| allowed.append(ch) | |
| else: | |
| allowed.append("_") | |
| cleaned = "".join(allowed).strip("._") | |
| return cleaned or "item" | |
| def session_storage_paths(player_id: str, session_id: str): | |
| player_dir = os.path.join(STORAGE_ROOT, safe_name(player_id)) | |
| session_dir = os.path.join(player_dir, safe_name(session_id)) | |
| videos_dir = os.path.join(session_dir, "videos") | |
| return player_dir, session_dir, videos_dir | |
| def list_session_files(): | |
| session_files = [] | |
| for root, _, files in os.walk(STORAGE_ROOT): | |
| if "session.json" in files: | |
| session_files.append(os.path.join(root, "session.json")) | |
| return sorted(session_files, key=os.path.getmtime, reverse=True) | |
| def build_storage_url(request: Request, *parts: str) -> str: | |
| relative = "/".join(safe_name(part) if idx < len(parts) - 1 else part.replace("\\", "/") for idx, part in enumerate(parts)) | |
| configured_public_base = os.environ.get("SPORALIZE_PUBLIC_BASE_URL", "").strip() | |
| if configured_public_base: | |
| return configured_public_base.rstrip("/") + "/storage/" + relative.lstrip("/") | |
| forwarded_proto = request.headers.get("x-forwarded-proto", "").split(",")[0].strip().lower() | |
| forwarded_host = request.headers.get("x-forwarded-host", "").split(",")[0].strip() | |
| if forwarded_proto in ("http", "https") and forwarded_host: | |
| return f"{forwarded_proto}://{forwarded_host}".rstrip("/") + "/storage/" + relative.lstrip("/") | |
| base_url = str(request.base_url).rstrip("/") | |
| if forwarded_proto in ("http", "https"): | |
| parsed = urlsplit(base_url) | |
| base_url = urlunsplit((forwarded_proto, parsed.netloc, parsed.path, "", "")).rstrip("/") | |
| return base_url + "/storage/" + relative.lstrip("/") | |
| def parse_video_timecode(value, fps=30.0): | |
| if value is None: | |
| return 0.0 | |
| if isinstance(value, (int, float, np.integer, np.floating)): | |
| return max(0.0, float(value)) | |
| parts = str(value).split(":") | |
| if len(parts) == 4: | |
| h, m, s, f = [int(float(part or 0)) for part in parts] | |
| return max(0.0, (h * 3600) + (m * 60) + s + (f / max(1.0, float(fps)))) | |
| try: | |
| return max(0.0, float(value)) | |
| except Exception: | |
| return 0.0 | |
| def detect_camera_id(file_name: str): | |
| match = re.search(r"_cam_(\d+)_", file_name) | |
| if match: | |
| return int(match.group(1)) | |
| return None | |
| def split_camera_video_variants(videos_dir: str): | |
| grouped: dict[int, dict[str, str]] = {} | |
| for file_name in sorted(os.listdir(videos_dir)): | |
| camera_id = detect_camera_id(file_name) | |
| if camera_id is None: | |
| continue | |
| full_path = os.path.join(videos_dir, file_name) | |
| if not os.path.isfile(full_path): | |
| continue | |
| bucket = grouped.setdefault(camera_id, {}) | |
| if file_name.lower().endswith(".web.mp4"): | |
| bucket["web"] = full_path | |
| else: | |
| bucket["base"] = full_path | |
| primary_map = {} | |
| fallback_map = {} | |
| for camera_id in sorted(grouped.keys()): | |
| base_path = grouped[camera_id].get("base") | |
| web_path = grouped[camera_id].get("web") | |
| if base_path: | |
| primary_map[camera_id] = base_path | |
| if web_path: | |
| fallback_map[camera_id] = web_path | |
| elif web_path: | |
| primary_map[camera_id] = web_path | |
| return primary_map, fallback_map | |
| def normalize_video_for_web(input_path: str) -> str: | |
| """ | |
| Re-encode uploaded video to a browser-friendly MP4 stream. | |
| Falls back to the original file if normalization fails. | |
| """ | |
| output_path = os.path.splitext(input_path)[0] + ".web.mp4" | |
| # First try ffmpeg -> H.264 + yuv420p for maximum browser compatibility. | |
| ffmpeg_cmd = [ | |
| "ffmpeg", | |
| "-y", | |
| "-i", | |
| input_path, | |
| "-an", | |
| "-c:v", | |
| "libx264", | |
| "-preset", | |
| "veryfast", | |
| "-pix_fmt", | |
| "yuv420p", | |
| "-movflags", | |
| "+faststart", | |
| output_path, | |
| ] | |
| try: | |
| ffmpeg_proc = subprocess.run(ffmpeg_cmd, capture_output=True, text=True) | |
| if ffmpeg_proc.returncode == 0 and os.path.exists(output_path) and os.path.getsize(output_path) > 0: | |
| return output_path | |
| except Exception: | |
| pass | |
| # Fallback: OpenCV transcode when ffmpeg is unavailable. | |
| cap = cv2.VideoCapture(input_path) | |
| if not cap.isOpened(): | |
| return input_path | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| if not fps or not np.isfinite(fps) or fps <= 0: | |
| fps = 30.0 | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) or 0) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) or 0) | |
| if width <= 0 or height <= 0: | |
| cap.release() | |
| return input_path | |
| writer = cv2.VideoWriter( | |
| output_path, | |
| cv2.VideoWriter_fourcc(*"mp4v"), | |
| float(fps), | |
| (width, height), | |
| ) | |
| if not writer.isOpened(): | |
| cap.release() | |
| return input_path | |
| frame_count = 0 | |
| success = True | |
| while True: | |
| ok, frame = cap.read() | |
| if not ok: | |
| break | |
| writer.write(frame) | |
| frame_count += 1 | |
| cap.release() | |
| writer.release() | |
| if frame_count <= 0: | |
| success = False | |
| elif not os.path.exists(output_path) or os.path.getsize(output_path) <= 0: | |
| success = False | |
| if not success: | |
| if os.path.exists(output_path): | |
| os.remove(output_path) | |
| return input_path | |
| return output_path | |
| def build_camera_video_entries(request: Request, player_id: str, session_id: str, camera_map): | |
| return [ | |
| { | |
| "cameraId": int(camera_id), | |
| "url": build_storage_url( | |
| request, | |
| safe_name(player_id), | |
| safe_name(session_id), | |
| "videos", | |
| os.path.basename(video_path), | |
| ), | |
| } | |
| for camera_id, video_path in sorted(camera_map.items()) | |
| if video_path and os.path.exists(video_path) | |
| ] | |
| def build_action_clip_entries(request: Request, player_id: str, session_id: str, clip_map): | |
| return [ | |
| { | |
| "cameraId": int(camera_id), | |
| "url": build_storage_url( | |
| request, | |
| safe_name(player_id), | |
| safe_name(session_id), | |
| "clips", | |
| os.path.basename(clip_path), | |
| ), | |
| } | |
| for camera_id, clip_path in sorted(clip_map.items()) | |
| if clip_path and os.path.exists(clip_path) | |
| ] | |
| def normalize_session_payload(session: dict, request: Request): | |
| session_id = session.get("id") | |
| player_id = session.get("playerId") | |
| if not session_id or not player_id: | |
| return session | |
| session_dir = find_session_path(session_id) | |
| if not session_dir: | |
| return session | |
| videos_dir = os.path.join(session_dir, "videos") | |
| if not os.path.isdir(videos_dir): | |
| return session | |
| camera_map, fallback_camera_map = split_camera_video_variants(videos_dir) | |
| if not camera_map: | |
| return session | |
| normalized_actions = [] | |
| for action in session.get("actions", []): | |
| normalized_action = dict(action) | |
| fps = float(normalized_action.get("fps") or 30.0) | |
| fps = max(1.0, fps) | |
| absolute_start_frame = normalized_action.get("sourceStartFrame") | |
| absolute_end_frame = normalized_action.get("sourceEndFrame") | |
| if absolute_start_frame is None or absolute_end_frame is None: | |
| absolute_start_frame = normalized_action.get("startFrame") | |
| absolute_end_frame = normalized_action.get("endFrame") | |
| try: | |
| absolute_start_frame = int(absolute_start_frame) if absolute_start_frame is not None else None | |
| absolute_end_frame = int(absolute_end_frame) if absolute_end_frame is not None else None | |
| except Exception: | |
| absolute_start_frame = None | |
| absolute_end_frame = None | |
| if absolute_start_frame is not None and absolute_end_frame is not None and absolute_end_frame >= absolute_start_frame: | |
| start_seconds = max(0.0, absolute_start_frame / fps) | |
| end_seconds = max(start_seconds, (absolute_end_frame + 1) / fps) | |
| normalized_action["startFrame"] = absolute_start_frame | |
| normalized_action["endFrame"] = absolute_end_frame | |
| else: | |
| total_frames = int(normalized_action.get("totalFrames") or 0) | |
| start_seconds = parse_video_timecode(normalized_action.get("start"), fps=fps) | |
| if total_frames > 0: | |
| end_seconds = start_seconds + (total_frames / fps) | |
| else: | |
| end_seconds = max(start_seconds, parse_video_timecode(normalized_action.get("end"), fps=fps)) | |
| normalized_action["cameraClips"] = normalized_action.get("cameraClips") or build_camera_video_entries( | |
| request, player_id, session_id, camera_map | |
| ) | |
| if fallback_camera_map: | |
| normalized_action["sourceCameraClips"] = normalized_action.get("sourceCameraClips") or build_camera_video_entries( | |
| request, player_id, session_id, fallback_camera_map | |
| ) | |
| normalized_action["startSeconds"] = round(start_seconds, 6) | |
| normalized_action["endSeconds"] = round(end_seconds, 6) | |
| normalized_actions.append(normalized_action) | |
| normalized_session = dict(session) | |
| normalized_session["actions"] = normalized_actions | |
| return normalized_session | |
| def json_default(value): | |
| if isinstance(value, np.generic): | |
| return value.item() | |
| if isinstance(value, np.ndarray): | |
| return value.tolist() | |
| raise TypeError(f"Object of type {type(value).__name__} is not JSON serializable") | |
| def export_action_clips(camera_map, clips_dir, action_index, start_frame, end_frame, fps): | |
| os.makedirs(clips_dir, exist_ok=True) | |
| frame_count = max(0, end_frame - start_frame + 1) | |
| clip_paths = {} | |
| for camera_id, video_path in sorted(camera_map.items()): | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): | |
| continue | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) or 0) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) or 0) | |
| if width <= 0 or height <= 0: | |
| cap.release() | |
| continue | |
| clip_name = f"action_{action_index:02d}_cam_{camera_id}.mp4" | |
| clip_path = os.path.join(clips_dir, clip_name) | |
| writer = cv2.VideoWriter( | |
| clip_path, | |
| cv2.VideoWriter_fourcc(*"mp4v"), | |
| max(1.0, float(fps)), | |
| (width, height), | |
| ) | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame) | |
| written = 0 | |
| while written < frame_count: | |
| ok, frame = cap.read() | |
| if not ok: | |
| break | |
| writer.write(frame) | |
| written += 1 | |
| writer.release() | |
| cap.release() | |
| if written > 0 and os.path.exists(clip_path): | |
| clip_paths[camera_id] = clip_path | |
| elif os.path.exists(clip_path): | |
| os.remove(clip_path) | |
| return clip_paths | |
| def load_session_by_id(session_id: str): | |
| target_name = safe_name(session_id) | |
| for session_file in list_session_files(): | |
| session_dir = os.path.basename(os.path.dirname(session_file)) | |
| if session_dir != target_name: | |
| continue | |
| with open(session_file, "r", encoding="utf-8") as f: | |
| return json.load(f) | |
| return None | |
| def find_session_path(session_id: str): | |
| target_name = safe_name(session_id) | |
| for session_file in list_session_files(): | |
| session_dir = os.path.dirname(session_file) | |
| if os.path.basename(session_dir) == target_name: | |
| return session_dir | |
| return None | |
| def player_storage_path(player_id: str): | |
| return os.path.join(STORAGE_ROOT, safe_name(player_id)) | |
| def get_cors_origins(): | |
| configured = os.environ.get("CORS_ALLOW_ORIGINS", "*").strip() | |
| if not configured or configured == "*": | |
| return ["*"] | |
| return [origin.strip() for origin in configured.split(",") if origin.strip()] | |
| def startup_event(): | |
| ensure_runtime_ready() | |
| sync_storage_from_hf(force=True) | |
| def healthz(): | |
| runtime = ensure_runtime_ready() | |
| return { | |
| "status": "ok", | |
| "storageRoot": STORAGE_ROOT, | |
| "pipelineRoot": runtime.get("pipeline_root"), | |
| } | |
| def cancel_processing(client_id: str): | |
| cancel_store[client_id] = True | |
| return {"status": "cancelled"} | |
| def get_progress(client_id: str): | |
| return progress_store.get(client_id, {"progress": 0.0, "phase": "Initializing"}) | |
| def get_session(session_id: str, request: Request): | |
| session = load_session_by_id(session_id) | |
| if session is None: | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| return normalize_session_payload(session, request) | |
| def get_archive(request: Request): | |
| sessions = [] | |
| for session_file in list_session_files(): | |
| try: | |
| with open(session_file, "r", encoding="utf-8") as f: | |
| session = json.load(f) | |
| sessions.append(normalize_session_payload(session, request)) | |
| except Exception: | |
| continue | |
| return {"sessions": sessions} | |
| def delete_session(session_id: str): | |
| session_dir = find_session_path(session_id) | |
| if session_dir is None: | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| player_dir = os.path.dirname(session_dir) | |
| player_id = os.path.basename(player_dir) | |
| shutil.rmtree(session_dir, ignore_errors=True) | |
| delete_session_from_hf(player_id, session_id) | |
| if os.path.isdir(player_dir) and not os.listdir(player_dir): | |
| os.rmdir(player_dir) | |
| return {"status": "deleted", "sessionId": session_id} | |
| def delete_player(player_id: str): | |
| player_dir = player_storage_path(player_id) | |
| if not os.path.isdir(player_dir): | |
| raise HTTPException(status_code=404, detail="Player storage not found") | |
| shutil.rmtree(player_dir, ignore_errors=True) | |
| delete_player_from_hf(player_id) | |
| return {"status": "deleted", "playerId": player_id} | |
| cors_origins = get_cors_origins() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=cors_origins, | |
| allow_credentials=(cors_origins != ["*"]), | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| def format_metric_series(name, unit, values_list): | |
| return { | |
| "name": name, | |
| "unit": unit, | |
| "values": [ | |
| {"frame": i, "value": safe_float(v)} | |
| for i, v in enumerate(values_list) | |
| ] | |
| } | |
| def safe_float(value): | |
| try: | |
| number = float(value) | |
| return None if np.isnan(number) else number | |
| except Exception: | |
| return None | |
| def metric_name(key: str) -> str: | |
| return key.replace("_", " ").title() | |
| FULL_INTERVAL_KEYS = [ | |
| "left_knee_angles", | |
| "right_knee_angles", | |
| "torso_pitch_angles", | |
| "head_angles", | |
| "mid_foot_ball_distances", | |
| "left_right_foot_distances", | |
| ] | |
| ACTION_METRIC_LAYOUTS = { | |
| "Pass": { | |
| "pre": ["body_to_ball_angle"], | |
| "in": [ | |
| "body_to_ball_angle", | |
| "l_r_foot_distance", | |
| "trunc_pitch_angle", | |
| "trunc_roll_angle", | |
| "left_foot_orientation_angle", | |
| "right_foot_orientation_angle", | |
| "difference_in_angles", | |
| "l_knee_angle", | |
| "r_knee_angle", | |
| "head_angle", | |
| "head_pitch_angle", | |
| "head_roll_angle", | |
| "stand_foot_angle", | |
| "active_foot_height_pct", | |
| ], | |
| "post": ["head_angle", "body_to_ball_angle"], | |
| "top_level_scalars": ["backward_weighted_angle", "forward_weighted_angle"], | |
| }, | |
| "Shot": { | |
| "pre": ["body_to_ball_angle"], | |
| "in": [ | |
| "body_to_ball_angle", | |
| "l_r_foot_distance", | |
| "trunc_pitch_angle", | |
| "trunc_roll_angle", | |
| "left_foot_orientation_angle", | |
| "right_foot_orientation_angle", | |
| "difference_in_angles", | |
| "l_knee_angle", | |
| "r_knee_angle", | |
| "head_angle", | |
| "head_pitch_angle", | |
| "head_roll_angle", | |
| "stand_foot_angle", | |
| "l_elbow_shoulder_hip_angle", | |
| "r_elbow_shoulder_hip_angle", | |
| "active_ankle_angle", | |
| ], | |
| "post": ["head_angle", "body_to_ball_angle"], | |
| "top_level_scalars": ["backward_weighted_angle", "forward_weighted_angle"], | |
| }, | |
| "Receive": { | |
| "pre": ["body_orientation_vs_ball", "head_angle"], | |
| "in": [ | |
| "head_angle", | |
| "l_knee_angle", | |
| "r_knee_angle", | |
| "trunc_pitch_angle", | |
| "trunc_roll_angle", | |
| "left_foot_orientation_angle", | |
| "right_foot_orientation_angle", | |
| "difference_in_angles", | |
| "l_r_foot_distance", | |
| "stand_foot_angle", | |
| "body_orientation_vs_ball", | |
| "active_foot_height_pct", | |
| ], | |
| "post": ["mid_feet_ball_dist", "ball_height_pct_body"], | |
| "top_level_scalars": [], | |
| }, | |
| "Dribble": { | |
| "frames": [ | |
| "ball_feet_distance", | |
| "trunk_pitch", | |
| "trunk_roll", | |
| "head_angle", | |
| "ball_possession_score", | |
| ], | |
| "top_level_scalars": [], | |
| }, | |
| } | |
| def ordered_metric_keys(observed_keys, preferred_keys=None): | |
| preferred = [key for key in (preferred_keys or []) if key in observed_keys] | |
| extras = sorted(key for key in observed_keys if key not in preferred) | |
| return preferred + extras | |
| def build_series_from_entries(entries, unit_for, skip_keys=None, preferred_keys=None): | |
| skip = {"frame"} | |
| if skip_keys: | |
| skip.update(skip_keys) | |
| metric_keys = set(preferred_keys or []) | |
| for entry in entries: | |
| metric_keys.update( | |
| key for key in entry.keys() | |
| if key not in skip | |
| ) | |
| series = [ | |
| format_metric_series(metric_name(key), unit_for(key), [entry.get(key) for entry in entries]) | |
| for key in ordered_metric_keys(metric_keys, preferred_keys) | |
| ] | |
| return series | |
| def build_scalar_metrics(payload, unit_for, skip_keys=None, preferred_keys=None): | |
| skip = set(skip_keys or []) | |
| metrics = [] | |
| observed_keys = set(key for key in payload.keys() if key not in skip) | |
| observed_keys.update(key for key in (preferred_keys or []) if key not in skip) | |
| for key in ordered_metric_keys(observed_keys, preferred_keys): | |
| if key in skip: | |
| continue | |
| value = safe_float(payload.get(key)) | |
| metrics.append({ | |
| "name": metric_name(key), | |
| "value": round(value, 3) if value is not None else None, | |
| "unit": unit_for(key), | |
| }) | |
| return metrics | |
| def build_top_level_interval_metrics(analytics, unit_for, skip_keys=None, preferred_keys=None): | |
| skip = { | |
| "action", | |
| "active_foot", | |
| "touch_frame", | |
| "pre_action", | |
| "action_frame", | |
| "post_action", | |
| "frames", | |
| } | |
| if skip_keys: | |
| skip.update(skip_keys) | |
| observed_keys = set(preferred_keys or []) | |
| observed_keys.update(analytics.keys()) | |
| series = [] | |
| for key in ordered_metric_keys(observed_keys, preferred_keys): | |
| if key in skip: | |
| continue | |
| values = analytics.get(key) | |
| if values is None: | |
| values = [] | |
| if not isinstance(values, list): | |
| continue | |
| series.append(format_metric_series(metric_name(key), unit_for(key), values)) | |
| order_index = {metric_name(key): idx for idx, key in enumerate(preferred_keys or [])} | |
| return sorted(series, key=lambda item: order_index.get(item["name"], len(order_index))) | |
| async def analyze_endpoint( | |
| request: Request, | |
| playerId: str = Form(...), | |
| targetW: float = Form(...), | |
| targetH: float = Form(...), | |
| clientId: str = Form(...), | |
| videoOrders: List[int] = Form(...), | |
| actionsJson: UploadFile = File(...), | |
| calibration: UploadFile = File(...), | |
| videos: List[UploadFile] = File(...) | |
| ): | |
| temp_dir = None | |
| session_id = f"session-{int(time.time())}-{uuid.uuid4().hex[:6]}" | |
| player_dir, session_dir, videos_dir = session_storage_paths(playerId, session_id) | |
| try: | |
| # Clear any previous cancellation flags | |
| cancel_store.pop(clientId, None) | |
| progress_store[clientId] = {"progress": 2.0, "step": 0, "total": 0, "phase": "Uploading & Validating Data"} | |
| os.makedirs(videos_dir, exist_ok=True) | |
| temp_dir = session_dir | |
| # 1. Store incoming payloads | |
| actions_path = os.path.join(session_dir, "actions.json") | |
| with open(actions_path, "wb") as f: | |
| f.write(await actionsJson.read()) | |
| calib_path = os.path.join(session_dir, "calibration.npz") | |
| with open(calib_path, "wb") as f: | |
| f.write(await calibration.read()) | |
| if len(videoOrders) != len(videos): | |
| raise ValueError("Each uploaded video must include a matching camera order") | |
| if len(set(videoOrders)) != len(videoOrders): | |
| raise ValueError("Camera order values must be unique") | |
| camera_map = {} | |
| fallback_camera_map = {} | |
| for idx, (camera_order, video) in enumerate(zip(videoOrders, videos)): | |
| original_name = video.filename or f"camera_{camera_order}.mp4" | |
| video_name = f"{idx:02d}_cam_{camera_order}_{safe_name(os.path.basename(original_name))}" | |
| vid_path = os.path.join(videos_dir, video_name) | |
| with open(vid_path, "wb") as f: | |
| f.write(await video.read()) | |
| normalized_path = normalize_video_for_web(vid_path) | |
| camera_id = int(camera_order) | |
| camera_map[camera_id] = vid_path | |
| if normalized_path != vid_path and os.path.exists(normalized_path): | |
| fallback_camera_map[camera_id] = normalized_path | |
| progress_store[clientId] = {"progress": 10.0, "step": 0, "total": 0, "phase": "Preparing AI Models"} | |
| runtime = ensure_runtime_ready() | |
| utils_paths = { | |
| "POSE_PATH": runtime["weights"]["POSE_PATH"], | |
| "YOLO_PATH": runtime["weights"]["YOLO_PATH"], | |
| "CALIBRATION_PATH": calib_path, | |
| "ACTIONS_PATH": actions_path | |
| } | |
| sizes = { | |
| "TARGET_SIZE": (int(targetW), int(targetH)), | |
| "YOLO_IMGSZ": 960 | |
| } | |
| print("Starting physical pipeline execution...") | |
| progress_store[clientId] = {"progress": 20.0, "step": 0, "total": 0, "phase": "Extracting 3D Kinematics"} | |
| def progress_tracker(current_act, total_act, step, total_frames): | |
| if cancel_store.get(clientId): | |
| return False # Signal pipeline to abort | |
| base_p = current_act / max(1, total_act) | |
| segment_p = (step / max(1, total_frames)) * (1.0 / max(1, total_act)) | |
| # Rescale 20% to 90% for processing | |
| pct = 20.0 + round((base_p + segment_p) * 70.0, 1) | |
| progress_store[clientId] = { | |
| "progress": pct, | |
| "step": step, | |
| "total": total_frames, | |
| "phase": f"Processing Action {current_act + 1}/{total_act}" | |
| } | |
| return True | |
| # 2. Yield to worker thread to allow concurrent polling from front-end | |
| def execute_pipeline(): | |
| return runtime["run_pipeline"](camera_map, utils_paths, sizes, progress_tracker) | |
| reports = await run_in_threadpool(execute_pipeline) | |
| progress_store[clientId] = {"progress": 100.0, "step": 0, "total": 0, "phase": "Completed"} | |
| raw_reports_path = os.path.join(session_dir, "raw_reports.json") | |
| with open(raw_reports_path, "w", encoding="utf-8") as f: | |
| json.dump(reports, f, indent=2, default=json_default) | |
| # 3. Format output dict perfectly mapping to the Frontend Types | |
| with open(actions_path, "r") as f: | |
| raw_actions = json.load(f).get("actions", []) | |
| formatted_actions = [] | |
| failed_actions = [] | |
| camera_videos = build_camera_video_entries(request, playerId, session_id, camera_map) | |
| source_camera_videos = build_camera_video_entries(request, playerId, session_id, fallback_camera_map) | |
| for i, rep in enumerate(reports): | |
| raw = raw_actions[i] if i < len(raw_actions) else {} | |
| if "error" in rep: | |
| failed_actions.append({ | |
| "id": f"err-{uuid.uuid4().hex[:6]}", | |
| "label": rep.get("action", raw.get("label", "Unknown")), | |
| "start": raw.get("start", "00:00:00:00"), | |
| "end": raw.get("end", "00:00:00:00"), | |
| "error": rep["error"] | |
| }) | |
| continue | |
| an = rep["analytics"] | |
| sf = rep["start_frame"] | |
| ef = rep["end_frame"] | |
| fps = float(rep.get("fps", 30)) | |
| is_dribble = (an.get("action") == "Dribble") | |
| if is_dribble: | |
| tf = (sf + ef) // 2 | |
| else: | |
| tf = an.get("touch_frame", (sf + ef) // 2) | |
| # --- Skeleton: support full COCO-25/WB joint range (0–32) --- | |
| skeleton_frames = [] | |
| raw_ball_history = rep.get("ball_history", {}) | |
| for f_idx in range(sf, ef + 1): | |
| raw_skel = rep["skel_history"].get(f_idx, {}) | |
| raw_ball = raw_ball_history.get(f_idx) | |
| # Find the max joint index present so we don't truncate | |
| max_joint = max(raw_skel.keys()) if raw_skel else 32 | |
| n_joints = max(33, max_joint + 1) | |
| joints = [] | |
| for j in range(n_joints): | |
| pt = raw_skel.get(j) | |
| if pt is not None: | |
| joints.append([float(pt[0]), float(pt[1]), float(pt[2])]) | |
| else: | |
| joints.append([0.0, 0.0, 0.0]) | |
| frame_payload = {"frame": f_idx - sf, "joints": joints} | |
| if raw_ball is not None: | |
| frame_payload["ball"] = [float(raw_ball[0]), float(raw_ball[1]), float(raw_ball[2])] | |
| skeleton_frames.append(frame_payload) | |
| # --- Unit dictionary for known metric names --- | |
| UNITS = { | |
| "head_angle": "°", "l_knee_angle": "°", "r_knee_angle": "°", | |
| "trunc_pitch_angle": "°", "trunc_roll_angle": "°", | |
| "trunk_pitch": "°", "trunk_roll": "°", | |
| "head_pitch_angle": "°", "head_roll_angle": "°", | |
| "left_foot_orientation_angle": "°", "right_foot_orientation_angle": "°", | |
| "difference_in_angles": "°", "body_to_ball_angle": "°", | |
| "body_orientation_vs_ball": "°", "stand_foot_angle": "°", | |
| "active_ankle_angle": "°", "l_elbow_shoulder_hip_angle": "°", | |
| "r_elbow_shoulder_hip_angle": "°", "backward_weighted_angle": "°", | |
| "forward_weighted_angle": "°", "leg_separation_angle": "°", | |
| "l_r_foot_distance": "cm", "l_foot_ball_distance": "cm", | |
| "r_foot_ball_distance": "cm", "mid_feet_ball_dist": "cm", | |
| "active_foot_height_pct": "%", "ball_height_pct_body": "%", | |
| "ball_possession_score": "%", "ball_feet_distance": "cm", | |
| } | |
| def unit_for(key): | |
| return UNITS.get(key, "") | |
| action_layout = ACTION_METRIC_LAYOUTS.get(rep["action"], {}) | |
| if is_dribble: | |
| dribble_frames = an.get("frames", []) | |
| pre_metrics = build_series_from_entries( | |
| dribble_frames, | |
| unit_for, | |
| preferred_keys=action_layout.get("frames"), | |
| ) | |
| in_action_metrics = [] | |
| frame_metric_keys = ordered_metric_keys( | |
| {k for frame in dribble_frames for k in frame.keys() if k != "frame"}, | |
| action_layout.get("frames"), | |
| ) | |
| for key in frame_metric_keys: | |
| numeric_values = [safe_float(frame.get(key)) for frame in dribble_frames] | |
| numeric_values = [value for value in numeric_values if value is not None] | |
| in_action_metrics.append({ | |
| "name": f"Avg {metric_name(key)}", | |
| "value": round(float(np.mean(numeric_values)), 3) if numeric_values else None, | |
| "unit": unit_for(key), | |
| }) | |
| post_metrics = [] | |
| else: | |
| pre_entries = an.get("pre_action", []) | |
| post_entries = an.get("post_action", []) | |
| action_frame_data = an.get("action_frame", {}) | |
| pre_metrics = build_series_from_entries( | |
| pre_entries, | |
| unit_for, | |
| preferred_keys=action_layout.get("pre"), | |
| ) | |
| in_action_metrics = build_scalar_metrics( | |
| action_frame_data, | |
| unit_for, | |
| skip_keys={"active_foot"}, | |
| preferred_keys=action_layout.get("in"), | |
| ) | |
| in_action_metrics.extend( | |
| build_scalar_metrics( | |
| an, | |
| unit_for, | |
| skip_keys={ | |
| "action", | |
| "active_foot", | |
| "touch_frame", | |
| "pre_action", | |
| "action_frame", | |
| "post_action", | |
| "frames", | |
| "left_knee_angles", | |
| "right_knee_angles", | |
| "torso_pitch_angles", | |
| "head_angles", | |
| "mid_foot_ball_distances", | |
| "left_right_foot_distances", | |
| }, | |
| preferred_keys=action_layout.get("top_level_scalars"), | |
| ) | |
| ) | |
| post_metrics = build_series_from_entries( | |
| post_entries, | |
| unit_for, | |
| preferred_keys=action_layout.get("post"), | |
| ) | |
| full_interval_metrics = build_top_level_interval_metrics( | |
| an, | |
| unit_for, | |
| preferred_keys=FULL_INTERVAL_KEYS, | |
| ) | |
| formatted_actions.append({ | |
| "id": f"{rep['action'].lower()}-{uuid.uuid4().hex[:6]}", | |
| "label": rep["action"], | |
| "start": raw.get("start", "00:00:00:00"), | |
| "end": raw.get("end", "00:00:00:00"), | |
| "fps": fps, | |
| "startFrame": sf, | |
| "endFrame": ef, | |
| "startSeconds": max(0.0, sf / max(1.0, fps)), | |
| "endSeconds": max(0.0, (ef + 1) / max(1.0, fps)), | |
| "totalFrames": ef - sf + 1, | |
| "preFrames": tf - sf, | |
| "inFrame": tf - sf, | |
| "postFrames": ef - tf, | |
| "cameraClips": camera_videos, | |
| "sourceCameraClips": source_camera_videos, | |
| "preMetrics": pre_metrics, | |
| "inActionMetrics": in_action_metrics, | |
| "postMetrics": post_metrics, | |
| "fullIntervalMetrics": full_interval_metrics, | |
| "skeleton": skeleton_frames, | |
| "rawAnalytics": an, | |
| }) | |
| print("Pipeline successful. Yielding payload payload.") | |
| response_payload = { | |
| "id": session_id, | |
| "playerId": playerId, | |
| "createdAt": int(time.time() * 1000), | |
| "targetSize": [int(targetW), int(targetH)], | |
| "cameraCount": len(camera_map), | |
| "actions": formatted_actions, | |
| "failedActions": failed_actions | |
| } | |
| session_json_path = os.path.join(session_dir, "session.json") | |
| with open(session_json_path, "w", encoding="utf-8") as f: | |
| json.dump(response_payload, f, indent=2, default=json_default) | |
| push_session_to_hf(playerId, session_id, session_dir) | |
| return response_payload | |
| except Exception as e: | |
| print("--- PIPELINE ERROR ---") | |
| traceback.print_exc() | |
| if temp_dir and os.path.isdir(temp_dir): | |
| shutil.rmtree(temp_dir, ignore_errors=True) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| # Start ASGI interface natively mapping locally to the React vite environment | |
| uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", "8000"))) | |