Shoraky's picture
Optimize storage sync performance on read paths.
1e1f04d
import os
import json
import time
import uuid
import shutil
import traceback
import re
import sys
import importlib.util
import threading
import subprocess
from urllib.parse import urlsplit, urlunsplit
import cv2
import numpy as np
from fastapi import FastAPI, File, UploadFile, Form, Request, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.concurrency import run_in_threadpool
from fastapi.staticfiles import StaticFiles
from typing import List
from huggingface_hub import hf_hub_download, snapshot_download, HfApi
app = FastAPI(title="Sporalize Labs 3D Analysis Engine")
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
def default_runtime_root():
if os.path.isdir("/data"):
return os.path.join("/data", "sporalize_runtime")
return os.path.join(os.path.expanduser("~"), ".sporalize_runtime")
RUNTIME_ROOT = os.environ.get("SPORALIZE_RUNTIME_DIR", default_runtime_root())
ASSETS_RUNTIME_ROOT = os.environ.get("SPORALIZE_ASSETS_DIR", os.path.join(RUNTIME_ROOT, "assets"))
WEIGHTS_RUNTIME_ROOT = os.environ.get("SPORALIZE_WEIGHTS_DIR", os.path.join(RUNTIME_ROOT, "weights"))
DEFAULT_LOCAL_STORAGE_ROOT = os.path.join(CURRENT_DIR, "Storage")
if not os.path.isdir("/data") and not os.access(CURRENT_DIR, os.W_OK):
DEFAULT_LOCAL_STORAGE_ROOT = os.path.join(RUNTIME_ROOT, "storage")
STORAGE_ROOT = os.environ.get(
"SPORALIZE_STORAGE_DIR",
os.path.join("/data", "sporalize_storage") if os.path.isdir("/data") else DEFAULT_LOCAL_STORAGE_ROOT,
)
STORAGE_DATASET_REPO_ID = os.environ.get("SPORALIZE_STORAGE_REPO_ID", "Shoraky/SporalizeLabs-runtime-private").strip()
STORAGE_DATASET_REPO_TYPE = os.environ.get("SPORALIZE_STORAGE_REPO_TYPE", "dataset").strip()
STORAGE_DATASET_PATH = os.environ.get("SPORALIZE_STORAGE_DATASET_PATH", "Storage").strip("/").strip()
STORAGE_SYNC_INTERVAL_SECONDS = float(os.environ.get("SPORALIZE_STORAGE_SYNC_INTERVAL_SECONDS", "20"))
_storage_sync_lock = threading.Lock()
_storage_last_sync_ts = 0.0
DEFAULT_WEIGHT_SPECS = {
"POSE_PATH": {
"filename": "vitpose-s-coco_25.onnx",
"repo_id": os.environ.get("SPORALIZE_POSE_MODEL_REPO_ID", "JunkyByte/easy_ViTPose"),
"repo_type": os.environ.get("SPORALIZE_POSE_MODEL_REPO_TYPE", "model"),
"repo_file": os.environ.get("SPORALIZE_POSE_MODEL_FILE", "onnx/coco_25/vitpose-25-s.onnx"),
"override_env": "SPORALIZE_POSE_MODEL_PATH",
"local_fallback": os.path.join(CURRENT_DIR, "Weights", "vitpose-s-coco_25.onnx"),
},
"YOLO_PATH": {
"filename": "yolov8m.pt",
"repo_id": os.environ.get("SPORALIZE_YOLO_MODEL_REPO_ID", "Ultralytics/YOLOv8"),
"repo_type": os.environ.get("SPORALIZE_YOLO_MODEL_REPO_TYPE", "model"),
"repo_file": os.environ.get("SPORALIZE_YOLO_MODEL_FILE", "yolov8m.pt"),
"override_env": "SPORALIZE_YOLO_MODEL_PATH",
"local_fallback": os.path.join(CURRENT_DIR, "Weights", "yolov8m.pt"),
},
}
runtime_state = {
"ready": False,
"pipeline_root": None,
"run_pipeline": None,
"weights": {},
}
def get_hf_token():
return os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
def hf_storage_enabled():
return bool(STORAGE_DATASET_REPO_ID and STORAGE_DATASET_PATH)
def hf_storage_path(*parts: str) -> str:
normalized = [STORAGE_DATASET_PATH]
normalized.extend(part.strip("/").replace("\\", "/") for part in parts if part is not None and str(part).strip("/"))
return "/".join(segment for segment in normalized if segment)
def sync_storage_from_hf(force: bool = False):
global _storage_last_sync_ts
if not hf_storage_enabled():
return
now = time.time()
if not force and (now - _storage_last_sync_ts) < STORAGE_SYNC_INTERVAL_SECONDS:
return
with _storage_sync_lock:
now = time.time()
if not force and (now - _storage_last_sync_ts) < STORAGE_SYNC_INTERVAL_SECONDS:
return
sync_cache_root = os.path.join(RUNTIME_ROOT, "storage-sync-cache")
os.makedirs(sync_cache_root, exist_ok=True)
local_repo_dir = os.path.join(sync_cache_root, safe_name(STORAGE_DATASET_REPO_ID))
snapshot_download(
repo_id=STORAGE_DATASET_REPO_ID,
repo_type=STORAGE_DATASET_REPO_TYPE,
token=get_hf_token(),
local_dir=local_repo_dir,
allow_patterns=[f"{STORAGE_DATASET_PATH}/**"],
)
source_storage = os.path.join(local_repo_dir, STORAGE_DATASET_PATH)
if os.path.isdir(source_storage):
if os.path.isdir(STORAGE_ROOT):
shutil.rmtree(STORAGE_ROOT, ignore_errors=True)
shutil.copytree(source_storage, STORAGE_ROOT, dirs_exist_ok=True)
_storage_last_sync_ts = time.time()
def push_session_to_hf(player_id: str, session_id: str, session_dir: str):
if not hf_storage_enabled():
return
api = HfApi(token=get_hf_token())
api.upload_folder(
repo_id=STORAGE_DATASET_REPO_ID,
repo_type=STORAGE_DATASET_REPO_TYPE,
folder_path=session_dir,
path_in_repo=hf_storage_path(safe_name(player_id), safe_name(session_id)),
commit_message=f"Add session {safe_name(session_id)} for player {safe_name(player_id)}",
)
def delete_session_from_hf(player_id: str, session_id: str):
if not hf_storage_enabled():
return
api = HfApi(token=get_hf_token())
api.delete_folder(
repo_id=STORAGE_DATASET_REPO_ID,
repo_type=STORAGE_DATASET_REPO_TYPE,
path_in_repo=hf_storage_path(safe_name(player_id), safe_name(session_id)),
commit_message=f"Delete session {safe_name(session_id)} for player {safe_name(player_id)}",
)
def delete_player_from_hf(player_id: str):
if not hf_storage_enabled():
return
api = HfApi(token=get_hf_token())
api.delete_folder(
repo_id=STORAGE_DATASET_REPO_ID,
repo_type=STORAGE_DATASET_REPO_TYPE,
path_in_repo=hf_storage_path(safe_name(player_id)),
commit_message=f"Delete player {safe_name(player_id)} storage",
)
def path_has_session_data(directory: str):
if not os.path.isdir(directory):
return False
for _root, _dirs, files in os.walk(directory):
if "session.json" in files:
return True
return False
def seed_storage_if_needed(seed_dir: str, target_dir: str):
if not os.path.isdir(seed_dir):
return
os.makedirs(target_dir, exist_ok=True)
if path_has_session_data(target_dir):
return
shutil.copytree(seed_dir, target_dir, dirs_exist_ok=True)
def resolve_pipeline_root():
local_pipeline = os.path.join(CURRENT_DIR, "pipeline.py")
local_vitpose = os.path.join(CURRENT_DIR, "ViTPose")
if os.path.isfile(local_pipeline) and os.path.isdir(local_vitpose):
return CURRENT_DIR
repo_id = os.environ.get("SPORALIZE_ASSETS_REPO_ID")
if not repo_id:
raise RuntimeError(
"SPORALIZE_ASSETS_REPO_ID is required when Backend/pipeline.py is not bundled locally."
)
assets_dir = os.path.join(ASSETS_RUNTIME_ROOT, safe_name(repo_id))
snapshot_download(
repo_id=repo_id,
repo_type=os.environ.get("SPORALIZE_ASSETS_REPO_TYPE", "dataset"),
revision=os.environ.get("SPORALIZE_ASSETS_REVISION"),
token=get_hf_token(),
local_dir=assets_dir,
allow_patterns=["pipeline.py", "ViTPose/**", "Storage/**", "Weights/**"],
)
seed_storage_if_needed(os.path.join(assets_dir, "Storage"), STORAGE_ROOT)
return assets_dir
def load_pipeline_callable(pipeline_root: str):
pipeline_path = os.path.join(pipeline_root, "pipeline.py")
if not os.path.isfile(pipeline_path):
raise RuntimeError(f"pipeline.py was not found at {pipeline_path}")
if pipeline_root not in sys.path:
sys.path.insert(0, pipeline_root)
module_name = "sporalize_runtime_pipeline"
if module_name in sys.modules:
del sys.modules[module_name]
spec = importlib.util.spec_from_file_location(module_name, pipeline_path)
if spec is None or spec.loader is None:
raise RuntimeError(f"Unable to create import spec for {pipeline_path}")
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
run_pipeline = getattr(module, "run_pipeline", None)
if run_pipeline is None:
raise RuntimeError("run_pipeline was not found in the resolved pipeline module")
return run_pipeline
def ensure_weight_file(spec: dict, pipeline_root: str):
override_path = os.environ.get(spec["override_env"])
if override_path and os.path.isfile(override_path):
return override_path
pipeline_weight = os.path.join(pipeline_root, "Weights", spec["filename"])
if os.path.isfile(pipeline_weight):
return pipeline_weight
local_fallback = spec.get("local_fallback")
if local_fallback and os.path.isfile(local_fallback):
return local_fallback
os.makedirs(WEIGHTS_RUNTIME_ROOT, exist_ok=True)
cached_path = os.path.join(WEIGHTS_RUNTIME_ROOT, spec["filename"])
if os.path.isfile(cached_path):
return cached_path
return hf_hub_download(
repo_id=spec["repo_id"],
repo_type=spec.get("repo_type", "model"),
filename=spec["repo_file"],
token=get_hf_token(),
local_dir=WEIGHTS_RUNTIME_ROOT,
)
def ensure_runtime_ready(force: bool = False):
if runtime_state["ready"] and not force:
return runtime_state
os.makedirs(RUNTIME_ROOT, exist_ok=True)
os.makedirs(STORAGE_ROOT, exist_ok=True)
pipeline_root = resolve_pipeline_root()
run_pipeline = load_pipeline_callable(pipeline_root)
weight_paths = {name: ensure_weight_file(spec, pipeline_root) for name, spec in DEFAULT_WEIGHT_SPECS.items()}
runtime_state.update({
"ready": True,
"pipeline_root": pipeline_root,
"run_pipeline": run_pipeline,
"weights": weight_paths,
})
return runtime_state
os.makedirs(STORAGE_ROOT, exist_ok=True)
app.mount("/storage", StaticFiles(directory=STORAGE_ROOT), name="storage")
progress_store = {}
cancel_store = {}
def safe_name(value: str) -> str:
allowed = []
for ch in str(value):
if ch.isalnum() or ch in ("-", "_", "."):
allowed.append(ch)
else:
allowed.append("_")
cleaned = "".join(allowed).strip("._")
return cleaned or "item"
def session_storage_paths(player_id: str, session_id: str):
player_dir = os.path.join(STORAGE_ROOT, safe_name(player_id))
session_dir = os.path.join(player_dir, safe_name(session_id))
videos_dir = os.path.join(session_dir, "videos")
return player_dir, session_dir, videos_dir
def list_session_files():
session_files = []
for root, _, files in os.walk(STORAGE_ROOT):
if "session.json" in files:
session_files.append(os.path.join(root, "session.json"))
return sorted(session_files, key=os.path.getmtime, reverse=True)
def build_storage_url(request: Request, *parts: str) -> str:
relative = "/".join(safe_name(part) if idx < len(parts) - 1 else part.replace("\\", "/") for idx, part in enumerate(parts))
configured_public_base = os.environ.get("SPORALIZE_PUBLIC_BASE_URL", "").strip()
if configured_public_base:
return configured_public_base.rstrip("/") + "/storage/" + relative.lstrip("/")
forwarded_proto = request.headers.get("x-forwarded-proto", "").split(",")[0].strip().lower()
forwarded_host = request.headers.get("x-forwarded-host", "").split(",")[0].strip()
if forwarded_proto in ("http", "https") and forwarded_host:
return f"{forwarded_proto}://{forwarded_host}".rstrip("/") + "/storage/" + relative.lstrip("/")
base_url = str(request.base_url).rstrip("/")
if forwarded_proto in ("http", "https"):
parsed = urlsplit(base_url)
base_url = urlunsplit((forwarded_proto, parsed.netloc, parsed.path, "", "")).rstrip("/")
return base_url + "/storage/" + relative.lstrip("/")
def parse_video_timecode(value, fps=30.0):
if value is None:
return 0.0
if isinstance(value, (int, float, np.integer, np.floating)):
return max(0.0, float(value))
parts = str(value).split(":")
if len(parts) == 4:
h, m, s, f = [int(float(part or 0)) for part in parts]
return max(0.0, (h * 3600) + (m * 60) + s + (f / max(1.0, float(fps))))
try:
return max(0.0, float(value))
except Exception:
return 0.0
def detect_camera_id(file_name: str):
match = re.search(r"_cam_(\d+)_", file_name)
if match:
return int(match.group(1))
return None
def split_camera_video_variants(videos_dir: str):
grouped: dict[int, dict[str, str]] = {}
for file_name in sorted(os.listdir(videos_dir)):
camera_id = detect_camera_id(file_name)
if camera_id is None:
continue
full_path = os.path.join(videos_dir, file_name)
if not os.path.isfile(full_path):
continue
bucket = grouped.setdefault(camera_id, {})
if file_name.lower().endswith(".web.mp4"):
bucket["web"] = full_path
else:
bucket["base"] = full_path
primary_map = {}
fallback_map = {}
for camera_id in sorted(grouped.keys()):
base_path = grouped[camera_id].get("base")
web_path = grouped[camera_id].get("web")
if base_path:
primary_map[camera_id] = base_path
if web_path:
fallback_map[camera_id] = web_path
elif web_path:
primary_map[camera_id] = web_path
return primary_map, fallback_map
def normalize_video_for_web(input_path: str) -> str:
"""
Re-encode uploaded video to a browser-friendly MP4 stream.
Falls back to the original file if normalization fails.
"""
output_path = os.path.splitext(input_path)[0] + ".web.mp4"
# First try ffmpeg -> H.264 + yuv420p for maximum browser compatibility.
ffmpeg_cmd = [
"ffmpeg",
"-y",
"-i",
input_path,
"-an",
"-c:v",
"libx264",
"-preset",
"veryfast",
"-pix_fmt",
"yuv420p",
"-movflags",
"+faststart",
output_path,
]
try:
ffmpeg_proc = subprocess.run(ffmpeg_cmd, capture_output=True, text=True)
if ffmpeg_proc.returncode == 0 and os.path.exists(output_path) and os.path.getsize(output_path) > 0:
return output_path
except Exception:
pass
# Fallback: OpenCV transcode when ffmpeg is unavailable.
cap = cv2.VideoCapture(input_path)
if not cap.isOpened():
return input_path
fps = cap.get(cv2.CAP_PROP_FPS)
if not fps or not np.isfinite(fps) or fps <= 0:
fps = 30.0
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) or 0)
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) or 0)
if width <= 0 or height <= 0:
cap.release()
return input_path
writer = cv2.VideoWriter(
output_path,
cv2.VideoWriter_fourcc(*"mp4v"),
float(fps),
(width, height),
)
if not writer.isOpened():
cap.release()
return input_path
frame_count = 0
success = True
while True:
ok, frame = cap.read()
if not ok:
break
writer.write(frame)
frame_count += 1
cap.release()
writer.release()
if frame_count <= 0:
success = False
elif not os.path.exists(output_path) or os.path.getsize(output_path) <= 0:
success = False
if not success:
if os.path.exists(output_path):
os.remove(output_path)
return input_path
return output_path
def build_camera_video_entries(request: Request, player_id: str, session_id: str, camera_map):
return [
{
"cameraId": int(camera_id),
"url": build_storage_url(
request,
safe_name(player_id),
safe_name(session_id),
"videos",
os.path.basename(video_path),
),
}
for camera_id, video_path in sorted(camera_map.items())
if video_path and os.path.exists(video_path)
]
def build_action_clip_entries(request: Request, player_id: str, session_id: str, clip_map):
return [
{
"cameraId": int(camera_id),
"url": build_storage_url(
request,
safe_name(player_id),
safe_name(session_id),
"clips",
os.path.basename(clip_path),
),
}
for camera_id, clip_path in sorted(clip_map.items())
if clip_path and os.path.exists(clip_path)
]
def normalize_session_payload(session: dict, request: Request):
session_id = session.get("id")
player_id = session.get("playerId")
if not session_id or not player_id:
return session
session_dir = find_session_path(session_id)
if not session_dir:
return session
videos_dir = os.path.join(session_dir, "videos")
if not os.path.isdir(videos_dir):
return session
camera_map, fallback_camera_map = split_camera_video_variants(videos_dir)
if not camera_map:
return session
normalized_actions = []
for action in session.get("actions", []):
normalized_action = dict(action)
fps = float(normalized_action.get("fps") or 30.0)
fps = max(1.0, fps)
absolute_start_frame = normalized_action.get("sourceStartFrame")
absolute_end_frame = normalized_action.get("sourceEndFrame")
if absolute_start_frame is None or absolute_end_frame is None:
absolute_start_frame = normalized_action.get("startFrame")
absolute_end_frame = normalized_action.get("endFrame")
try:
absolute_start_frame = int(absolute_start_frame) if absolute_start_frame is not None else None
absolute_end_frame = int(absolute_end_frame) if absolute_end_frame is not None else None
except Exception:
absolute_start_frame = None
absolute_end_frame = None
if absolute_start_frame is not None and absolute_end_frame is not None and absolute_end_frame >= absolute_start_frame:
start_seconds = max(0.0, absolute_start_frame / fps)
end_seconds = max(start_seconds, (absolute_end_frame + 1) / fps)
normalized_action["startFrame"] = absolute_start_frame
normalized_action["endFrame"] = absolute_end_frame
else:
total_frames = int(normalized_action.get("totalFrames") or 0)
start_seconds = parse_video_timecode(normalized_action.get("start"), fps=fps)
if total_frames > 0:
end_seconds = start_seconds + (total_frames / fps)
else:
end_seconds = max(start_seconds, parse_video_timecode(normalized_action.get("end"), fps=fps))
normalized_action["cameraClips"] = normalized_action.get("cameraClips") or build_camera_video_entries(
request, player_id, session_id, camera_map
)
if fallback_camera_map:
normalized_action["sourceCameraClips"] = normalized_action.get("sourceCameraClips") or build_camera_video_entries(
request, player_id, session_id, fallback_camera_map
)
normalized_action["startSeconds"] = round(start_seconds, 6)
normalized_action["endSeconds"] = round(end_seconds, 6)
normalized_actions.append(normalized_action)
normalized_session = dict(session)
normalized_session["actions"] = normalized_actions
return normalized_session
def json_default(value):
if isinstance(value, np.generic):
return value.item()
if isinstance(value, np.ndarray):
return value.tolist()
raise TypeError(f"Object of type {type(value).__name__} is not JSON serializable")
def export_action_clips(camera_map, clips_dir, action_index, start_frame, end_frame, fps):
os.makedirs(clips_dir, exist_ok=True)
frame_count = max(0, end_frame - start_frame + 1)
clip_paths = {}
for camera_id, video_path in sorted(camera_map.items()):
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
continue
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) or 0)
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) or 0)
if width <= 0 or height <= 0:
cap.release()
continue
clip_name = f"action_{action_index:02d}_cam_{camera_id}.mp4"
clip_path = os.path.join(clips_dir, clip_name)
writer = cv2.VideoWriter(
clip_path,
cv2.VideoWriter_fourcc(*"mp4v"),
max(1.0, float(fps)),
(width, height),
)
cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
written = 0
while written < frame_count:
ok, frame = cap.read()
if not ok:
break
writer.write(frame)
written += 1
writer.release()
cap.release()
if written > 0 and os.path.exists(clip_path):
clip_paths[camera_id] = clip_path
elif os.path.exists(clip_path):
os.remove(clip_path)
return clip_paths
def load_session_by_id(session_id: str):
target_name = safe_name(session_id)
for session_file in list_session_files():
session_dir = os.path.basename(os.path.dirname(session_file))
if session_dir != target_name:
continue
with open(session_file, "r", encoding="utf-8") as f:
return json.load(f)
return None
def find_session_path(session_id: str):
target_name = safe_name(session_id)
for session_file in list_session_files():
session_dir = os.path.dirname(session_file)
if os.path.basename(session_dir) == target_name:
return session_dir
return None
def player_storage_path(player_id: str):
return os.path.join(STORAGE_ROOT, safe_name(player_id))
def get_cors_origins():
configured = os.environ.get("CORS_ALLOW_ORIGINS", "*").strip()
if not configured or configured == "*":
return ["*"]
return [origin.strip() for origin in configured.split(",") if origin.strip()]
@app.on_event("startup")
def startup_event():
ensure_runtime_ready()
sync_storage_from_hf(force=True)
@app.get("/healthz")
def healthz():
runtime = ensure_runtime_ready()
return {
"status": "ok",
"storageRoot": STORAGE_ROOT,
"pipelineRoot": runtime.get("pipeline_root"),
}
@app.post("/api/cancel/{client_id}")
def cancel_processing(client_id: str):
cancel_store[client_id] = True
return {"status": "cancelled"}
@app.get("/api/progress/{client_id}")
def get_progress(client_id: str):
return progress_store.get(client_id, {"progress": 0.0, "phase": "Initializing"})
@app.get("/api/sessions/{session_id}")
def get_session(session_id: str, request: Request):
session = load_session_by_id(session_id)
if session is None:
raise HTTPException(status_code=404, detail="Session not found")
return normalize_session_payload(session, request)
@app.get("/api/archive")
def get_archive(request: Request):
sessions = []
for session_file in list_session_files():
try:
with open(session_file, "r", encoding="utf-8") as f:
session = json.load(f)
sessions.append(normalize_session_payload(session, request))
except Exception:
continue
return {"sessions": sessions}
@app.delete("/api/sessions/{session_id}")
def delete_session(session_id: str):
session_dir = find_session_path(session_id)
if session_dir is None:
raise HTTPException(status_code=404, detail="Session not found")
player_dir = os.path.dirname(session_dir)
player_id = os.path.basename(player_dir)
shutil.rmtree(session_dir, ignore_errors=True)
delete_session_from_hf(player_id, session_id)
if os.path.isdir(player_dir) and not os.listdir(player_dir):
os.rmdir(player_dir)
return {"status": "deleted", "sessionId": session_id}
@app.delete("/api/players/{player_id}")
def delete_player(player_id: str):
player_dir = player_storage_path(player_id)
if not os.path.isdir(player_dir):
raise HTTPException(status_code=404, detail="Player storage not found")
shutil.rmtree(player_dir, ignore_errors=True)
delete_player_from_hf(player_id)
return {"status": "deleted", "playerId": player_id}
cors_origins = get_cors_origins()
app.add_middleware(
CORSMiddleware,
allow_origins=cors_origins,
allow_credentials=(cors_origins != ["*"]),
allow_methods=["*"],
allow_headers=["*"],
)
def format_metric_series(name, unit, values_list):
return {
"name": name,
"unit": unit,
"values": [
{"frame": i, "value": safe_float(v)}
for i, v in enumerate(values_list)
]
}
def safe_float(value):
try:
number = float(value)
return None if np.isnan(number) else number
except Exception:
return None
def metric_name(key: str) -> str:
return key.replace("_", " ").title()
FULL_INTERVAL_KEYS = [
"left_knee_angles",
"right_knee_angles",
"torso_pitch_angles",
"head_angles",
"mid_foot_ball_distances",
"left_right_foot_distances",
]
ACTION_METRIC_LAYOUTS = {
"Pass": {
"pre": ["body_to_ball_angle"],
"in": [
"body_to_ball_angle",
"l_r_foot_distance",
"trunc_pitch_angle",
"trunc_roll_angle",
"left_foot_orientation_angle",
"right_foot_orientation_angle",
"difference_in_angles",
"l_knee_angle",
"r_knee_angle",
"head_angle",
"head_pitch_angle",
"head_roll_angle",
"stand_foot_angle",
"active_foot_height_pct",
],
"post": ["head_angle", "body_to_ball_angle"],
"top_level_scalars": ["backward_weighted_angle", "forward_weighted_angle"],
},
"Shot": {
"pre": ["body_to_ball_angle"],
"in": [
"body_to_ball_angle",
"l_r_foot_distance",
"trunc_pitch_angle",
"trunc_roll_angle",
"left_foot_orientation_angle",
"right_foot_orientation_angle",
"difference_in_angles",
"l_knee_angle",
"r_knee_angle",
"head_angle",
"head_pitch_angle",
"head_roll_angle",
"stand_foot_angle",
"l_elbow_shoulder_hip_angle",
"r_elbow_shoulder_hip_angle",
"active_ankle_angle",
],
"post": ["head_angle", "body_to_ball_angle"],
"top_level_scalars": ["backward_weighted_angle", "forward_weighted_angle"],
},
"Receive": {
"pre": ["body_orientation_vs_ball", "head_angle"],
"in": [
"head_angle",
"l_knee_angle",
"r_knee_angle",
"trunc_pitch_angle",
"trunc_roll_angle",
"left_foot_orientation_angle",
"right_foot_orientation_angle",
"difference_in_angles",
"l_r_foot_distance",
"stand_foot_angle",
"body_orientation_vs_ball",
"active_foot_height_pct",
],
"post": ["mid_feet_ball_dist", "ball_height_pct_body"],
"top_level_scalars": [],
},
"Dribble": {
"frames": [
"ball_feet_distance",
"trunk_pitch",
"trunk_roll",
"head_angle",
"ball_possession_score",
],
"top_level_scalars": [],
},
}
def ordered_metric_keys(observed_keys, preferred_keys=None):
preferred = [key for key in (preferred_keys or []) if key in observed_keys]
extras = sorted(key for key in observed_keys if key not in preferred)
return preferred + extras
def build_series_from_entries(entries, unit_for, skip_keys=None, preferred_keys=None):
skip = {"frame"}
if skip_keys:
skip.update(skip_keys)
metric_keys = set(preferred_keys or [])
for entry in entries:
metric_keys.update(
key for key in entry.keys()
if key not in skip
)
series = [
format_metric_series(metric_name(key), unit_for(key), [entry.get(key) for entry in entries])
for key in ordered_metric_keys(metric_keys, preferred_keys)
]
return series
def build_scalar_metrics(payload, unit_for, skip_keys=None, preferred_keys=None):
skip = set(skip_keys or [])
metrics = []
observed_keys = set(key for key in payload.keys() if key not in skip)
observed_keys.update(key for key in (preferred_keys or []) if key not in skip)
for key in ordered_metric_keys(observed_keys, preferred_keys):
if key in skip:
continue
value = safe_float(payload.get(key))
metrics.append({
"name": metric_name(key),
"value": round(value, 3) if value is not None else None,
"unit": unit_for(key),
})
return metrics
def build_top_level_interval_metrics(analytics, unit_for, skip_keys=None, preferred_keys=None):
skip = {
"action",
"active_foot",
"touch_frame",
"pre_action",
"action_frame",
"post_action",
"frames",
}
if skip_keys:
skip.update(skip_keys)
observed_keys = set(preferred_keys or [])
observed_keys.update(analytics.keys())
series = []
for key in ordered_metric_keys(observed_keys, preferred_keys):
if key in skip:
continue
values = analytics.get(key)
if values is None:
values = []
if not isinstance(values, list):
continue
series.append(format_metric_series(metric_name(key), unit_for(key), values))
order_index = {metric_name(key): idx for idx, key in enumerate(preferred_keys or [])}
return sorted(series, key=lambda item: order_index.get(item["name"], len(order_index)))
@app.post("/api/analyze")
async def analyze_endpoint(
request: Request,
playerId: str = Form(...),
targetW: float = Form(...),
targetH: float = Form(...),
clientId: str = Form(...),
videoOrders: List[int] = Form(...),
actionsJson: UploadFile = File(...),
calibration: UploadFile = File(...),
videos: List[UploadFile] = File(...)
):
temp_dir = None
session_id = f"session-{int(time.time())}-{uuid.uuid4().hex[:6]}"
player_dir, session_dir, videos_dir = session_storage_paths(playerId, session_id)
try:
# Clear any previous cancellation flags
cancel_store.pop(clientId, None)
progress_store[clientId] = {"progress": 2.0, "step": 0, "total": 0, "phase": "Uploading & Validating Data"}
os.makedirs(videos_dir, exist_ok=True)
temp_dir = session_dir
# 1. Store incoming payloads
actions_path = os.path.join(session_dir, "actions.json")
with open(actions_path, "wb") as f:
f.write(await actionsJson.read())
calib_path = os.path.join(session_dir, "calibration.npz")
with open(calib_path, "wb") as f:
f.write(await calibration.read())
if len(videoOrders) != len(videos):
raise ValueError("Each uploaded video must include a matching camera order")
if len(set(videoOrders)) != len(videoOrders):
raise ValueError("Camera order values must be unique")
camera_map = {}
fallback_camera_map = {}
for idx, (camera_order, video) in enumerate(zip(videoOrders, videos)):
original_name = video.filename or f"camera_{camera_order}.mp4"
video_name = f"{idx:02d}_cam_{camera_order}_{safe_name(os.path.basename(original_name))}"
vid_path = os.path.join(videos_dir, video_name)
with open(vid_path, "wb") as f:
f.write(await video.read())
normalized_path = normalize_video_for_web(vid_path)
camera_id = int(camera_order)
camera_map[camera_id] = vid_path
if normalized_path != vid_path and os.path.exists(normalized_path):
fallback_camera_map[camera_id] = normalized_path
progress_store[clientId] = {"progress": 10.0, "step": 0, "total": 0, "phase": "Preparing AI Models"}
runtime = ensure_runtime_ready()
utils_paths = {
"POSE_PATH": runtime["weights"]["POSE_PATH"],
"YOLO_PATH": runtime["weights"]["YOLO_PATH"],
"CALIBRATION_PATH": calib_path,
"ACTIONS_PATH": actions_path
}
sizes = {
"TARGET_SIZE": (int(targetW), int(targetH)),
"YOLO_IMGSZ": 960
}
print("Starting physical pipeline execution...")
progress_store[clientId] = {"progress": 20.0, "step": 0, "total": 0, "phase": "Extracting 3D Kinematics"}
def progress_tracker(current_act, total_act, step, total_frames):
if cancel_store.get(clientId):
return False # Signal pipeline to abort
base_p = current_act / max(1, total_act)
segment_p = (step / max(1, total_frames)) * (1.0 / max(1, total_act))
# Rescale 20% to 90% for processing
pct = 20.0 + round((base_p + segment_p) * 70.0, 1)
progress_store[clientId] = {
"progress": pct,
"step": step,
"total": total_frames,
"phase": f"Processing Action {current_act + 1}/{total_act}"
}
return True
# 2. Yield to worker thread to allow concurrent polling from front-end
def execute_pipeline():
return runtime["run_pipeline"](camera_map, utils_paths, sizes, progress_tracker)
reports = await run_in_threadpool(execute_pipeline)
progress_store[clientId] = {"progress": 100.0, "step": 0, "total": 0, "phase": "Completed"}
raw_reports_path = os.path.join(session_dir, "raw_reports.json")
with open(raw_reports_path, "w", encoding="utf-8") as f:
json.dump(reports, f, indent=2, default=json_default)
# 3. Format output dict perfectly mapping to the Frontend Types
with open(actions_path, "r") as f:
raw_actions = json.load(f).get("actions", [])
formatted_actions = []
failed_actions = []
camera_videos = build_camera_video_entries(request, playerId, session_id, camera_map)
source_camera_videos = build_camera_video_entries(request, playerId, session_id, fallback_camera_map)
for i, rep in enumerate(reports):
raw = raw_actions[i] if i < len(raw_actions) else {}
if "error" in rep:
failed_actions.append({
"id": f"err-{uuid.uuid4().hex[:6]}",
"label": rep.get("action", raw.get("label", "Unknown")),
"start": raw.get("start", "00:00:00:00"),
"end": raw.get("end", "00:00:00:00"),
"error": rep["error"]
})
continue
an = rep["analytics"]
sf = rep["start_frame"]
ef = rep["end_frame"]
fps = float(rep.get("fps", 30))
is_dribble = (an.get("action") == "Dribble")
if is_dribble:
tf = (sf + ef) // 2
else:
tf = an.get("touch_frame", (sf + ef) // 2)
# --- Skeleton: support full COCO-25/WB joint range (0–32) ---
skeleton_frames = []
raw_ball_history = rep.get("ball_history", {})
for f_idx in range(sf, ef + 1):
raw_skel = rep["skel_history"].get(f_idx, {})
raw_ball = raw_ball_history.get(f_idx)
# Find the max joint index present so we don't truncate
max_joint = max(raw_skel.keys()) if raw_skel else 32
n_joints = max(33, max_joint + 1)
joints = []
for j in range(n_joints):
pt = raw_skel.get(j)
if pt is not None:
joints.append([float(pt[0]), float(pt[1]), float(pt[2])])
else:
joints.append([0.0, 0.0, 0.0])
frame_payload = {"frame": f_idx - sf, "joints": joints}
if raw_ball is not None:
frame_payload["ball"] = [float(raw_ball[0]), float(raw_ball[1]), float(raw_ball[2])]
skeleton_frames.append(frame_payload)
# --- Unit dictionary for known metric names ---
UNITS = {
"head_angle": "°", "l_knee_angle": "°", "r_knee_angle": "°",
"trunc_pitch_angle": "°", "trunc_roll_angle": "°",
"trunk_pitch": "°", "trunk_roll": "°",
"head_pitch_angle": "°", "head_roll_angle": "°",
"left_foot_orientation_angle": "°", "right_foot_orientation_angle": "°",
"difference_in_angles": "°", "body_to_ball_angle": "°",
"body_orientation_vs_ball": "°", "stand_foot_angle": "°",
"active_ankle_angle": "°", "l_elbow_shoulder_hip_angle": "°",
"r_elbow_shoulder_hip_angle": "°", "backward_weighted_angle": "°",
"forward_weighted_angle": "°", "leg_separation_angle": "°",
"l_r_foot_distance": "cm", "l_foot_ball_distance": "cm",
"r_foot_ball_distance": "cm", "mid_feet_ball_dist": "cm",
"active_foot_height_pct": "%", "ball_height_pct_body": "%",
"ball_possession_score": "%", "ball_feet_distance": "cm",
}
def unit_for(key):
return UNITS.get(key, "")
action_layout = ACTION_METRIC_LAYOUTS.get(rep["action"], {})
if is_dribble:
dribble_frames = an.get("frames", [])
pre_metrics = build_series_from_entries(
dribble_frames,
unit_for,
preferred_keys=action_layout.get("frames"),
)
in_action_metrics = []
frame_metric_keys = ordered_metric_keys(
{k for frame in dribble_frames for k in frame.keys() if k != "frame"},
action_layout.get("frames"),
)
for key in frame_metric_keys:
numeric_values = [safe_float(frame.get(key)) for frame in dribble_frames]
numeric_values = [value for value in numeric_values if value is not None]
in_action_metrics.append({
"name": f"Avg {metric_name(key)}",
"value": round(float(np.mean(numeric_values)), 3) if numeric_values else None,
"unit": unit_for(key),
})
post_metrics = []
else:
pre_entries = an.get("pre_action", [])
post_entries = an.get("post_action", [])
action_frame_data = an.get("action_frame", {})
pre_metrics = build_series_from_entries(
pre_entries,
unit_for,
preferred_keys=action_layout.get("pre"),
)
in_action_metrics = build_scalar_metrics(
action_frame_data,
unit_for,
skip_keys={"active_foot"},
preferred_keys=action_layout.get("in"),
)
in_action_metrics.extend(
build_scalar_metrics(
an,
unit_for,
skip_keys={
"action",
"active_foot",
"touch_frame",
"pre_action",
"action_frame",
"post_action",
"frames",
"left_knee_angles",
"right_knee_angles",
"torso_pitch_angles",
"head_angles",
"mid_foot_ball_distances",
"left_right_foot_distances",
},
preferred_keys=action_layout.get("top_level_scalars"),
)
)
post_metrics = build_series_from_entries(
post_entries,
unit_for,
preferred_keys=action_layout.get("post"),
)
full_interval_metrics = build_top_level_interval_metrics(
an,
unit_for,
preferred_keys=FULL_INTERVAL_KEYS,
)
formatted_actions.append({
"id": f"{rep['action'].lower()}-{uuid.uuid4().hex[:6]}",
"label": rep["action"],
"start": raw.get("start", "00:00:00:00"),
"end": raw.get("end", "00:00:00:00"),
"fps": fps,
"startFrame": sf,
"endFrame": ef,
"startSeconds": max(0.0, sf / max(1.0, fps)),
"endSeconds": max(0.0, (ef + 1) / max(1.0, fps)),
"totalFrames": ef - sf + 1,
"preFrames": tf - sf,
"inFrame": tf - sf,
"postFrames": ef - tf,
"cameraClips": camera_videos,
"sourceCameraClips": source_camera_videos,
"preMetrics": pre_metrics,
"inActionMetrics": in_action_metrics,
"postMetrics": post_metrics,
"fullIntervalMetrics": full_interval_metrics,
"skeleton": skeleton_frames,
"rawAnalytics": an,
})
print("Pipeline successful. Yielding payload payload.")
response_payload = {
"id": session_id,
"playerId": playerId,
"createdAt": int(time.time() * 1000),
"targetSize": [int(targetW), int(targetH)],
"cameraCount": len(camera_map),
"actions": formatted_actions,
"failedActions": failed_actions
}
session_json_path = os.path.join(session_dir, "session.json")
with open(session_json_path, "w", encoding="utf-8") as f:
json.dump(response_payload, f, indent=2, default=json_default)
push_session_to_hf(playerId, session_id, session_dir)
return response_payload
except Exception as e:
print("--- PIPELINE ERROR ---")
traceback.print_exc()
if temp_dir and os.path.isdir(temp_dir):
shutil.rmtree(temp_dir, ignore_errors=True)
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
# Start ASGI interface natively mapping locally to the React vite environment
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", "8000")))