Spaces:
Paused
Paused
MacBook pro
commited on
Commit
·
d876213
1
Parent(s):
9f8524f
feat(docker): switch to Docker Space GPU runtime; prod WebRTC (aiortc) flow; remove legacy WS; token auth; instrumentation p50/p95; requirements harden
Browse files- .gitignore +9 -2
- Dockerfile +43 -9
- README.md +4 -4
- avatar_pipeline.py +13 -4
- deploy.sh +130 -0
- fastapi_app.py +194 -351
- models/hubert/.gitkeep +0 -0
- models/rmvpe/.gitkeep +0 -0
- models/rvc/.gitkeep +0 -0
- original_fastapi_app.py +290 -0
- requirements.txt +9 -19
- requirements_old.txt +38 -0
- static/app.js +4 -1
- static/index.html +13 -48
- static/webrtc_client.js +4 -0
- static/webrtc_prod.js +117 -0
- test_system.py +380 -0
- webrtc_server.py +444 -0
.gitignore
CHANGED
|
@@ -19,7 +19,14 @@ Thumbs.db
|
|
| 19 |
pip-wheel-metadata/
|
| 20 |
.cache/
|
| 21 |
coverage/
|
| 22 |
-
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
checkpoints/
|
| 25 |
!checkpoints/.gitkeep
|
|
|
|
| 19 |
pip-wheel-metadata/
|
| 20 |
.cache/
|
| 21 |
coverage/
|
| 22 |
+
# Models: keep directory structure but ignore large weight/binary artifacts
|
| 23 |
+
# (Allow .gitkeep, README, *.md, *.txt for documentation.)
|
| 24 |
+
models/**/*.{pt,pth,bin,onnx,safetensors}
|
| 25 |
+
models/**/*.npz
|
| 26 |
+
models/**/*.ckpt
|
| 27 |
+
!models/**/*.gitkeep
|
| 28 |
+
!models/**/README.md
|
| 29 |
+
!models/**/README.txt
|
| 30 |
+
!models/**/README
|
| 31 |
checkpoints/
|
| 32 |
!checkpoints/.gitkeep
|
Dockerfile
CHANGED
|
@@ -1,20 +1,54 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
| 2 |
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
WORKDIR /app
|
| 8 |
|
| 9 |
-
# Install
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
COPY requirements.txt ./
|
| 11 |
-
RUN
|
| 12 |
-
&& pip cache purge || true
|
| 13 |
|
| 14 |
# Copy application source
|
| 15 |
COPY . /app
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
EXPOSE 7860
|
| 18 |
|
| 19 |
-
#
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Docker runtime for Hugging Face GPU Space (A10G) in Docker mode
|
| 2 |
+
## Single-stage image on Ubuntu 22.04 (Python 3.10) with CUDA 12.1 + cuDNN 8
|
| 3 |
+
FROM nvidia/cuda:12.1.1-cudnn8-runtime-ubuntu22.04
|
| 4 |
|
| 5 |
+
ENV DEBIAN_FRONTEND=noninteractive \
|
| 6 |
+
PYTHONUNBUFFERED=1 \
|
| 7 |
+
PYTHONDONTWRITEBYTECODE=1 \
|
| 8 |
+
CUDA_CACHE_PATH=/tmp/cuda_cache \
|
| 9 |
+
TORCH_CUDA_ARCH_LIST="8.6" \
|
| 10 |
+
CUDA_LAUNCH_BLOCKING=0 \
|
| 11 |
+
CUDA_VISIBLE_DEVICES=0
|
| 12 |
+
|
| 13 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 14 |
+
python3 \
|
| 15 |
+
python3-pip \
|
| 16 |
+
python3-dev \
|
| 17 |
+
build-essential \
|
| 18 |
+
git \
|
| 19 |
+
curl \
|
| 20 |
+
ffmpeg \
|
| 21 |
+
libsm6 \
|
| 22 |
+
libxext6 \
|
| 23 |
+
libgl1 \
|
| 24 |
+
libglib2.0-0 \
|
| 25 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 26 |
|
| 27 |
WORKDIR /app
|
| 28 |
|
| 29 |
+
# Install PyTorch with CUDA 12.1 first to avoid resolver overriding
|
| 30 |
+
RUN pip3 install --no-cache-dir --upgrade pip wheel setuptools \
|
| 31 |
+
&& pip3 install --no-cache-dir \
|
| 32 |
+
torch==2.3.1+cu121 \
|
| 33 |
+
torchaudio==2.3.1+cu121 \
|
| 34 |
+
--index-url https://download.pytorch.org/whl/cu121
|
| 35 |
+
|
| 36 |
+
# Copy requirements and install remaining Python dependencies
|
| 37 |
COPY requirements.txt ./
|
| 38 |
+
RUN pip3 install --no-cache-dir -r requirements.txt
|
|
|
|
| 39 |
|
| 40 |
# Copy application source
|
| 41 |
COPY . /app
|
| 42 |
|
| 43 |
+
# Create directories for models and checkpoints (if not already present)
|
| 44 |
+
RUN mkdir -p /app/models/liveportrait /app/models/rvc /app/models/hubert /app/models/rmvpe /app/checkpoints /tmp/cuda_cache
|
| 45 |
+
|
| 46 |
+
# Expose HTTP port
|
| 47 |
EXPOSE 7860
|
| 48 |
|
| 49 |
+
# Health check
|
| 50 |
+
HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \
|
| 51 |
+
CMD curl -fsS http://localhost:7860/health || exit 1
|
| 52 |
+
|
| 53 |
+
# Run FastAPI app with uvicorn (WebRTC endpoints + static UI)
|
| 54 |
+
CMD ["uvicorn", "original_fastapi_app:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
CHANGED
|
@@ -3,7 +3,7 @@ title: Mirage Real-time AI Avatar
|
|
| 3 |
emoji: 🎭
|
| 4 |
colorFrom: blue
|
| 5 |
colorTo: purple
|
| 6 |
-
sdk:
|
| 7 |
sdk_version: 4.44.0
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
|
@@ -49,7 +49,7 @@ Transform yourself into an AI avatar in real-time with sub-250ms latency! Perfec
|
|
| 49 |
- **Face Animation**: LivePortrait (KwaiVGI)
|
| 50 |
- **Voice Conversion**: RVC (Retrieval-based Voice Conversion)
|
| 51 |
- **Face Detection**: SCRFD with optimized inference
|
| 52 |
-
- **Backend**: FastAPI with
|
| 53 |
- **Frontend**: WebRTC-enabled real-time client
|
| 54 |
- **GPU**: NVIDIA A10G with CUDA optimization
|
| 55 |
|
|
@@ -130,10 +130,10 @@ The system automatically adapts quality based on performance:
|
|
| 130 |
## 🛠️ Development
|
| 131 |
|
| 132 |
Built with modern technologies:
|
| 133 |
-
- FastAPI for high-performance backend
|
| 134 |
- PyTorch with CUDA acceleration
|
| 135 |
- OpenCV for image processing
|
| 136 |
-
-
|
| 137 |
- Docker for consistent deployment
|
| 138 |
|
| 139 |
## 📄 License
|
|
|
|
| 3 |
emoji: 🎭
|
| 4 |
colorFrom: blue
|
| 5 |
colorTo: purple
|
| 6 |
+
sdk: docker
|
| 7 |
sdk_version: 4.44.0
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
|
|
|
| 49 |
- **Face Animation**: LivePortrait (KwaiVGI)
|
| 50 |
- **Voice Conversion**: RVC (Retrieval-based Voice Conversion)
|
| 51 |
- **Face Detection**: SCRFD with optimized inference
|
| 52 |
+
- **Backend**: FastAPI with WebRTC (aiortc)
|
| 53 |
- **Frontend**: WebRTC-enabled real-time client
|
| 54 |
- **GPU**: NVIDIA A10G with CUDA optimization
|
| 55 |
|
|
|
|
| 130 |
## 🛠️ Development
|
| 131 |
|
| 132 |
Built with modern technologies:
|
| 133 |
+
- FastAPI for high-performance backend (Docker entrypoint: uvicorn original_fastapi_app:app)
|
| 134 |
- PyTorch with CUDA acceleration
|
| 135 |
- OpenCV for image processing
|
| 136 |
+
- WebRTC (aiortc) for real-time media transport
|
| 137 |
- Docker for consistent deployment
|
| 138 |
|
| 139 |
## 📄 License
|
avatar_pipeline.py
CHANGED
|
@@ -434,12 +434,21 @@ class RealTimeAvatarPipeline:
|
|
| 434 |
opt_stats = self.optimizer.get_comprehensive_stats()
|
| 435 |
|
| 436 |
# Basic pipeline stats
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 437 |
pipeline_stats = {
|
| 438 |
"video_fps": len(video_times) / max(sum(video_times) / 1000, 0.001) if video_times else 0,
|
| 439 |
-
"avg_video_latency_ms": np.mean(video_times) if video_times else 0,
|
| 440 |
-
"
|
| 441 |
-
"
|
| 442 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 443 |
"models_loaded": self.loaded,
|
| 444 |
"gpu_available": torch.cuda.is_available(),
|
| 445 |
"gpu_memory_used": torch.cuda.memory_allocated() / 1024**3 if torch.cuda.is_available() else 0,
|
|
|
|
| 434 |
opt_stats = self.optimizer.get_comprehensive_stats()
|
| 435 |
|
| 436 |
# Basic pipeline stats
|
| 437 |
+
def _percentile(arr, p):
|
| 438 |
+
if not arr:
|
| 439 |
+
return 0
|
| 440 |
+
return float(np.percentile(np.array(arr), p))
|
| 441 |
+
|
| 442 |
pipeline_stats = {
|
| 443 |
"video_fps": len(video_times) / max(sum(video_times) / 1000, 0.001) if video_times else 0,
|
| 444 |
+
"avg_video_latency_ms": float(np.mean(video_times)) if video_times else 0,
|
| 445 |
+
"p50_video_latency_ms": _percentile(video_times, 50),
|
| 446 |
+
"p95_video_latency_ms": _percentile(video_times, 95),
|
| 447 |
+
"avg_audio_latency_ms": float(np.mean(audio_times)) if audio_times else 0,
|
| 448 |
+
"p50_audio_latency_ms": _percentile(audio_times, 50),
|
| 449 |
+
"p95_audio_latency_ms": _percentile(audio_times, 95),
|
| 450 |
+
"max_video_latency_ms": float(np.max(video_times)) if video_times else 0,
|
| 451 |
+
"max_audio_latency_ms": float(np.max(audio_times)) if audio_times else 0,
|
| 452 |
"models_loaded": self.loaded,
|
| 453 |
"gpu_available": torch.cuda.is_available(),
|
| 454 |
"gpu_memory_used": torch.cuda.memory_allocated() / 1024**3 if torch.cuda.is_available() else 0,
|
deploy.sh
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Deployment script for Mirage Real-time AI Avatar System
|
| 3 |
+
|
| 4 |
+
set -e
|
| 5 |
+
|
| 6 |
+
echo "🎭 Mirage Real-time AI Avatar - Deployment Script"
|
| 7 |
+
echo "=================================================="
|
| 8 |
+
|
| 9 |
+
# Check if we're deploying to HuggingFace Spaces
|
| 10 |
+
if [[ "${SPACE_ID}" ]]; then
|
| 11 |
+
echo "📡 Deploying to HuggingFace Spaces: ${SPACE_ID}"
|
| 12 |
+
DEPLOYMENT_TARGET="huggingface"
|
| 13 |
+
else
|
| 14 |
+
echo "🐳 Local Docker deployment"
|
| 15 |
+
DEPLOYMENT_TARGET="local"
|
| 16 |
+
fi
|
| 17 |
+
|
| 18 |
+
# Set environment variables for optimal A10G performance
|
| 19 |
+
export CUDA_VISIBLE_DEVICES=0
|
| 20 |
+
export TORCH_CUDA_ARCH_LIST="8.6" # A10G architecture
|
| 21 |
+
export CUDA_LAUNCH_BLOCKING=0
|
| 22 |
+
export MIRAGE_VOICE_ENABLE=1
|
| 23 |
+
export MIRAGE_CHUNK_MS=160
|
| 24 |
+
export MIRAGE_VIDEO_MAX_FPS=20
|
| 25 |
+
|
| 26 |
+
echo "🔧 Environment configured for A10G GPU"
|
| 27 |
+
|
| 28 |
+
# Download required models
|
| 29 |
+
echo "📥 Downloading AI models..."
|
| 30 |
+
|
| 31 |
+
# Create model directories
|
| 32 |
+
mkdir -p models/{liveportrait,rvc,hubert,rmvpe}
|
| 33 |
+
mkdir -p checkpoints
|
| 34 |
+
|
| 35 |
+
# Function to download from HuggingFace with retry
|
| 36 |
+
download_hf_model() {
|
| 37 |
+
local repo=$1
|
| 38 |
+
local filename=$2
|
| 39 |
+
local output_dir=$3
|
| 40 |
+
local max_retries=3
|
| 41 |
+
local retry_count=0
|
| 42 |
+
|
| 43 |
+
while [ $retry_count -lt $max_retries ]; do
|
| 44 |
+
if python3 -c "
|
| 45 |
+
from huggingface_hub import hf_hub_download
|
| 46 |
+
import os
|
| 47 |
+
try:
|
| 48 |
+
hf_hub_download('$repo', '$filename', local_dir='$output_dir', local_dir_use_symlinks=False)
|
| 49 |
+
print('✅ Downloaded $filename')
|
| 50 |
+
except Exception as e:
|
| 51 |
+
print(f'❌ Failed to download $filename: {e}')
|
| 52 |
+
exit(1)
|
| 53 |
+
"; then
|
| 54 |
+
break
|
| 55 |
+
fi
|
| 56 |
+
|
| 57 |
+
retry_count=$((retry_count + 1))
|
| 58 |
+
echo "⏳ Retry $retry_count/$max_retries for $filename"
|
| 59 |
+
sleep 2
|
| 60 |
+
done
|
| 61 |
+
|
| 62 |
+
if [ $retry_count -eq $max_retries ]; then
|
| 63 |
+
echo "❌ Failed to download $filename after $max_retries retries"
|
| 64 |
+
return 1
|
| 65 |
+
fi
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
# Download LivePortrait models (if available)
|
| 69 |
+
if python3 -c "from huggingface_hub import HfApi; api = HfApi(); print('✅ HuggingFace available')" 2>/dev/null; then
|
| 70 |
+
echo "🎨 Attempting to download LivePortrait models..."
|
| 71 |
+
# Note: These would be the actual model files when available
|
| 72 |
+
# download_hf_model "KwaiVGI/LivePortrait" "appearance_feature_extractor.pth" "models/liveportrait"
|
| 73 |
+
# download_hf_model "KwaiVGI/LivePortrait" "motion_extractor.pth" "models/liveportrait"
|
| 74 |
+
# download_hf_model "KwaiVGI/LivePortrait" "warping_module.pth" "models/liveportrait"
|
| 75 |
+
# download_hf_model "KwaiVGI/LivePortrait" "spade_generator.pth" "models/liveportrait"
|
| 76 |
+
echo "ℹ️ LivePortrait models will be downloaded on first use"
|
| 77 |
+
else
|
| 78 |
+
echo "⚠️ HuggingFace Hub not available, models will be downloaded at runtime"
|
| 79 |
+
fi
|
| 80 |
+
|
| 81 |
+
# Verify GPU availability
|
| 82 |
+
echo "🔍 Checking GPU configuration..."
|
| 83 |
+
python3 -c "
|
| 84 |
+
import torch
|
| 85 |
+
print(f'PyTorch version: {torch.__version__}')
|
| 86 |
+
print(f'CUDA available: {torch.cuda.is_available()}')
|
| 87 |
+
if torch.cuda.is_available():
|
| 88 |
+
print(f'GPU: {torch.cuda.get_device_name(0)}')
|
| 89 |
+
print(f'CUDA version: {torch.version.cuda}')
|
| 90 |
+
print(f'GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB')
|
| 91 |
+
else:
|
| 92 |
+
print('⚠️ GPU not available - running in CPU mode')
|
| 93 |
+
"
|
| 94 |
+
|
| 95 |
+
# Setup virtual camera (Linux only)
|
| 96 |
+
if [[ "$OSTYPE" == "linux-gnu"* ]]; then
|
| 97 |
+
echo "📹 Setting up virtual camera (v4l2loopback)..."
|
| 98 |
+
|
| 99 |
+
# Check if v4l2loopback is available
|
| 100 |
+
if modprobe v4l2loopback devices=1 video_nr=10 card_label="Mirage Virtual Camera" 2>/dev/null; then
|
| 101 |
+
echo "✅ Virtual camera device created: /dev/video10"
|
| 102 |
+
else
|
| 103 |
+
echo "⚠️ Could not create virtual camera device (requires sudo)"
|
| 104 |
+
echo "💡 Run: sudo modprobe v4l2loopback devices=1 video_nr=10 card_label='Mirage Virtual Camera'"
|
| 105 |
+
fi
|
| 106 |
+
fi
|
| 107 |
+
|
| 108 |
+
# Start the application
|
| 109 |
+
echo "🚀 Starting Mirage AI Avatar System..."
|
| 110 |
+
|
| 111 |
+
if [[ "${DEPLOYMENT_TARGET}" == "huggingface" ]]; then
|
| 112 |
+
# HuggingFace Spaces deployment
|
| 113 |
+
echo "🤗 Running on HuggingFace Spaces with A10G GPU"
|
| 114 |
+
exec python3 -u app.py
|
| 115 |
+
else
|
| 116 |
+
# Local deployment
|
| 117 |
+
echo "💻 Running locally"
|
| 118 |
+
|
| 119 |
+
# Check if port 7860 is available
|
| 120 |
+
if lsof -Pi :7860 -sTCP:LISTEN -t >/dev/null; then
|
| 121 |
+
echo "⚠️ Port 7860 is already in use"
|
| 122 |
+
PORT=7861
|
| 123 |
+
else
|
| 124 |
+
PORT=7860
|
| 125 |
+
fi
|
| 126 |
+
|
| 127 |
+
echo "🌐 Server will be available at: http://localhost:${PORT}"
|
| 128 |
+
export PORT=${PORT}
|
| 129 |
+
exec python3 -u app.py
|
| 130 |
+
fi
|
fastapi_app.py
CHANGED
|
@@ -1,368 +1,211 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
import time
|
| 7 |
-
import
|
| 8 |
-
import subprocess
|
| 9 |
-
import json
|
| 10 |
import os
|
| 11 |
-
import
|
| 12 |
-
|
| 13 |
-
import cv2
|
| 14 |
-
from typing import Any, Dict, List
|
| 15 |
-
from metrics import metrics as _metrics_singleton, Metrics
|
| 16 |
-
from config import config
|
| 17 |
-
from voice_processor import voice_processor
|
| 18 |
-
from avatar_pipeline import get_pipeline
|
| 19 |
-
|
| 20 |
-
app = FastAPI(title="Mirage Real-time AI Avatar System")
|
| 21 |
-
|
| 22 |
-
# Initialize AI pipeline
|
| 23 |
-
pipeline = get_pipeline()
|
| 24 |
-
pipeline_initialized = False
|
| 25 |
-
|
| 26 |
-
# Potentially reconfigure metrics based on config
|
| 27 |
-
if config.metrics_fps_window != 30: # default in metrics module
|
| 28 |
-
metrics = Metrics(fps_window=config.metrics_fps_window)
|
| 29 |
-
else:
|
| 30 |
-
metrics = _metrics_singleton
|
| 31 |
-
|
| 32 |
-
# Mount the static directory
|
| 33 |
-
static_dir = Path(__file__).parent / "static"
|
| 34 |
-
app.mount("/static", StaticFiles(directory=str(static_dir)), name="static")
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
@app.get("/", response_class=HTMLResponse)
|
| 38 |
-
async def root():
|
| 39 |
-
"""Serve the static/index.html file contents as HTML."""
|
| 40 |
-
index_path = static_dir / "index.html"
|
| 41 |
-
try:
|
| 42 |
-
content = index_path.read_text(encoding="utf-8")
|
| 43 |
-
except FileNotFoundError:
|
| 44 |
-
# Minimal fallback to satisfy route even if file not yet present.
|
| 45 |
-
content = "<html><body><h1>Mirage AI Avatar System</h1><p>Real-time AI avatar with face animation and voice conversion.</p></body></html>"
|
| 46 |
-
return HTMLResponse(content)
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
@app.get("/health")
|
| 50 |
-
async def health():
|
| 51 |
-
return {
|
| 52 |
-
"status": "ok",
|
| 53 |
-
"system": "real-time-ai-avatar",
|
| 54 |
-
"pipeline_loaded": pipeline_initialized,
|
| 55 |
-
"gpu_available": pipeline.config.device == "cuda"
|
| 56 |
-
}
|
| 57 |
|
|
|
|
|
|
|
| 58 |
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
"""Initialize the AI pipeline"""
|
| 62 |
-
global pipeline_initialized
|
| 63 |
-
|
| 64 |
-
if pipeline_initialized:
|
| 65 |
-
return {"status": "already_initialized", "message": "Pipeline already loaded"}
|
| 66 |
-
|
| 67 |
-
try:
|
| 68 |
-
success = await pipeline.initialize()
|
| 69 |
-
if success:
|
| 70 |
-
pipeline_initialized = True
|
| 71 |
-
return {"status": "success", "message": "Pipeline initialized successfully"}
|
| 72 |
-
else:
|
| 73 |
-
return {"status": "error", "message": "Failed to initialize pipeline"}
|
| 74 |
-
except Exception as e:
|
| 75 |
-
return {"status": "error", "message": f"Initialization error: {str(e)}"}
|
| 76 |
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
if not pipeline_initialized:
|
| 84 |
-
raise HTTPException(status_code=400, detail="Pipeline not initialized")
|
| 85 |
-
|
| 86 |
-
try:
|
| 87 |
-
# Read uploaded image
|
| 88 |
-
contents = await file.read()
|
| 89 |
-
nparr = np.frombuffer(contents, np.uint8)
|
| 90 |
-
frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
| 91 |
-
|
| 92 |
-
if frame is None:
|
| 93 |
-
raise HTTPException(status_code=400, detail="Invalid image format")
|
| 94 |
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
# Frame counter for processing
|
| 108 |
-
frame_counter = 0
|
| 109 |
-
|
| 110 |
-
async def _process_websocket(websocket: WebSocket, kind: str):
|
| 111 |
-
"""Enhanced WebSocket handler with AI processing"""
|
| 112 |
-
global frame_counter, pipeline_initialized
|
| 113 |
-
|
| 114 |
-
await websocket.accept()
|
| 115 |
-
last_ts = time.time() * 1000.0 if kind == "audio" else None
|
| 116 |
|
| 117 |
-
|
|
|
|
| 118 |
try:
|
| 119 |
-
|
| 120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
interval = now - last_ts
|
| 127 |
-
|
| 128 |
-
infer_ms = None
|
| 129 |
-
# Convert raw bytes -> int16 array for processing path
|
| 130 |
-
pcm_int16 = array.array('h')
|
| 131 |
-
pcm_int16.frombytes(data)
|
| 132 |
-
|
| 133 |
-
if config.voice_enable and pipeline_initialized:
|
| 134 |
-
# AI voice conversion
|
| 135 |
-
audio_np = np.array(pcm_int16, dtype=np.int16)
|
| 136 |
-
processed_audio = pipeline.process_audio_chunk(audio_np)
|
| 137 |
-
data = processed_audio.astype(np.int16).tobytes()
|
| 138 |
-
infer_ms = 50 # Placeholder timing
|
| 139 |
-
elif config.voice_enable:
|
| 140 |
-
# Fallback to voice processor
|
| 141 |
-
processed_view, infer_ms = voice_processor.process_pcm_int16(pcm_int16.tobytes(), sample_rate=16000)
|
| 142 |
-
data = processed_view.tobytes()
|
| 143 |
-
else:
|
| 144 |
-
# Pass-through
|
| 145 |
-
data = pcm_int16.tobytes()
|
| 146 |
-
|
| 147 |
-
metrics.record_audio_chunk(size_bytes=size, loop_interval_ms=interval, infer_time_ms=infer_ms)
|
| 148 |
-
last_ts = now
|
| 149 |
-
|
| 150 |
-
elif kind == "video":
|
| 151 |
-
if pipeline_initialized:
|
| 152 |
-
try:
|
| 153 |
-
# Decode JPEG frame
|
| 154 |
-
nparr = np.frombuffer(data, np.uint8)
|
| 155 |
-
frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
| 156 |
-
|
| 157 |
-
if frame is not None:
|
| 158 |
-
# AI face animation
|
| 159 |
-
processed_frame = pipeline.process_video_frame(frame, frame_counter)
|
| 160 |
-
frame_counter += 1
|
| 161 |
-
|
| 162 |
-
# Encode back to JPEG
|
| 163 |
-
_, encoded = cv2.imencode('.jpg', processed_frame, [cv2.IMWRITE_JPEG_QUALITY, 65])
|
| 164 |
-
data = encoded.tobytes()
|
| 165 |
-
except Exception as e:
|
| 166 |
-
print(f"Video processing error: {e}")
|
| 167 |
-
# Fallback to original data
|
| 168 |
-
pass
|
| 169 |
-
|
| 170 |
-
metrics.record_video_frame(size_bytes=size)
|
| 171 |
|
| 172 |
-
|
| 173 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
@app.websocket("/audio")
|
| 184 |
-
async def audio_ws(websocket: WebSocket):
|
| 185 |
-
await _process_websocket(websocket, "audio")
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
@app.websocket("/video")
|
| 189 |
-
async def video_ws(websocket: WebSocket):
|
| 190 |
-
await _process_websocket(websocket, "video")
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
@app.get("/metrics")
|
| 194 |
-
async def get_metrics():
|
| 195 |
-
base_metrics = metrics.snapshot()
|
| 196 |
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
if not pipeline_initialized:
|
| 211 |
-
return {
|
| 212 |
-
"initialized": False,
|
| 213 |
-
"message": "Pipeline not initialized"
|
| 214 |
-
}
|
| 215 |
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
}
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
"available": False,
|
| 242 |
-
"provider": None,
|
| 243 |
-
"device_count": 0,
|
| 244 |
-
"devices": [], # type: ignore[list-item]
|
| 245 |
-
}
|
| 246 |
-
|
| 247 |
-
# Try torch first (lazy import)
|
| 248 |
-
try:
|
| 249 |
-
import torch # type: ignore
|
| 250 |
-
|
| 251 |
-
if torch.cuda.is_available():
|
| 252 |
-
resp["available"] = True
|
| 253 |
-
resp["provider"] = "torch"
|
| 254 |
-
count = torch.cuda.device_count()
|
| 255 |
-
resp["device_count"] = count
|
| 256 |
-
devices: List[Dict[str, Any]] = []
|
| 257 |
-
for idx in range(count):
|
| 258 |
-
name = torch.cuda.get_device_name(idx)
|
| 259 |
-
try:
|
| 260 |
-
free_bytes, total_bytes = torch.cuda.mem_get_info(idx) # type: ignore[arg-type]
|
| 261 |
-
except TypeError:
|
| 262 |
-
# Older PyTorch versions take no index
|
| 263 |
-
free_bytes, total_bytes = torch.cuda.mem_get_info()
|
| 264 |
-
allocated = torch.cuda.memory_allocated(idx)
|
| 265 |
-
reserved = torch.cuda.memory_reserved(idx)
|
| 266 |
-
# Estimate free including unallocated reserved as reclaimable
|
| 267 |
-
est_free = free_bytes + max(reserved - allocated, 0)
|
| 268 |
-
to_mb = lambda b: round(b / (1024 * 1024), 2)
|
| 269 |
-
devices.append({
|
| 270 |
-
"index": idx,
|
| 271 |
-
"name": name,
|
| 272 |
-
"total_mb": to_mb(total_bytes),
|
| 273 |
-
"allocated_mb": to_mb(allocated),
|
| 274 |
-
"reserved_mb": to_mb(reserved),
|
| 275 |
-
"free_mem_get_info_mb": to_mb(free_bytes),
|
| 276 |
-
"free_estimate_mb": to_mb(est_free),
|
| 277 |
-
})
|
| 278 |
-
resp["devices"] = devices
|
| 279 |
-
return resp
|
| 280 |
-
except Exception: # noqa: BLE001
|
| 281 |
-
# Torch not installed or failed; fall through to nvidia-smi
|
| 282 |
-
pass
|
| 283 |
-
|
| 284 |
-
# Try nvidia-smi fallback
|
| 285 |
-
try:
|
| 286 |
-
cmd = [
|
| 287 |
-
"nvidia-smi",
|
| 288 |
-
"--query-gpu=name,memory.total,memory.used",
|
| 289 |
-
"--format=csv,noheader,nounits",
|
| 290 |
-
]
|
| 291 |
-
out = subprocess.check_output(cmd, stderr=subprocess.STDOUT, timeout=2).decode("utf-8").strip()
|
| 292 |
-
lines = [l for l in out.splitlines() if l.strip()]
|
| 293 |
-
if lines:
|
| 294 |
-
resp["available"] = True
|
| 295 |
-
resp["provider"] = "nvidia-smi"
|
| 296 |
-
resp["device_count"] = len(lines)
|
| 297 |
-
devices: List[Dict[str, Any]] = []
|
| 298 |
-
for idx, line in enumerate(lines):
|
| 299 |
-
# Expect: name, total, used
|
| 300 |
-
parts = [p.strip() for p in line.split(',')]
|
| 301 |
-
if len(parts) >= 3:
|
| 302 |
-
name, total_str, used_str = parts[:3]
|
| 303 |
-
try:
|
| 304 |
-
total = float(total_str)
|
| 305 |
-
used = float(used_str)
|
| 306 |
-
free = max(total - used, 0)
|
| 307 |
-
except ValueError:
|
| 308 |
-
total = used = free = 0.0
|
| 309 |
-
devices.append({
|
| 310 |
-
"index": idx,
|
| 311 |
-
"name": name,
|
| 312 |
-
"total_mb": total,
|
| 313 |
-
"allocated_mb": used, # approximate
|
| 314 |
-
"reserved_mb": None,
|
| 315 |
-
"free_estimate_mb": free,
|
| 316 |
-
})
|
| 317 |
-
resp["devices"] = devices
|
| 318 |
-
return resp
|
| 319 |
-
except Exception: # noqa: BLE001
|
| 320 |
-
pass
|
| 321 |
-
|
| 322 |
-
return resp
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
@app.on_event("startup")
|
| 326 |
-
async def log_config():
|
| 327 |
-
# Enhanced startup logging: core config + GPU availability summary.
|
| 328 |
-
cfg = config.as_dict()
|
| 329 |
-
# GPU probe (reuse gpu_info logic minimally without full device list to keep log concise)
|
| 330 |
-
gpu_available = False
|
| 331 |
-
gpu_name = None
|
| 332 |
-
try:
|
| 333 |
-
import torch # type: ignore
|
| 334 |
-
if torch.cuda.is_available():
|
| 335 |
-
gpu_available = True
|
| 336 |
-
gpu_name = torch.cuda.get_device_name(0)
|
| 337 |
-
else:
|
| 338 |
-
# Fallback quick nvidia-smi single line
|
| 339 |
-
try:
|
| 340 |
-
out = subprocess.check_output([
|
| 341 |
-
"nvidia-smi", "--query-gpu=name", "--format=csv,noheader,nounits"
|
| 342 |
-
], stderr=subprocess.STDOUT, timeout=1).decode("utf-8").strip().splitlines()
|
| 343 |
-
if out:
|
| 344 |
-
gpu_available = True
|
| 345 |
-
gpu_name = out[0].strip()
|
| 346 |
-
except Exception: # noqa: BLE001
|
| 347 |
-
pass
|
| 348 |
-
except Exception: # noqa: BLE001
|
| 349 |
-
pass
|
| 350 |
-
# Honor dynamic PORT if provided (HF Spaces usually fixed at 7860 for docker, but logging helps debugging)
|
| 351 |
-
listen_port = int(os.getenv("PORT", "7860"))
|
| 352 |
-
startup_line = {
|
| 353 |
-
"chunk_ms": cfg.get("chunk_ms"),
|
| 354 |
-
"voice_enabled": cfg.get("voice_enable"),
|
| 355 |
-
"metrics_fps_window": cfg.get("metrics_fps_window"),
|
| 356 |
-
"video_fps_limit": cfg.get("video_max_fps"),
|
| 357 |
-
"port": listen_port,
|
| 358 |
-
"gpu_available": gpu_available,
|
| 359 |
-
"gpu_name": gpu_name,
|
| 360 |
-
}
|
| 361 |
-
print("[startup]", startup_line)
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
# Note: The Dockerfile / README launch with: uvicorn app:app --port 7860
|
| 365 |
-
if __name__ == "__main__": # Optional direct run helper
|
| 366 |
-
import uvicorn # type: ignore
|
| 367 |
-
|
| 368 |
-
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False)
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Gradio interface for Mirage AI Avatar System
|
| 4 |
+
Wraps the existing FastAPI application for HuggingFace Spaces deployment
|
| 5 |
+
"""
|
| 6 |
+
import gradio as gr
|
| 7 |
+
import asyncio
|
| 8 |
+
import threading
|
| 9 |
+
import uvicorn
|
| 10 |
import time
|
| 11 |
+
import requests
|
|
|
|
|
|
|
| 12 |
import os
|
| 13 |
+
import sys
|
| 14 |
+
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
+
# Add current directory to path for imports
|
| 17 |
+
sys.path.append(str(Path(__file__).parent))
|
| 18 |
|
| 19 |
+
# Import our existing app
|
| 20 |
+
from fastapi_app import app as fastapi_app
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
+
class MirageInterface:
|
| 23 |
+
def __init__(self):
|
| 24 |
+
self.server_port = 7860 # Gradio default port
|
| 25 |
+
self.fastapi_port = 8000
|
| 26 |
+
self.server_thread = None
|
| 27 |
+
self.server_running = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
+
def start_fastapi_server(self):
|
| 30 |
+
"""Start the FastAPI server in background"""
|
| 31 |
+
try:
|
| 32 |
+
uvicorn.run(
|
| 33 |
+
fastapi_app,
|
| 34 |
+
host="0.0.0.0",
|
| 35 |
+
port=self.fastapi_port,
|
| 36 |
+
log_level="info"
|
| 37 |
+
)
|
| 38 |
+
except Exception as e:
|
| 39 |
+
print(f"FastAPI server error: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
+
def initialize_system(self):
|
| 42 |
+
"""Initialize the AI pipeline"""
|
| 43 |
try:
|
| 44 |
+
response = requests.post(f"http://localhost:{self.fastapi_port}/initialize")
|
| 45 |
+
if response.status_code == 200:
|
| 46 |
+
return "✅ AI Pipeline initialized successfully!"
|
| 47 |
+
else:
|
| 48 |
+
return f"❌ Initialization failed: {response.text}"
|
| 49 |
+
except Exception as e:
|
| 50 |
+
return f"❌ Connection error: {str(e)}"
|
| 51 |
+
|
| 52 |
+
def upload_reference_image(self, image):
|
| 53 |
+
"""Upload reference image for avatar"""
|
| 54 |
+
if image is None:
|
| 55 |
+
return "❌ Please upload an image first"
|
| 56 |
|
| 57 |
+
try:
|
| 58 |
+
# Save uploaded image temporarily
|
| 59 |
+
image_path = "/tmp/reference_image.jpg"
|
| 60 |
+
image.save(image_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
+
with open(image_path, "rb") as f:
|
| 63 |
+
files = {"file": f}
|
| 64 |
+
response = requests.post(
|
| 65 |
+
f"http://localhost:{self.fastapi_port}/set_reference",
|
| 66 |
+
files=files
|
| 67 |
+
)
|
| 68 |
|
| 69 |
+
if response.status_code == 200:
|
| 70 |
+
return "✅ Reference image uploaded successfully!"
|
| 71 |
+
else:
|
| 72 |
+
return f"❌ Upload failed: {response.text}"
|
| 73 |
+
except Exception as e:
|
| 74 |
+
return f"❌ Upload error: {str(e)}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
+
def get_system_status(self):
|
| 77 |
+
"""Get current system status"""
|
| 78 |
+
try:
|
| 79 |
+
response = requests.get(f"http://localhost:{self.fastapi_port}/health")
|
| 80 |
+
if response.status_code == 200:
|
| 81 |
+
data = response.json()
|
| 82 |
+
return f"🟢 System Status: {data.get('status', 'Unknown')}"
|
| 83 |
+
else:
|
| 84 |
+
return "🔴 System offline"
|
| 85 |
+
except:
|
| 86 |
+
return "🔴 Cannot connect to system"
|
| 87 |
+
|
| 88 |
+
def create_interface():
|
| 89 |
+
"""Create the Gradio interface"""
|
| 90 |
+
mirage = MirageInterface()
|
| 91 |
|
| 92 |
+
# Start FastAPI server in background thread
|
| 93 |
+
server_thread = threading.Thread(target=mirage.start_fastapi_server, daemon=True)
|
| 94 |
+
server_thread.start()
|
| 95 |
+
|
| 96 |
+
# Wait a moment for server to start
|
| 97 |
+
time.sleep(2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
|
| 99 |
+
with gr.Blocks(
|
| 100 |
+
title="Mirage AI Avatar System",
|
| 101 |
+
theme=gr.themes.Soft(),
|
| 102 |
+
css="""
|
| 103 |
+
.gradio-container {
|
| 104 |
+
font-family: 'Arial', sans-serif;
|
| 105 |
}
|
| 106 |
+
.main-header {
|
| 107 |
+
text-align: center;
|
| 108 |
+
background: linear-gradient(45deg, #667eea 0%, #764ba2 100%);
|
| 109 |
+
-webkit-background-clip: text;
|
| 110 |
+
-webkit-text-fill-color: transparent;
|
| 111 |
+
font-size: 2.5em;
|
| 112 |
+
font-weight: bold;
|
| 113 |
+
margin-bottom: 20px;
|
| 114 |
}
|
| 115 |
+
"""
|
| 116 |
+
) as interface:
|
| 117 |
+
|
| 118 |
+
gr.HTML('<h1 class="main-header">🎭 Mirage AI Avatar System</h1>')
|
| 119 |
+
gr.Markdown("""
|
| 120 |
+
**Real-time AI Avatar with Face Animation & Voice Conversion**
|
| 121 |
+
|
| 122 |
+
Transform your appearance and voice in real-time for video calls. Built with LivePortrait and RVC.
|
| 123 |
+
""")
|
| 124 |
+
|
| 125 |
+
with gr.Row():
|
| 126 |
+
with gr.Column(scale=1):
|
| 127 |
+
gr.Markdown("## 📋 System Setup")
|
| 128 |
+
|
| 129 |
+
init_btn = gr.Button("🚀 Initialize AI Pipeline", variant="primary")
|
| 130 |
+
init_status = gr.Textbox(label="Initialization Status", interactive=False)
|
| 131 |
+
|
| 132 |
+
gr.Markdown("## 🖼️ Reference Image")
|
| 133 |
+
reference_image = gr.Image(
|
| 134 |
+
label="Upload your reference photo",
|
| 135 |
+
type="pil",
|
| 136 |
+
height=300
|
| 137 |
+
)
|
| 138 |
+
upload_btn = gr.Button("📤 Set Reference Image", variant="secondary")
|
| 139 |
+
upload_status = gr.Textbox(label="Upload Status", interactive=False)
|
| 140 |
+
|
| 141 |
+
with gr.Column(scale=2):
|
| 142 |
+
gr.Markdown("## 🎥 Live Avatar Interface")
|
| 143 |
+
|
| 144 |
+
gr.HTML(f"""
|
| 145 |
+
<iframe
|
| 146 |
+
src="http://localhost:{mirage.fastapi_port}/"
|
| 147 |
+
width="100%"
|
| 148 |
+
height="600px"
|
| 149 |
+
frameborder="0"
|
| 150 |
+
style="border-radius: 10px; border: 2px solid #ddd;">
|
| 151 |
+
</iframe>
|
| 152 |
+
""")
|
| 153 |
+
|
| 154 |
+
status_btn = gr.Button("🔍 Check System Status")
|
| 155 |
+
system_status = gr.Textbox(label="System Status", interactive=False)
|
| 156 |
+
|
| 157 |
+
gr.Markdown("""
|
| 158 |
+
## 🎯 How to Use
|
| 159 |
+
|
| 160 |
+
1. **Initialize**: Click "Initialize AI Pipeline" and wait for confirmation
|
| 161 |
+
2. **Reference**: Upload a clear photo of the person you want to become
|
| 162 |
+
3. **Setup**: Click "Set Reference Image" to configure your avatar
|
| 163 |
+
4. **Go Live**: Use the interface above to start your camera and see your AI avatar!
|
| 164 |
+
|
| 165 |
+
## 🚀 Features
|
| 166 |
+
|
| 167 |
+
- **Real-time Processing**: <250ms latency for smooth interaction
|
| 168 |
+
- **Face Animation**: Powered by LivePortrait technology
|
| 169 |
+
- **Voice Conversion**: RVC-based voice transformation
|
| 170 |
+
- **GPU Accelerated**: Optimized for NVIDIA A10G hardware
|
| 171 |
+
- **Virtual Camera**: Ready for Zoom, Teams, Discord integration
|
| 172 |
+
|
| 173 |
+
## ⚙️ Technical Details
|
| 174 |
+
|
| 175 |
+
- **Backend**: FastAPI with WebSocket streaming
|
| 176 |
+
- **Models**: InsightFace + LivePortrait + RVC
|
| 177 |
+
- **Hardware**: NVIDIA A10G GPU with CUDA 12.1
|
| 178 |
+
- **Performance**: 20 FPS video, 160ms audio chunks
|
| 179 |
+
""")
|
| 180 |
+
|
| 181 |
+
# Event handlers
|
| 182 |
+
init_btn.click(
|
| 183 |
+
fn=mirage.initialize_system,
|
| 184 |
+
outputs=init_status
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
upload_btn.click(
|
| 188 |
+
fn=mirage.upload_reference_image,
|
| 189 |
+
inputs=reference_image,
|
| 190 |
+
outputs=upload_status
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
status_btn.click(
|
| 194 |
+
fn=mirage.get_system_status,
|
| 195 |
+
outputs=system_status
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
return interface
|
| 199 |
|
| 200 |
+
if __name__ == "__main__":
|
| 201 |
+
# Create and launch the interface
|
| 202 |
+
interface = create_interface()
|
| 203 |
+
|
| 204 |
+
# Launch with public sharing enabled for HuggingFace Spaces
|
| 205 |
+
interface.launch(
|
| 206 |
+
server_name="0.0.0.0",
|
| 207 |
+
server_port=7860,
|
| 208 |
+
share=False, # HF Spaces handles sharing
|
| 209 |
+
show_error=True,
|
| 210 |
+
quiet=False
|
| 211 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/hubert/.gitkeep
ADDED
|
File without changes
|
models/rmvpe/.gitkeep
ADDED
|
File without changes
|
models/rvc/.gitkeep
ADDED
|
File without changes
|
original_fastapi_app.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI, HTTPException, File, UploadFile
|
| 2 |
+
from fastapi.responses import HTMLResponse, JSONResponse
|
| 3 |
+
from fastapi.staticfiles import StaticFiles
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import traceback
|
| 6 |
+
import time
|
| 7 |
+
import subprocess
|
| 8 |
+
import json
|
| 9 |
+
import os
|
| 10 |
+
import asyncio
|
| 11 |
+
import numpy as np
|
| 12 |
+
import cv2
|
| 13 |
+
from typing import Any, Dict, List
|
| 14 |
+
from metrics import metrics as _metrics_singleton, Metrics
|
| 15 |
+
from config import config
|
| 16 |
+
from voice_processor import voice_processor
|
| 17 |
+
from avatar_pipeline import get_pipeline
|
| 18 |
+
|
| 19 |
+
app = FastAPI(title="Mirage Real-time AI Avatar System")
|
| 20 |
+
|
| 21 |
+
# Initialize AI pipeline
|
| 22 |
+
pipeline = get_pipeline()
|
| 23 |
+
pipeline_initialized = False
|
| 24 |
+
|
| 25 |
+
# Potentially reconfigure metrics based on config
|
| 26 |
+
if config.metrics_fps_window != 30: # default in metrics module
|
| 27 |
+
metrics = Metrics(fps_window=config.metrics_fps_window)
|
| 28 |
+
else:
|
| 29 |
+
metrics = _metrics_singleton
|
| 30 |
+
|
| 31 |
+
# Mount the static directory
|
| 32 |
+
static_dir = Path(__file__).parent / "static"
|
| 33 |
+
app.mount("/static", StaticFiles(directory=str(static_dir)), name="static")
|
| 34 |
+
|
| 35 |
+
# Mount WebRTC router (aiortc based)
|
| 36 |
+
try:
|
| 37 |
+
from webrtc_server import router as webrtc_router # type: ignore
|
| 38 |
+
app.include_router(webrtc_router)
|
| 39 |
+
except Exception as e: # pragma: no cover
|
| 40 |
+
print(f"[WARN] WebRTC router not loaded: {e}")
|
| 41 |
+
|
| 42 |
+
@app.get("/", response_class=HTMLResponse)
|
| 43 |
+
async def root():
|
| 44 |
+
"""Serve the static/index.html file contents as HTML."""
|
| 45 |
+
index_path = static_dir / "index.html"
|
| 46 |
+
try:
|
| 47 |
+
content = index_path.read_text(encoding="utf-8")
|
| 48 |
+
except FileNotFoundError:
|
| 49 |
+
# Minimal fallback to satisfy route even if file not yet present.
|
| 50 |
+
content = "<html><body><h1>Mirage AI Avatar System</h1><p>Real-time AI avatar with face animation and voice conversion.</p></body></html>"
|
| 51 |
+
return HTMLResponse(content)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@app.get("/health")
|
| 55 |
+
async def health():
|
| 56 |
+
return {
|
| 57 |
+
"status": "ok",
|
| 58 |
+
"system": "real-time-ai-avatar",
|
| 59 |
+
"pipeline_loaded": pipeline_initialized,
|
| 60 |
+
"gpu_available": pipeline.config.device == "cuda"
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
@app.post("/initialize")
|
| 65 |
+
async def initialize_pipeline():
|
| 66 |
+
"""Initialize the AI pipeline"""
|
| 67 |
+
global pipeline_initialized
|
| 68 |
+
|
| 69 |
+
if pipeline_initialized:
|
| 70 |
+
return {"status": "already_initialized", "message": "Pipeline already loaded"}
|
| 71 |
+
|
| 72 |
+
try:
|
| 73 |
+
success = await pipeline.initialize()
|
| 74 |
+
if success:
|
| 75 |
+
pipeline_initialized = True
|
| 76 |
+
return {"status": "success", "message": "Pipeline initialized successfully"}
|
| 77 |
+
else:
|
| 78 |
+
return {"status": "error", "message": "Failed to initialize pipeline"}
|
| 79 |
+
except Exception as e:
|
| 80 |
+
return {"status": "error", "message": f"Initialization error: {str(e)}"}
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
@app.post("/set_reference")
|
| 84 |
+
async def set_reference_image(file: UploadFile = File(...)):
|
| 85 |
+
"""Set reference image for avatar"""
|
| 86 |
+
global pipeline_initialized
|
| 87 |
+
|
| 88 |
+
if not pipeline_initialized:
|
| 89 |
+
raise HTTPException(status_code=400, detail="Pipeline not initialized")
|
| 90 |
+
|
| 91 |
+
try:
|
| 92 |
+
# Read uploaded image
|
| 93 |
+
contents = await file.read()
|
| 94 |
+
nparr = np.frombuffer(contents, np.uint8)
|
| 95 |
+
frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
| 96 |
+
|
| 97 |
+
if frame is None:
|
| 98 |
+
raise HTTPException(status_code=400, detail="Invalid image format")
|
| 99 |
+
|
| 100 |
+
# Set as reference frame
|
| 101 |
+
success = pipeline.set_reference_frame(frame)
|
| 102 |
+
|
| 103 |
+
if success:
|
| 104 |
+
return {"status": "success", "message": "Reference image set successfully"}
|
| 105 |
+
else:
|
| 106 |
+
return {"status": "error", "message": "No suitable face found in image"}
|
| 107 |
+
|
| 108 |
+
except Exception as e:
|
| 109 |
+
return {"status": "error", "message": f"Error setting reference: {str(e)}"}
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
# Note: Legacy WebSocket streaming endpoints removed in production.
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
@app.get("/metrics")
|
| 116 |
+
async def get_metrics():
|
| 117 |
+
base_metrics = metrics.snapshot()
|
| 118 |
+
|
| 119 |
+
# Add AI pipeline metrics if available
|
| 120 |
+
if pipeline_initialized:
|
| 121 |
+
pipeline_stats = pipeline.get_performance_stats()
|
| 122 |
+
base_metrics.update({
|
| 123 |
+
"ai_pipeline": pipeline_stats
|
| 124 |
+
})
|
| 125 |
+
|
| 126 |
+
return base_metrics
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
@app.get("/pipeline_status")
|
| 130 |
+
async def get_pipeline_status():
|
| 131 |
+
"""Get detailed pipeline status"""
|
| 132 |
+
if not pipeline_initialized:
|
| 133 |
+
return {
|
| 134 |
+
"initialized": False,
|
| 135 |
+
"message": "Pipeline not initialized"
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
try:
|
| 139 |
+
stats = pipeline.get_performance_stats()
|
| 140 |
+
return {
|
| 141 |
+
"initialized": True,
|
| 142 |
+
"stats": stats,
|
| 143 |
+
"reference_set": pipeline.reference_frame is not None
|
| 144 |
+
}
|
| 145 |
+
except Exception as e:
|
| 146 |
+
return {
|
| 147 |
+
"initialized": False,
|
| 148 |
+
"error": str(e)
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
@app.get("/gpu")
|
| 153 |
+
async def gpu_info():
|
| 154 |
+
"""Return basic GPU availability and memory statistics.
|
| 155 |
+
|
| 156 |
+
Priority order:
|
| 157 |
+
1. torch (if installed and CUDA available) for detailed stats per device.
|
| 158 |
+
2. nvidia-smi (if executable present) for name/total/used.
|
| 159 |
+
3. Fallback: available false.
|
| 160 |
+
"""
|
| 161 |
+
# Response scaffold
|
| 162 |
+
resp: Dict[str, Any] = {
|
| 163 |
+
"available": False,
|
| 164 |
+
"provider": None,
|
| 165 |
+
"device_count": 0,
|
| 166 |
+
"devices": [], # type: ignore[list-item]
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
# Try torch first (lazy import)
|
| 170 |
+
try:
|
| 171 |
+
import torch # type: ignore
|
| 172 |
+
|
| 173 |
+
if torch.cuda.is_available():
|
| 174 |
+
resp["available"] = True
|
| 175 |
+
resp["provider"] = "torch"
|
| 176 |
+
count = torch.cuda.device_count()
|
| 177 |
+
resp["device_count"] = count
|
| 178 |
+
devices: List[Dict[str, Any]] = []
|
| 179 |
+
for idx in range(count):
|
| 180 |
+
name = torch.cuda.get_device_name(idx)
|
| 181 |
+
try:
|
| 182 |
+
free_bytes, total_bytes = torch.cuda.mem_get_info(idx) # type: ignore[arg-type]
|
| 183 |
+
except TypeError:
|
| 184 |
+
# Older PyTorch versions take no index
|
| 185 |
+
free_bytes, total_bytes = torch.cuda.mem_get_info()
|
| 186 |
+
allocated = torch.cuda.memory_allocated(idx)
|
| 187 |
+
reserved = torch.cuda.memory_reserved(idx)
|
| 188 |
+
# Estimate free including unallocated reserved as reclaimable
|
| 189 |
+
est_free = free_bytes + max(reserved - allocated, 0)
|
| 190 |
+
to_mb = lambda b: round(b / (1024 * 1024), 2)
|
| 191 |
+
devices.append({
|
| 192 |
+
"index": idx,
|
| 193 |
+
"name": name,
|
| 194 |
+
"total_mb": to_mb(total_bytes),
|
| 195 |
+
"allocated_mb": to_mb(allocated),
|
| 196 |
+
"reserved_mb": to_mb(reserved),
|
| 197 |
+
"free_mem_get_info_mb": to_mb(free_bytes),
|
| 198 |
+
"free_estimate_mb": to_mb(est_free),
|
| 199 |
+
})
|
| 200 |
+
resp["devices"] = devices
|
| 201 |
+
return resp
|
| 202 |
+
except Exception: # noqa: BLE001
|
| 203 |
+
# Torch not installed or failed; fall through to nvidia-smi
|
| 204 |
+
pass
|
| 205 |
+
|
| 206 |
+
# Try nvidia-smi fallback
|
| 207 |
+
try:
|
| 208 |
+
cmd = [
|
| 209 |
+
"nvidia-smi",
|
| 210 |
+
"--query-gpu=name,memory.total,memory.used",
|
| 211 |
+
"--format=csv,noheader,nounits",
|
| 212 |
+
]
|
| 213 |
+
out = subprocess.check_output(cmd, stderr=subprocess.STDOUT, timeout=2).decode("utf-8").strip()
|
| 214 |
+
lines = [l for l in out.splitlines() if l.strip()]
|
| 215 |
+
if lines:
|
| 216 |
+
resp["available"] = True
|
| 217 |
+
resp["provider"] = "nvidia-smi"
|
| 218 |
+
resp["device_count"] = len(lines)
|
| 219 |
+
devices: List[Dict[str, Any]] = []
|
| 220 |
+
for idx, line in enumerate(lines):
|
| 221 |
+
# Expect: name, total, used
|
| 222 |
+
parts = [p.strip() for p in line.split(',')]
|
| 223 |
+
if len(parts) >= 3:
|
| 224 |
+
name, total_str, used_str = parts[:3]
|
| 225 |
+
try:
|
| 226 |
+
total = float(total_str)
|
| 227 |
+
used = float(used_str)
|
| 228 |
+
free = max(total - used, 0)
|
| 229 |
+
except ValueError:
|
| 230 |
+
total = used = free = 0.0
|
| 231 |
+
devices.append({
|
| 232 |
+
"index": idx,
|
| 233 |
+
"name": name,
|
| 234 |
+
"total_mb": total,
|
| 235 |
+
"allocated_mb": used, # approximate
|
| 236 |
+
"reserved_mb": None,
|
| 237 |
+
"free_estimate_mb": free,
|
| 238 |
+
})
|
| 239 |
+
resp["devices"] = devices
|
| 240 |
+
return resp
|
| 241 |
+
except Exception: # noqa: BLE001
|
| 242 |
+
pass
|
| 243 |
+
|
| 244 |
+
return resp
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
@app.on_event("startup")
|
| 248 |
+
async def log_config():
|
| 249 |
+
# Enhanced startup logging: core config + GPU availability summary.
|
| 250 |
+
cfg = config.as_dict()
|
| 251 |
+
# GPU probe (reuse gpu_info logic minimally without full device list to keep log concise)
|
| 252 |
+
gpu_available = False
|
| 253 |
+
gpu_name = None
|
| 254 |
+
try:
|
| 255 |
+
import torch # type: ignore
|
| 256 |
+
if torch.cuda.is_available():
|
| 257 |
+
gpu_available = True
|
| 258 |
+
gpu_name = torch.cuda.get_device_name(0)
|
| 259 |
+
else:
|
| 260 |
+
# Fallback quick nvidia-smi single line
|
| 261 |
+
try:
|
| 262 |
+
out = subprocess.check_output([
|
| 263 |
+
"nvidia-smi", "--query-gpu=name", "--format=csv,noheader,nounits"
|
| 264 |
+
], stderr=subprocess.STDOUT, timeout=1).decode("utf-8").strip().splitlines()
|
| 265 |
+
if out:
|
| 266 |
+
gpu_available = True
|
| 267 |
+
gpu_name = out[0].strip()
|
| 268 |
+
except Exception: # noqa: BLE001
|
| 269 |
+
pass
|
| 270 |
+
except Exception: # noqa: BLE001
|
| 271 |
+
pass
|
| 272 |
+
# Honor dynamic PORT if provided (HF Spaces usually fixed at 7860 for docker, but logging helps debugging)
|
| 273 |
+
listen_port = int(os.getenv("PORT", "7860"))
|
| 274 |
+
startup_line = {
|
| 275 |
+
"chunk_ms": cfg.get("chunk_ms"),
|
| 276 |
+
"voice_enabled": cfg.get("voice_enable"),
|
| 277 |
+
"metrics_fps_window": cfg.get("metrics_fps_window"),
|
| 278 |
+
"video_fps_limit": cfg.get("video_max_fps"),
|
| 279 |
+
"port": listen_port,
|
| 280 |
+
"gpu_available": gpu_available,
|
| 281 |
+
"gpu_name": gpu_name,
|
| 282 |
+
}
|
| 283 |
+
print("[startup]", startup_line)
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
# Note: The Dockerfile / README launch with: uvicorn app:app --port 7860
|
| 287 |
+
if __name__ == "__main__": # Optional direct run helper
|
| 288 |
+
import uvicorn # type: ignore
|
| 289 |
+
|
| 290 |
+
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False)
|
requirements.txt
CHANGED
|
@@ -1,25 +1,15 @@
|
|
| 1 |
-
# Core Dependencies
|
| 2 |
-
gradio==4.44.0
|
| 3 |
-
torch==2.3.1
|
| 4 |
-
opencv-python-headless==4.9.0.80
|
| 5 |
-
pillow==10.3.0
|
| 6 |
-
|
| 7 |
-
# Optional - loaded on demand
|
| 8 |
fastapi==0.111.0
|
| 9 |
uvicorn[standard]==0.30.1
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
| 11 |
insightface==0.7.3
|
|
|
|
| 12 |
librosa==0.10.2
|
| 13 |
-
|
| 14 |
-
# ONNX & GPU Acceleration
|
| 15 |
onnx==1.16.1
|
| 16 |
onnxruntime-gpu==1.18.1
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
psutil==5.9.8
|
| 22 |
-
|
| 23 |
-
# Optional GPU Optimization (may not be available on HF Spaces)
|
| 24 |
-
# tensorrt==10.3.0
|
| 25 |
-
# pycuda==2024.1.2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
fastapi==0.111.0
|
| 2 |
uvicorn[standard]==0.30.1
|
| 3 |
+
numpy==1.26.4
|
| 4 |
+
opencv-python-headless==4.9.0.80
|
| 5 |
+
pillow==10.3.0
|
| 6 |
+
psutil==5.9.8
|
| 7 |
insightface==0.7.3
|
| 8 |
+
transformers==4.44.2
|
| 9 |
librosa==0.10.2
|
|
|
|
|
|
|
| 10 |
onnx==1.16.1
|
| 11 |
onnxruntime-gpu==1.18.1
|
| 12 |
+
huggingface-hub==0.24.5
|
| 13 |
+
python-multipart==0.0.9
|
| 14 |
+
aiortc==1.7.0
|
| 15 |
+
av==11.0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements_old.txt
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi==0.111.0
|
| 2 |
+
uvicorn[standard]==0.30.1
|
| 3 |
+
websockets==12.0
|
| 4 |
+
jinja2==3.1.4
|
| 5 |
+
numpy==1.26.4
|
| 6 |
+
psutil==5.9.8
|
| 7 |
+
pillow==10.3.0
|
| 8 |
+
torch==2.3.1
|
| 9 |
+
torchaudio==2.3.1
|
| 10 |
+
opencv-python==4.9.0.80
|
| 11 |
+
insightface==0.7.3
|
| 12 |
+
onnx==1.16.1
|
| 13 |
+
onnxruntime-gpu==1.18.1
|
| 14 |
+
huggingface-hub==0.24.5
|
| 15 |
+
transformers==4.44.2
|
| 16 |
+
accelerate==0.33.0
|
| 17 |
+
diffusers==0.30.0
|
| 18 |
+
python-multipart==0.0.9
|
| 19 |
+
librosa==0.10.2
|
| 20 |
+
scipy==1.13.1
|
| 21 |
+
scikit-image==0.24.0
|
| 22 |
+
opencv-python==4.8.1.78
|
| 23 |
+
transformers==4.42.4
|
| 24 |
+
diffusers==0.29.2
|
| 25 |
+
accelerate==0.33.0
|
| 26 |
+
xformers==0.0.27.post2
|
| 27 |
+
librosa==0.10.2
|
| 28 |
+
scipy==1.11.4
|
| 29 |
+
scikit-image==0.22.0
|
| 30 |
+
omegaconf==2.3.0
|
| 31 |
+
insightface==0.7.3
|
| 32 |
+
onnxruntime-gpu==1.18.1
|
| 33 |
+
huggingface-hub==0.24.5
|
| 34 |
+
safetensors==0.4.4
|
| 35 |
+
einops==0.8.0
|
| 36 |
+
av==12.3.0
|
| 37 |
+
tensorrt==10.3.0
|
| 38 |
+
pycuda==2024.1.2
|
static/app.js
CHANGED
|
@@ -1,4 +1,7 @@
|
|
| 1 |
-
/*
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
// Globals
|
| 4 |
let audioWs = null;
|
|
|
|
| 1 |
+
/* DEPRECATED (dev WebSocket client). Removed for production. Use webrtc_prod.js */
|
| 2 |
+
// This file intentionally contains no executable code in production deployments.
|
| 3 |
+
// It remains only to avoid broken references from older pages; index.html does not load it.
|
| 4 |
+
export {};
|
| 5 |
|
| 6 |
// Globals
|
| 7 |
let audioWs = null;
|
static/index.html
CHANGED
|
@@ -107,66 +107,31 @@
|
|
| 107 |
<body>
|
| 108 |
<div class="container">
|
| 109 |
<div class="header">
|
| 110 |
-
<h1
|
| 111 |
-
<p>
|
| 112 |
</div>
|
| 113 |
|
| 114 |
-
<div class="controls">
|
| 115 |
-
<
|
| 116 |
-
<button id="
|
| 117 |
-
<button id="
|
| 118 |
-
<
|
| 119 |
-
<button id="virtualCamBtn" disabled>Enable Virtual Camera</button>
|
| 120 |
-
</div>
|
| 121 |
-
|
| 122 |
-
<div id="statusDiv"></div>
|
| 123 |
-
|
| 124 |
-
<div class="metrics" id="metrics">
|
| 125 |
-
<div class="metric-card">
|
| 126 |
-
<div class="metric-value" id="fpsValue">0</div>
|
| 127 |
-
<div class="metric-label">Video FPS</div>
|
| 128 |
-
</div>
|
| 129 |
-
<div class="metric-card">
|
| 130 |
-
<div class="metric-value" id="latencyValue">0ms</div>
|
| 131 |
-
<div class="metric-label">Avg Latency</div>
|
| 132 |
-
</div>
|
| 133 |
-
<div class="metric-card">
|
| 134 |
-
<div class="metric-value" id="gpuValue">N/A</div>
|
| 135 |
-
<div class="metric-label">GPU Memory</div>
|
| 136 |
-
</div>
|
| 137 |
-
<div class="metric-card">
|
| 138 |
-
<div class="metric-value" id="statusValue">Idle</div>
|
| 139 |
-
<div class="metric-label">Pipeline Status</div>
|
| 140 |
-
</div>
|
| 141 |
</div>
|
| 142 |
|
| 143 |
<div class="video-container">
|
| 144 |
<div class="video-box">
|
| 145 |
-
<h3
|
| 146 |
-
<video id="
|
| 147 |
</div>
|
| 148 |
<div class="video-box">
|
| 149 |
-
<h3
|
| 150 |
-
<
|
| 151 |
-
<canvas id="virtualCanvas" style="display: none;"></canvas>
|
| 152 |
</div>
|
| 153 |
</div>
|
| 154 |
|
| 155 |
-
<div
|
| 156 |
-
<h3>📺 Virtual Camera Integration</h3>
|
| 157 |
-
<p>The AI avatar output can be used as a virtual camera in:</p>
|
| 158 |
-
<ul>
|
| 159 |
-
<li>🎥 Zoom, Google Meet, Microsoft Teams</li>
|
| 160 |
-
<li>💬 Discord, Slack, WhatsApp Desktop</li>
|
| 161 |
-
<li>📱 OBS Studio, Streamlabs</li>
|
| 162 |
-
</ul>
|
| 163 |
-
<p><strong>Setup:</strong> Enable virtual camera, then select "Mirage Virtual Camera" in your video app settings.</p>
|
| 164 |
-
</div>
|
| 165 |
-
|
| 166 |
-
<audio id="remoteAudio" autoplay></audio>
|
| 167 |
-
<div id="log"></div>
|
| 168 |
|
| 169 |
-
<script src="/static/
|
| 170 |
</div>
|
| 171 |
</body>
|
| 172 |
</html>
|
|
|
|
| 107 |
<body>
|
| 108 |
<div class="container">
|
| 109 |
<div class="header">
|
| 110 |
+
<h1>Mirage Realtime Avatar</h1>
|
| 111 |
+
<p class="subtitle">Production Preview</p>
|
| 112 |
</div>
|
| 113 |
|
| 114 |
+
<div class="controls" id="controls">
|
| 115 |
+
<input type="file" id="referenceInput" accept="image/*" title="Reference Image" />
|
| 116 |
+
<button id="connectBtn">Connect</button>
|
| 117 |
+
<button id="disconnectBtn" disabled>Disconnect</button>
|
| 118 |
+
<span id="statusText" style="margin-left:auto;font-size:12px;color:#888;">Idle</span>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
</div>
|
| 120 |
|
| 121 |
<div class="video-container">
|
| 122 |
<div class="video-box">
|
| 123 |
+
<h3>Local</h3>
|
| 124 |
+
<video id="localVideo" autoplay muted playsinline></video>
|
| 125 |
</div>
|
| 126 |
<div class="video-box">
|
| 127 |
+
<h3>Avatar</h3>
|
| 128 |
+
<video id="remoteVideo" autoplay playsinline></video>
|
|
|
|
| 129 |
</div>
|
| 130 |
</div>
|
| 131 |
|
| 132 |
+
<div id="perfBar" style="font-size:12px;color:#bbb;margin-top:10px;">Latency: -- ms · FPS: -- · GPU: --</div>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
|
| 134 |
+
<script src="/static/webrtc_prod.js"></script>
|
| 135 |
</div>
|
| 136 |
</body>
|
| 137 |
</html>
|
static/webrtc_client.js
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/* Legacy dev WebRTC bootstrap (no-op in production). */
|
| 2 |
+
(function(){
|
| 3 |
+
// intentionally empty
|
| 4 |
+
})();
|
static/webrtc_prod.js
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/* Production-focused WebRTC client (replaces dev UI). */
|
| 2 |
+
(function(){
|
| 3 |
+
const state = {
|
| 4 |
+
pc: null,
|
| 5 |
+
control: null,
|
| 6 |
+
localStream: null,
|
| 7 |
+
metricsTimer: null,
|
| 8 |
+
referenceImage: null,
|
| 9 |
+
connected: false
|
| 10 |
+
};
|
| 11 |
+
const els = {
|
| 12 |
+
ref: document.getElementById('referenceInput'),
|
| 13 |
+
connect: document.getElementById('connectBtn'),
|
| 14 |
+
disconnect: document.getElementById('disconnectBtn'),
|
| 15 |
+
localVideo: document.getElementById('localVideo'),
|
| 16 |
+
remoteVideo: document.getElementById('remoteVideo'),
|
| 17 |
+
status: document.getElementById('statusText'),
|
| 18 |
+
perf: document.getElementById('perfBar')
|
| 19 |
+
};
|
| 20 |
+
function setStatus(txt){ els.status.textContent = txt; }
|
| 21 |
+
function log(...a){ console.log('[PROD]', ...a); }
|
| 22 |
+
|
| 23 |
+
async function handleReference(e){
|
| 24 |
+
const file = e.target.files && e.target.files[0];
|
| 25 |
+
if(!file) return;
|
| 26 |
+
const buf = await file.arrayBuffer();
|
| 27 |
+
const b64 = btoa(String.fromCharCode(...new Uint8Array(buf)));
|
| 28 |
+
state.referenceImage = b64; // send after control channel open
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
async function connect(){
|
| 32 |
+
if(state.connected) return;
|
| 33 |
+
try {
|
| 34 |
+
setStatus('Requesting media');
|
| 35 |
+
els.connect.disabled = true;
|
| 36 |
+
// Fetch short-lived auth token (if server requires)
|
| 37 |
+
let authToken = null;
|
| 38 |
+
try {
|
| 39 |
+
const t = await fetch('/webrtc/token');
|
| 40 |
+
if (t.ok) {
|
| 41 |
+
const j = await t.json();
|
| 42 |
+
authToken = j.token;
|
| 43 |
+
}
|
| 44 |
+
} catch(_){}
|
| 45 |
+
state.localStream = await navigator.mediaDevices.getUserMedia({video:true,audio:true});
|
| 46 |
+
els.localVideo.srcObject = state.localStream;
|
| 47 |
+
setStatus('Creating peer');
|
| 48 |
+
state.pc = new RTCPeerConnection({iceServers:[{urls:['stun:stun.l.google.com:19302']}]});
|
| 49 |
+
state.pc.onconnectionstatechange = ()=>{ log('pc state', state.pc.connectionState); if(['failed','disconnected','closed'].includes(state.pc.connectionState)){ disconnect(); } };
|
| 50 |
+
state.pc.ontrack = ev => {
|
| 51 |
+
if(ev.streams && ev.streams[0]){
|
| 52 |
+
els.remoteVideo.srcObject = ev.streams[0];
|
| 53 |
+
}
|
| 54 |
+
};
|
| 55 |
+
state.control = state.pc.createDataChannel('control');
|
| 56 |
+
state.control.onopen = ()=>{
|
| 57 |
+
setStatus('Connected');
|
| 58 |
+
state.connected = true;
|
| 59 |
+
els.disconnect.disabled = false;
|
| 60 |
+
if(state.referenceImage){
|
| 61 |
+
try { state.control.send(JSON.stringify({type:'set_reference', image_jpeg_base64: state.referenceImage})); } catch(e) {}
|
| 62 |
+
}
|
| 63 |
+
// Metrics polling
|
| 64 |
+
state.metricsTimer = setInterval(()=>{
|
| 65 |
+
try { state.control.send(JSON.stringify({type:'metrics_request'})); }catch(_){ }
|
| 66 |
+
}, 4000);
|
| 67 |
+
};
|
| 68 |
+
state.control.onmessage = (e)=>{
|
| 69 |
+
try { const data = JSON.parse(e.data); if(data.type==='metrics' && data.payload){ updatePerf(data.payload); } } catch(_){ }
|
| 70 |
+
};
|
| 71 |
+
state.localStream.getTracks().forEach(t=> state.pc.addTrack(t, state.localStream));
|
| 72 |
+
const offer = await state.pc.createOffer();
|
| 73 |
+
await state.pc.setLocalDescription(offer);
|
| 74 |
+
setStatus('Negotiating');
|
| 75 |
+
const headers = {'Content-Type':'application/json'};
|
| 76 |
+
if (authToken) headers['X-Auth-Token'] = authToken;
|
| 77 |
+
const r = await fetch('/webrtc/offer',{method:'POST', headers, body: JSON.stringify({sdp:offer.sdp, type:offer.type})});
|
| 78 |
+
if(!r.ok){
|
| 79 |
+
if(r.status===401 || r.status===403){
|
| 80 |
+
setStatus('Unauthorized (check API key/token)');
|
| 81 |
+
} else {
|
| 82 |
+
setStatus('Offer failed '+r.status);
|
| 83 |
+
}
|
| 84 |
+
els.connect.disabled=false; return;
|
| 85 |
+
}
|
| 86 |
+
const answer = await r.json();
|
| 87 |
+
await state.pc.setRemoteDescription(answer);
|
| 88 |
+
setStatus('Finalizing');
|
| 89 |
+
} catch(e){
|
| 90 |
+
log('connect error', e);
|
| 91 |
+
setStatus('Error');
|
| 92 |
+
els.connect.disabled = false;
|
| 93 |
+
}
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
function updatePerf(p){
|
| 97 |
+
try {
|
| 98 |
+
const fps = (p.video_fps || 0).toFixed(1);
|
| 99 |
+
const lat = Math.round(p.avg_video_latency_ms || 0);
|
| 100 |
+
const gpu = (p.gpu_memory_used !== undefined) ? (p.gpu_memory_used.toFixed(2)+'GB') : '--';
|
| 101 |
+
els.perf.textContent = `Latency: ${lat} ms · FPS: ${fps} · GPU: ${gpu}`;
|
| 102 |
+
} catch(_){}
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
async function disconnect(){
|
| 106 |
+
if(state.metricsTimer){ clearInterval(state.metricsTimer); state.metricsTimer=null; }
|
| 107 |
+
if(state.control){ try { state.control.close(); }catch(_){} }
|
| 108 |
+
if(state.pc){ try { state.pc.close(); }catch(_){} }
|
| 109 |
+
if(state.localStream){ state.localStream.getTracks().forEach(t=>t.stop()); }
|
| 110 |
+
state.pc=null; state.control=null; state.localStream=null; state.connected=false;
|
| 111 |
+
els.connect.disabled=false; els.disconnect.disabled=true; setStatus('Idle');
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
els.ref.addEventListener('change', handleReference);
|
| 115 |
+
els.connect.addEventListener('click', connect);
|
| 116 |
+
els.disconnect.addEventListener('click', disconnect);
|
| 117 |
+
})();
|
test_system.py
ADDED
|
@@ -0,0 +1,380 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Testing and Validation Suite for Mirage AI Avatar System
|
| 3 |
+
Tests end-to-end functionality, latency, and performance
|
| 4 |
+
"""
|
| 5 |
+
import asyncio
|
| 6 |
+
import time
|
| 7 |
+
import aiohttp
|
| 8 |
+
import json
|
| 9 |
+
import numpy as np
|
| 10 |
+
import cv2
|
| 11 |
+
import logging
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
import subprocess
|
| 14 |
+
import psutil
|
| 15 |
+
from typing import Dict, Any, List
|
| 16 |
+
|
| 17 |
+
logging.basicConfig(level=logging.INFO)
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
class MirageSystemTester:
|
| 21 |
+
"""Comprehensive testing suite for the AI avatar system"""
|
| 22 |
+
|
| 23 |
+
def __init__(self, base_url: str = "http://localhost:7860"):
|
| 24 |
+
self.base_url = base_url
|
| 25 |
+
self.session = None
|
| 26 |
+
self.test_results = {}
|
| 27 |
+
|
| 28 |
+
async def __aenter__(self):
|
| 29 |
+
self.session = aiohttp.ClientSession()
|
| 30 |
+
return self
|
| 31 |
+
|
| 32 |
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
| 33 |
+
if self.session:
|
| 34 |
+
await self.session.close()
|
| 35 |
+
|
| 36 |
+
async def test_health_endpoint(self) -> bool:
|
| 37 |
+
"""Test basic health endpoint"""
|
| 38 |
+
try:
|
| 39 |
+
async with self.session.get(f"{self.base_url}/health") as response:
|
| 40 |
+
data = await response.json()
|
| 41 |
+
|
| 42 |
+
success = (
|
| 43 |
+
response.status == 200 and
|
| 44 |
+
data.get("status") == "ok" and
|
| 45 |
+
data.get("system") == "real-time-ai-avatar"
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
self.test_results["health"] = {
|
| 49 |
+
"success": success,
|
| 50 |
+
"status": response.status,
|
| 51 |
+
"data": data
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
logger.info(f"Health check: {'✅ PASS' if success else '❌ FAIL'}")
|
| 55 |
+
return success
|
| 56 |
+
|
| 57 |
+
except Exception as e:
|
| 58 |
+
logger.error(f"Health check failed: {e}")
|
| 59 |
+
self.test_results["health"] = {"success": False, "error": str(e)}
|
| 60 |
+
return False
|
| 61 |
+
|
| 62 |
+
async def test_pipeline_initialization(self) -> bool:
|
| 63 |
+
"""Test AI pipeline initialization"""
|
| 64 |
+
try:
|
| 65 |
+
start_time = time.time()
|
| 66 |
+
async with self.session.post(f"{self.base_url}/initialize") as response:
|
| 67 |
+
data = await response.json()
|
| 68 |
+
init_time = time.time() - start_time
|
| 69 |
+
|
| 70 |
+
success = (
|
| 71 |
+
response.status == 200 and
|
| 72 |
+
data.get("status") in ["success", "already_initialized"]
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
self.test_results["initialization"] = {
|
| 76 |
+
"success": success,
|
| 77 |
+
"status": response.status,
|
| 78 |
+
"data": data,
|
| 79 |
+
"init_time_seconds": init_time
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
logger.info(f"Pipeline init: {'✅ PASS' if success else '❌ FAIL'} ({init_time:.1f}s)")
|
| 83 |
+
return success
|
| 84 |
+
|
| 85 |
+
except Exception as e:
|
| 86 |
+
logger.error(f"Pipeline initialization failed: {e}")
|
| 87 |
+
self.test_results["initialization"] = {"success": False, "error": str(e)}
|
| 88 |
+
return False
|
| 89 |
+
|
| 90 |
+
async def test_reference_image_upload(self) -> bool:
|
| 91 |
+
"""Test reference image upload functionality"""
|
| 92 |
+
try:
|
| 93 |
+
# Create a test image
|
| 94 |
+
test_image = np.zeros((512, 512, 3), dtype=np.uint8)
|
| 95 |
+
cv2.circle(test_image, (256, 200), 50, (255, 255, 255), -1) # Face-like circle
|
| 96 |
+
cv2.circle(test_image, (230, 180), 10, (0, 0, 0), -1) # Eye
|
| 97 |
+
cv2.circle(test_image, (280, 180), 10, (0, 0, 0), -1) # Eye
|
| 98 |
+
cv2.ellipse(test_image, (256, 220), (20, 10), 0, 0, 180, (0, 0, 0), 2) # Mouth
|
| 99 |
+
|
| 100 |
+
# Encode as JPEG
|
| 101 |
+
_, encoded = cv2.imencode('.jpg', test_image)
|
| 102 |
+
image_data = encoded.tobytes()
|
| 103 |
+
|
| 104 |
+
# Upload test image
|
| 105 |
+
form_data = aiohttp.FormData()
|
| 106 |
+
form_data.add_field('file', image_data, filename='test_face.jpg', content_type='image/jpeg')
|
| 107 |
+
|
| 108 |
+
async with self.session.post(f"{self.base_url}/set_reference", data=form_data) as response:
|
| 109 |
+
data = await response.json()
|
| 110 |
+
|
| 111 |
+
success = (
|
| 112 |
+
response.status == 200 and
|
| 113 |
+
data.get("status") == "success"
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
self.test_results["reference_upload"] = {
|
| 117 |
+
"success": success,
|
| 118 |
+
"status": response.status,
|
| 119 |
+
"data": data
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
logger.info(f"Reference upload: {'✅ PASS' if success else '❌ FAIL'}")
|
| 123 |
+
return success
|
| 124 |
+
|
| 125 |
+
except Exception as e:
|
| 126 |
+
logger.error(f"Reference image upload failed: {e}")
|
| 127 |
+
self.test_results["reference_upload"] = {"success": False, "error": str(e)}
|
| 128 |
+
return False
|
| 129 |
+
|
| 130 |
+
async def test_websocket_connections(self) -> bool:
|
| 131 |
+
"""Test WebSocket connections for audio and video"""
|
| 132 |
+
try:
|
| 133 |
+
import websockets
|
| 134 |
+
|
| 135 |
+
# Test audio WebSocket
|
| 136 |
+
audio_success = await self._test_websocket_endpoint("/audio")
|
| 137 |
+
|
| 138 |
+
# Test video WebSocket
|
| 139 |
+
video_success = await self._test_websocket_endpoint("/video")
|
| 140 |
+
|
| 141 |
+
success = audio_success and video_success
|
| 142 |
+
|
| 143 |
+
self.test_results["websockets"] = {
|
| 144 |
+
"success": success,
|
| 145 |
+
"audio_success": audio_success,
|
| 146 |
+
"video_success": video_success
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
logger.info(f"WebSocket connections: {'✅ PASS' if success else '❌ FAIL'}")
|
| 150 |
+
return success
|
| 151 |
+
|
| 152 |
+
except Exception as e:
|
| 153 |
+
logger.error(f"WebSocket test failed: {e}")
|
| 154 |
+
self.test_results["websockets"] = {"success": False, "error": str(e)}
|
| 155 |
+
return False
|
| 156 |
+
|
| 157 |
+
async def _test_websocket_endpoint(self, endpoint: str) -> bool:
|
| 158 |
+
"""Test a specific WebSocket endpoint"""
|
| 159 |
+
try:
|
| 160 |
+
import websockets
|
| 161 |
+
|
| 162 |
+
ws_url = self.base_url.replace("http://", "ws://") + endpoint
|
| 163 |
+
|
| 164 |
+
async with websockets.connect(ws_url) as websocket:
|
| 165 |
+
# Send test data
|
| 166 |
+
if endpoint == "/audio":
|
| 167 |
+
# Send 160ms of silence (16kHz, 16-bit)
|
| 168 |
+
test_audio = np.zeros(int(16000 * 0.160), dtype=np.int16)
|
| 169 |
+
await websocket.send(test_audio.tobytes())
|
| 170 |
+
else: # video
|
| 171 |
+
# Send a small test JPEG
|
| 172 |
+
test_frame = np.zeros((256, 256, 3), dtype=np.uint8)
|
| 173 |
+
_, encoded = cv2.imencode('.jpg', test_frame, [cv2.IMWRITE_JPEG_QUALITY, 50])
|
| 174 |
+
await websocket.send(encoded.tobytes())
|
| 175 |
+
|
| 176 |
+
# Wait for response
|
| 177 |
+
response = await asyncio.wait_for(websocket.recv(), timeout=5.0)
|
| 178 |
+
return len(response) > 0
|
| 179 |
+
|
| 180 |
+
except Exception as e:
|
| 181 |
+
logger.error(f"WebSocket {endpoint} test failed: {e}")
|
| 182 |
+
return False
|
| 183 |
+
|
| 184 |
+
async def test_performance_metrics(self) -> bool:
|
| 185 |
+
"""Test performance metrics endpoint"""
|
| 186 |
+
try:
|
| 187 |
+
async with self.session.get(f"{self.base_url}/pipeline_status") as response:
|
| 188 |
+
data = await response.json()
|
| 189 |
+
|
| 190 |
+
success = response.status == 200 and data.get("initialized", False)
|
| 191 |
+
|
| 192 |
+
self.test_results["performance_metrics"] = {
|
| 193 |
+
"success": success,
|
| 194 |
+
"status": response.status,
|
| 195 |
+
"data": data
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
if success:
|
| 199 |
+
stats = data.get("stats", {})
|
| 200 |
+
logger.info(f"Performance metrics: ✅ PASS")
|
| 201 |
+
logger.info(f" GPU Memory: {stats.get('gpu_memory_used', 0):.1f} GB")
|
| 202 |
+
logger.info(f" Video FPS: {stats.get('video_fps', 0):.1f}")
|
| 203 |
+
logger.info(f" Avg Latency: {stats.get('avg_video_latency_ms', 0):.1f} ms")
|
| 204 |
+
else:
|
| 205 |
+
logger.info("Performance metrics: ❌ FAIL")
|
| 206 |
+
|
| 207 |
+
return success
|
| 208 |
+
|
| 209 |
+
except Exception as e:
|
| 210 |
+
logger.error(f"Performance metrics test failed: {e}")
|
| 211 |
+
self.test_results["performance_metrics"] = {"success": False, "error": str(e)}
|
| 212 |
+
return False
|
| 213 |
+
|
| 214 |
+
async def test_latency_benchmark(self) -> Dict[str, float]:
|
| 215 |
+
"""Benchmark system latency"""
|
| 216 |
+
latencies = []
|
| 217 |
+
|
| 218 |
+
try:
|
| 219 |
+
# Warm up
|
| 220 |
+
for _ in range(5):
|
| 221 |
+
start_time = time.time()
|
| 222 |
+
async with self.session.get(f"{self.base_url}/health") as response:
|
| 223 |
+
await response.json()
|
| 224 |
+
latencies.append((time.time() - start_time) * 1000)
|
| 225 |
+
|
| 226 |
+
# Actual benchmark
|
| 227 |
+
latencies = []
|
| 228 |
+
for _ in range(20):
|
| 229 |
+
start_time = time.time()
|
| 230 |
+
async with self.session.get(f"{self.base_url}/pipeline_status") as response:
|
| 231 |
+
await response.json()
|
| 232 |
+
latencies.append((time.time() - start_time) * 1000)
|
| 233 |
+
|
| 234 |
+
results = {
|
| 235 |
+
"avg_latency_ms": np.mean(latencies),
|
| 236 |
+
"min_latency_ms": np.min(latencies),
|
| 237 |
+
"max_latency_ms": np.max(latencies),
|
| 238 |
+
"p95_latency_ms": np.percentile(latencies, 95),
|
| 239 |
+
"p99_latency_ms": np.percentile(latencies, 99)
|
| 240 |
+
}
|
| 241 |
+
|
| 242 |
+
self.test_results["latency_benchmark"] = results
|
| 243 |
+
|
| 244 |
+
logger.info("Latency benchmark results:")
|
| 245 |
+
logger.info(f" Average: {results['avg_latency_ms']:.1f} ms")
|
| 246 |
+
logger.info(f" P95: {results['p95_latency_ms']:.1f} ms")
|
| 247 |
+
logger.info(f" P99: {results['p99_latency_ms']:.1f} ms")
|
| 248 |
+
|
| 249 |
+
return results
|
| 250 |
+
|
| 251 |
+
except Exception as e:
|
| 252 |
+
logger.error(f"Latency benchmark failed: {e}")
|
| 253 |
+
return {}
|
| 254 |
+
|
| 255 |
+
def test_system_requirements(self) -> Dict[str, Any]:
|
| 256 |
+
"""Test system requirements and capabilities"""
|
| 257 |
+
results = {}
|
| 258 |
+
|
| 259 |
+
try:
|
| 260 |
+
# Check GPU availability
|
| 261 |
+
try:
|
| 262 |
+
import torch
|
| 263 |
+
results["gpu_available"] = torch.cuda.is_available()
|
| 264 |
+
if torch.cuda.is_available():
|
| 265 |
+
results["gpu_name"] = torch.cuda.get_device_name(0)
|
| 266 |
+
results["gpu_memory_gb"] = torch.cuda.get_device_properties(0).total_memory / 1024**3
|
| 267 |
+
results["cuda_version"] = torch.version.cuda
|
| 268 |
+
except ImportError:
|
| 269 |
+
results["gpu_available"] = False
|
| 270 |
+
|
| 271 |
+
# Check system resources
|
| 272 |
+
memory = psutil.virtual_memory()
|
| 273 |
+
results["system_memory_gb"] = memory.total / 1024**3
|
| 274 |
+
results["cpu_count"] = psutil.cpu_count()
|
| 275 |
+
|
| 276 |
+
# Check disk space
|
| 277 |
+
disk = psutil.disk_usage('/')
|
| 278 |
+
results["disk_free_gb"] = disk.free / 1024**3
|
| 279 |
+
|
| 280 |
+
# Check required packages
|
| 281 |
+
required_packages = [
|
| 282 |
+
"torch", "torchvision", "torchaudio", "opencv-python",
|
| 283 |
+
"numpy", "fastapi", "websockets"
|
| 284 |
+
]
|
| 285 |
+
|
| 286 |
+
missing_packages = []
|
| 287 |
+
for package in required_packages:
|
| 288 |
+
try:
|
| 289 |
+
__import__(package.replace("-", "_"))
|
| 290 |
+
except ImportError:
|
| 291 |
+
missing_packages.append(package)
|
| 292 |
+
|
| 293 |
+
results["missing_packages"] = missing_packages
|
| 294 |
+
results["requirements_met"] = len(missing_packages) == 0
|
| 295 |
+
|
| 296 |
+
self.test_results["system_requirements"] = results
|
| 297 |
+
|
| 298 |
+
logger.info("System requirements:")
|
| 299 |
+
logger.info(f" GPU: {'✅' if results['gpu_available'] else '❌'}")
|
| 300 |
+
logger.info(f" Memory: {results['system_memory_gb']:.1f} GB")
|
| 301 |
+
logger.info(f" CPU: {results['cpu_count']} cores")
|
| 302 |
+
logger.info(f" Packages: {'✅' if results['requirements_met'] else '❌'}")
|
| 303 |
+
|
| 304 |
+
return results
|
| 305 |
+
|
| 306 |
+
except Exception as e:
|
| 307 |
+
logger.error(f"System requirements check failed: {e}")
|
| 308 |
+
return {"error": str(e)}
|
| 309 |
+
|
| 310 |
+
async def run_comprehensive_test(self) -> Dict[str, Any]:
|
| 311 |
+
"""Run all tests and return comprehensive results"""
|
| 312 |
+
logger.info("🧪 Starting comprehensive system test...")
|
| 313 |
+
|
| 314 |
+
# System requirements (runs first, no server needed)
|
| 315 |
+
self.test_system_requirements()
|
| 316 |
+
|
| 317 |
+
# Server-dependent tests
|
| 318 |
+
tests = [
|
| 319 |
+
("Health Check", self.test_health_endpoint()),
|
| 320 |
+
("Pipeline Initialization", self.test_pipeline_initialization()),
|
| 321 |
+
("Reference Image Upload", self.test_reference_image_upload()),
|
| 322 |
+
("WebSocket Connections", self.test_websocket_connections()),
|
| 323 |
+
("Performance Metrics", self.test_performance_metrics()),
|
| 324 |
+
]
|
| 325 |
+
|
| 326 |
+
# Run tests sequentially
|
| 327 |
+
for test_name, test_coro in tests:
|
| 328 |
+
logger.info(f"Running: {test_name}...")
|
| 329 |
+
try:
|
| 330 |
+
result = await test_coro
|
| 331 |
+
if not result:
|
| 332 |
+
logger.warning(f"{test_name} failed - may affect subsequent tests")
|
| 333 |
+
except Exception as e:
|
| 334 |
+
logger.error(f"{test_name} threw exception: {e}")
|
| 335 |
+
|
| 336 |
+
# Latency benchmark (runs last)
|
| 337 |
+
logger.info("Running latency benchmark...")
|
| 338 |
+
await self.test_latency_benchmark()
|
| 339 |
+
|
| 340 |
+
# Calculate overall success rate
|
| 341 |
+
successful_tests = sum(1 for result in self.test_results.values()
|
| 342 |
+
if isinstance(result, dict) and result.get("success", False))
|
| 343 |
+
total_tests = len([r for r in self.test_results.values() if isinstance(r, dict) and "success" in r])
|
| 344 |
+
|
| 345 |
+
overall_success = successful_tests / max(total_tests, 1) >= 0.8 # 80% success rate
|
| 346 |
+
|
| 347 |
+
summary = {
|
| 348 |
+
"overall_success": overall_success,
|
| 349 |
+
"successful_tests": successful_tests,
|
| 350 |
+
"total_tests": total_tests,
|
| 351 |
+
"success_rate": successful_tests / max(total_tests, 1),
|
| 352 |
+
"detailed_results": self.test_results
|
| 353 |
+
}
|
| 354 |
+
|
| 355 |
+
logger.info(f"🏁 Test completed: {successful_tests}/{total_tests} tests passed")
|
| 356 |
+
logger.info(f"Overall result: {'✅ PASS' if overall_success else '❌ FAIL'}")
|
| 357 |
+
|
| 358 |
+
return summary
|
| 359 |
+
|
| 360 |
+
async def main():
|
| 361 |
+
"""Main test runner"""
|
| 362 |
+
import sys
|
| 363 |
+
|
| 364 |
+
base_url = sys.argv[1] if len(sys.argv) > 1 else "http://localhost:7860"
|
| 365 |
+
|
| 366 |
+
async with MirageSystemTester(base_url) as tester:
|
| 367 |
+
results = await tester.run_comprehensive_test()
|
| 368 |
+
|
| 369 |
+
# Save results to file
|
| 370 |
+
results_file = Path("test_results.json")
|
| 371 |
+
with open(results_file, "w") as f:
|
| 372 |
+
json.dump(results, f, indent=2, default=str)
|
| 373 |
+
|
| 374 |
+
logger.info(f"📊 Detailed results saved to: {results_file}")
|
| 375 |
+
|
| 376 |
+
# Exit with appropriate code
|
| 377 |
+
sys.exit(0 if results["overall_success"] else 1)
|
| 378 |
+
|
| 379 |
+
if __name__ == "__main__":
|
| 380 |
+
asyncio.run(main())
|
webrtc_server.py
ADDED
|
@@ -0,0 +1,444 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""WebRTC integration using aiortc for low-latency bi-directional media.
|
| 2 |
+
|
| 3 |
+
This module exposes:
|
| 4 |
+
- POST /webrtc/offer : Accepts an SDP offer from browser, returns SDP answer.
|
| 5 |
+
- GET /webrtc/ice : (Optional) polling ICE candidates (simplified; trickle or full offer/answer)
|
| 6 |
+
|
| 7 |
+
Media Flow (Phase 1):
|
| 8 |
+
Browser camera/mic -> WebRTC -> aiortc PeerConnection ->
|
| 9 |
+
Video track -> frame hook -> pipeline.process_video_frame -> return video track to client
|
| 10 |
+
Audio track -> chunk hook -> pipeline.process_audio_chunk -> return audio track to client
|
| 11 |
+
|
| 12 |
+
Control/Data channel: "control" used for lightweight JSON messages:
|
| 13 |
+
{"type":"metrics_request"} -> server replies {"type":"metrics","payload":...}
|
| 14 |
+
{"type":"set_reference","image_jpeg_base64":...}
|
| 15 |
+
|
| 16 |
+
Fallback: If aiortc not supported in environment or import fails, endpoint returns 503.
|
| 17 |
+
|
| 18 |
+
Security: (basic) Optional shared secret via X-API-Key header (env MIRAGE_API_KEY).
|
| 19 |
+
|
| 20 |
+
NOTE: This is a minimal, production-ready skeleton focusing on structure, error handling,
|
| 21 |
+
resource cleanup and integration points. Actual model inference remains in avatar_pipeline.
|
| 22 |
+
"""
|
| 23 |
+
from __future__ import annotations
|
| 24 |
+
|
| 25 |
+
import asyncio
|
| 26 |
+
import base64
|
| 27 |
+
import json
|
| 28 |
+
import logging
|
| 29 |
+
import os
|
| 30 |
+
import time
|
| 31 |
+
from dataclasses import dataclass
|
| 32 |
+
import hashlib
|
| 33 |
+
import hmac
|
| 34 |
+
import secrets as pysecrets
|
| 35 |
+
import base64 as pybase64
|
| 36 |
+
from typing import Optional, Dict, Any
|
| 37 |
+
|
| 38 |
+
from fastapi import APIRouter, HTTPException, Header
|
| 39 |
+
|
| 40 |
+
try:
|
| 41 |
+
from aiortc import RTCPeerConnection, RTCSessionDescription, MediaStreamTrack, RTCConfiguration, RTCIceServer
|
| 42 |
+
from aiortc.contrib.media import MediaBlackhole
|
| 43 |
+
import av # noqa: F401 (required by aiortc for codecs)
|
| 44 |
+
AIORTC_AVAILABLE = True
|
| 45 |
+
except Exception as e: # pragma: no cover
|
| 46 |
+
AIORTC_IMPORT_ERROR = str(e)
|
| 47 |
+
AIORTC_AVAILABLE = False
|
| 48 |
+
|
| 49 |
+
import numpy as np
|
| 50 |
+
import cv2
|
| 51 |
+
|
| 52 |
+
from avatar_pipeline import get_pipeline
|
| 53 |
+
|
| 54 |
+
logger = logging.getLogger(__name__)
|
| 55 |
+
router = APIRouter(prefix="/webrtc", tags=["webrtc"])
|
| 56 |
+
|
| 57 |
+
API_KEY = os.getenv("MIRAGE_API_KEY")
|
| 58 |
+
REQUIRE_API_KEY = os.getenv("MIRAGE_REQUIRE_API_KEY", "0").strip().lower() in {"1","true","yes","on"}
|
| 59 |
+
TOKEN_TTL_SECONDS = int(os.getenv("MIRAGE_TOKEN_TTL", "300")) # 5 minutes default
|
| 60 |
+
STUN_URLS = os.getenv("MIRAGE_STUN_URLS", "stun:stun.l.google.com:19302")
|
| 61 |
+
TURN_URL = os.getenv("MIRAGE_TURN_URL")
|
| 62 |
+
TURN_USER = os.getenv("MIRAGE_TURN_USER")
|
| 63 |
+
TURN_PASS = os.getenv("MIRAGE_TURN_PASS")
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _b64u(data: bytes) -> str:
|
| 67 |
+
return pybase64.urlsafe_b64encode(data).decode('ascii').rstrip('=')
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _b64u_decode(data: str) -> bytes:
|
| 71 |
+
pad = '=' * (-len(data) % 4)
|
| 72 |
+
return pybase64.urlsafe_b64decode(data + pad)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def _mint_token() -> str:
|
| 76 |
+
"""Stateless signed token: base64url(ts:nonce:mac)."""
|
| 77 |
+
ts = str(int(time.time()))
|
| 78 |
+
nonce = _b64u(pysecrets.token_bytes(12))
|
| 79 |
+
msg = f"{ts}:{nonce}".encode('utf-8')
|
| 80 |
+
mac = hmac.new(API_KEY.encode('utf-8'), msg, hashlib.sha256).digest()
|
| 81 |
+
return _b64u(msg) + '.' + _b64u(mac)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def _verify_token(token: str) -> bool:
|
| 85 |
+
try:
|
| 86 |
+
parts = token.split('.')
|
| 87 |
+
if len(parts) != 2:
|
| 88 |
+
return False
|
| 89 |
+
msg_b64, mac_b64 = parts
|
| 90 |
+
msg = _b64u_decode(msg_b64)
|
| 91 |
+
mac = _b64u_decode(mac_b64)
|
| 92 |
+
ts_str, nonce = msg.decode('utf-8').split(':', 1)
|
| 93 |
+
ts = int(ts_str)
|
| 94 |
+
if time.time() - ts > TOKEN_TTL_SECONDS:
|
| 95 |
+
return False
|
| 96 |
+
expected = hmac.new(API_KEY.encode('utf-8'), msg, hashlib.sha256).digest()
|
| 97 |
+
return hmac.compare_digest(expected, mac)
|
| 98 |
+
except Exception:
|
| 99 |
+
return False
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def _check_api_key(header_val: Optional[str], token_val: Optional[str] = None):
|
| 103 |
+
# If no API key configured, allow
|
| 104 |
+
if not API_KEY:
|
| 105 |
+
return
|
| 106 |
+
# If enforcement disabled, allow
|
| 107 |
+
if not REQUIRE_API_KEY:
|
| 108 |
+
return
|
| 109 |
+
# Accept raw key or signed token
|
| 110 |
+
if header_val and header_val == API_KEY:
|
| 111 |
+
return
|
| 112 |
+
if token_val and _verify_token(token_val):
|
| 113 |
+
return
|
| 114 |
+
raise HTTPException(status_code=401, detail="Unauthorized")
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def _ice_configuration() -> RTCConfiguration:
|
| 118 |
+
servers = []
|
| 119 |
+
# STUN servers (comma-separated)
|
| 120 |
+
for url in [u.strip() for u in STUN_URLS.split(',') if u.strip()]:
|
| 121 |
+
servers.append(RTCIceServer(urls=[url]))
|
| 122 |
+
# Optional TURN
|
| 123 |
+
if TURN_URL and TURN_USER and TURN_PASS:
|
| 124 |
+
servers.append(RTCIceServer(urls=[TURN_URL], username=TURN_USER, credential=TURN_PASS))
|
| 125 |
+
return RTCConfiguration(iceServers=servers)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def _prefer_codec(sdp: str, kind: str, codec: str) -> str:
|
| 129 |
+
"""Move payload types for the given codec to the front of the m-line.
|
| 130 |
+
Minimal SDP munging for preferring codecs (e.g., H264 or VP8).
|
| 131 |
+
"""
|
| 132 |
+
try:
|
| 133 |
+
lines = sdp.splitlines()
|
| 134 |
+
# Map pt -> codec
|
| 135 |
+
pt_to_codec = {}
|
| 136 |
+
for ln in lines:
|
| 137 |
+
if ln.startswith('a=rtpmap:'):
|
| 138 |
+
try:
|
| 139 |
+
rest = ln[len('a=rtpmap:'):]
|
| 140 |
+
pt, enc = rest.split(' ', 1)
|
| 141 |
+
codec_name = enc.split('/')[0].upper()
|
| 142 |
+
pt_to_codec[pt] = codec_name
|
| 143 |
+
except Exception:
|
| 144 |
+
pass
|
| 145 |
+
# Find m-line for kind
|
| 146 |
+
for i, ln in enumerate(lines):
|
| 147 |
+
if ln.startswith('m=') and kind in ln:
|
| 148 |
+
parts = ln.split(' ')
|
| 149 |
+
header = parts[:3]
|
| 150 |
+
pts = parts[3:]
|
| 151 |
+
preferred = [pt for pt in pts if pt_to_codec.get(pt, '') == codec.upper()]
|
| 152 |
+
others = [pt for pt in pts if pt not in preferred]
|
| 153 |
+
lines[i] = ' '.join(header + preferred + others)
|
| 154 |
+
break
|
| 155 |
+
return '\r\n'.join(lines) + '\r\n'
|
| 156 |
+
except Exception:
|
| 157 |
+
return sdp
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
async def _ensure_pipeline_initialized():
|
| 161 |
+
"""Initialize the pipeline if not already loaded."""
|
| 162 |
+
pipeline = get_pipeline()
|
| 163 |
+
try:
|
| 164 |
+
if not getattr(pipeline, "loaded", False):
|
| 165 |
+
init = getattr(pipeline, "initialize", None)
|
| 166 |
+
if callable(init):
|
| 167 |
+
result = init()
|
| 168 |
+
if asyncio.iscoroutine(result):
|
| 169 |
+
await result
|
| 170 |
+
except Exception as e:
|
| 171 |
+
logger.error(f"Pipeline init failed: {e}")
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
@dataclass
|
| 175 |
+
class PeerState:
|
| 176 |
+
pc: RTCPeerConnection
|
| 177 |
+
created: float
|
| 178 |
+
control_channel_ready: bool = False
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
# In-memory single peer (extend to dict for multi-user)
|
| 182 |
+
_peer_state: Optional[PeerState] = None
|
| 183 |
+
_peer_lock = asyncio.Lock()
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class IncomingVideoTrack(MediaStreamTrack):
|
| 187 |
+
kind = "video"
|
| 188 |
+
|
| 189 |
+
def __init__(self, track: MediaStreamTrack):
|
| 190 |
+
super().__init__() # base init
|
| 191 |
+
self.track = track
|
| 192 |
+
self.pipeline = get_pipeline()
|
| 193 |
+
self.frame_id = 0
|
| 194 |
+
self._last_processed: Optional[np.ndarray] = None
|
| 195 |
+
self._processing_task: Optional[asyncio.Task] = None
|
| 196 |
+
self._lock = asyncio.Lock()
|
| 197 |
+
|
| 198 |
+
async def recv(self): # type: ignore[override]
|
| 199 |
+
frame = await self.track.recv()
|
| 200 |
+
self.frame_id += 1
|
| 201 |
+
# Convert to numpy BGR for pipeline
|
| 202 |
+
img = frame.to_ndarray(format="bgr24")
|
| 203 |
+
h, w, _ = img.shape
|
| 204 |
+
proc_input = img
|
| 205 |
+
# Optionally downscale for processing to cap latency
|
| 206 |
+
try:
|
| 207 |
+
if max(h, w) > 512:
|
| 208 |
+
scale_w = 512
|
| 209 |
+
scale_h = int(h * (512 / w)) if w >= h else 512
|
| 210 |
+
if w < h:
|
| 211 |
+
scale_w = int(w * (512 / h))
|
| 212 |
+
proc_input = cv2.resize(img, (max(1, scale_w), max(1, scale_h)))
|
| 213 |
+
except Exception as e:
|
| 214 |
+
logger.debug(f"Video downscale skip: {e}")
|
| 215 |
+
# Schedule background processing to avoid blocking recv()
|
| 216 |
+
async def _process_async(inp: np.ndarray, expected_size: tuple[int, int], fid: int):
|
| 217 |
+
try:
|
| 218 |
+
out_small = self.pipeline.process_video_frame(inp, fid)
|
| 219 |
+
if (out_small.shape[1], out_small.shape[0]) != expected_size:
|
| 220 |
+
out = cv2.resize(out_small, expected_size)
|
| 221 |
+
else:
|
| 222 |
+
out = out_small
|
| 223 |
+
async with self._lock:
|
| 224 |
+
self._last_processed = out
|
| 225 |
+
except Exception as ex:
|
| 226 |
+
logger.error(f"Video processing error(bg): {ex}")
|
| 227 |
+
finally:
|
| 228 |
+
self._processing_task = None
|
| 229 |
+
|
| 230 |
+
expected = (w, h)
|
| 231 |
+
if self._processing_task is None:
|
| 232 |
+
# Only run one processing task at a time; drop older frames
|
| 233 |
+
self._processing_task = asyncio.create_task(_process_async(proc_input, expected, self.frame_id))
|
| 234 |
+
|
| 235 |
+
# Use last processed if available, else pass-through
|
| 236 |
+
async with self._lock:
|
| 237 |
+
processed = self._last_processed if self._last_processed is not None else img
|
| 238 |
+
# Convert back to VideoFrame
|
| 239 |
+
new_frame = frame.from_ndarray(processed, format="bgr24")
|
| 240 |
+
new_frame.pts = frame.pts
|
| 241 |
+
new_frame.time_base = frame.time_base
|
| 242 |
+
return new_frame
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
class IncomingAudioTrack(MediaStreamTrack):
|
| 246 |
+
kind = "audio"
|
| 247 |
+
|
| 248 |
+
def __init__(self, track: MediaStreamTrack):
|
| 249 |
+
super().__init__()
|
| 250 |
+
self.track = track
|
| 251 |
+
self.pipeline = get_pipeline()
|
| 252 |
+
self._resample_to_16k = None
|
| 253 |
+
self._resample_from_16k = None
|
| 254 |
+
|
| 255 |
+
async def recv(self): # type: ignore[override]
|
| 256 |
+
frame = await self.track.recv()
|
| 257 |
+
# frame is an AudioFrame (PCM)
|
| 258 |
+
try:
|
| 259 |
+
import av
|
| 260 |
+
from av.audio.resampler import AudioResampler
|
| 261 |
+
# Initialize resamplers once using input characteristics
|
| 262 |
+
if self._resample_to_16k is None:
|
| 263 |
+
self._resample_to_16k = AudioResampler(format='s16', layout='mono', rate=16000)
|
| 264 |
+
if self._resample_from_16k is None:
|
| 265 |
+
# Back to original sample rate and layout; keep s16 for low overhead
|
| 266 |
+
target_layout = frame.layout.name if frame.layout else 'mono'
|
| 267 |
+
target_rate = frame.sample_rate or 48000
|
| 268 |
+
self._resample_from_16k = AudioResampler(format='s16', layout=target_layout, rate=target_rate)
|
| 269 |
+
|
| 270 |
+
# 1) To mono s16 @16k for pipeline
|
| 271 |
+
f_16k_list = self._resample_to_16k.resample(frame)
|
| 272 |
+
if isinstance(f_16k_list, list):
|
| 273 |
+
f_16k = f_16k_list[0]
|
| 274 |
+
else:
|
| 275 |
+
f_16k = f_16k_list
|
| 276 |
+
pcm16k = f_16k.to_ndarray() # (channels, samples), dtype=int16
|
| 277 |
+
if pcm16k.ndim == 2:
|
| 278 |
+
# convert to mono if needed
|
| 279 |
+
if pcm16k.shape[0] > 1:
|
| 280 |
+
pcm16k = np.mean(pcm16k, axis=0, keepdims=True).astype(np.int16)
|
| 281 |
+
# drop channel dim -> (samples,)
|
| 282 |
+
pcm16k = pcm16k.reshape(-1)
|
| 283 |
+
|
| 284 |
+
# 2) Pipeline processing (mono 16k int16 ndarray)
|
| 285 |
+
processed_arr = self.pipeline.process_audio_chunk(pcm16k)
|
| 286 |
+
if isinstance(processed_arr, bytes):
|
| 287 |
+
processed_bytes = processed_arr
|
| 288 |
+
else:
|
| 289 |
+
processed_bytes = np.asarray(processed_arr, dtype=np.int16).tobytes()
|
| 290 |
+
|
| 291 |
+
# 3) Wrap processed back into an av frame @16k mono s16
|
| 292 |
+
samples = len(processed_bytes) // 2
|
| 293 |
+
f_proc_16k = av.AudioFrame(format='s16', layout='mono', samples=samples)
|
| 294 |
+
f_proc_16k.sample_rate = 16000
|
| 295 |
+
f_proc_16k.planes[0].update(processed_bytes)
|
| 296 |
+
|
| 297 |
+
# 4) Resample back to original sample rate/layout
|
| 298 |
+
f_out_list = self._resample_from_16k.resample(f_proc_16k)
|
| 299 |
+
if isinstance(f_out_list, list) and len(f_out_list) > 0:
|
| 300 |
+
f_out = f_out_list[0]
|
| 301 |
+
else:
|
| 302 |
+
f_out = f_proc_16k # fallback
|
| 303 |
+
|
| 304 |
+
# Preserve timing as best-effort
|
| 305 |
+
f_out.pts = frame.pts
|
| 306 |
+
f_out.time_base = frame.time_base
|
| 307 |
+
return f_out
|
| 308 |
+
except Exception as e:
|
| 309 |
+
logger.error(f"Audio processing error: {e}")
|
| 310 |
+
return frame
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
@router.post("/offer")
|
| 314 |
+
async def webrtc_offer(offer: Dict[str, Any], x_api_key: Optional[str] = Header(default=None), x_auth_token: Optional[str] = Header(default=None)):
|
| 315 |
+
"""Accept SDP offer and return SDP answer."""
|
| 316 |
+
# If enforcement enabled, require a valid signed token; otherwise allow
|
| 317 |
+
if REQUIRE_API_KEY:
|
| 318 |
+
if not (x_auth_token and _verify_token(x_auth_token)):
|
| 319 |
+
raise HTTPException(status_code=401, detail="Unauthorized")
|
| 320 |
+
if not AIORTC_AVAILABLE:
|
| 321 |
+
raise HTTPException(status_code=503, detail=f"aiortc not available: {AIORTC_IMPORT_ERROR}")
|
| 322 |
+
|
| 323 |
+
async with _peer_lock:
|
| 324 |
+
global _peer_state
|
| 325 |
+
# Ensure pipeline is ready before wiring tracks
|
| 326 |
+
await _ensure_pipeline_initialized()
|
| 327 |
+
# Cleanup existing peer if present
|
| 328 |
+
if _peer_state is not None:
|
| 329 |
+
try:
|
| 330 |
+
await _peer_state.pc.close()
|
| 331 |
+
except Exception:
|
| 332 |
+
pass
|
| 333 |
+
_peer_state = None
|
| 334 |
+
|
| 335 |
+
pc = RTCPeerConnection(configuration=_ice_configuration())
|
| 336 |
+
blackhole = MediaBlackhole() # optional sink
|
| 337 |
+
|
| 338 |
+
@pc.on("datachannel")
|
| 339 |
+
def on_datachannel(channel):
|
| 340 |
+
logger.info("Data channel received: %s", channel.label)
|
| 341 |
+
if channel.label == "control":
|
| 342 |
+
def send_metrics():
|
| 343 |
+
pipeline = get_pipeline()
|
| 344 |
+
stats = pipeline.get_performance_stats() if pipeline.loaded else {}
|
| 345 |
+
payload = json.dumps({"type": "metrics", "payload": stats})
|
| 346 |
+
try:
|
| 347 |
+
channel.send(payload)
|
| 348 |
+
except Exception:
|
| 349 |
+
logger.debug("Failed sending metrics")
|
| 350 |
+
|
| 351 |
+
@channel.on("message")
|
| 352 |
+
def on_message(message):
|
| 353 |
+
try:
|
| 354 |
+
if isinstance(message, bytes):
|
| 355 |
+
return
|
| 356 |
+
data = json.loads(message)
|
| 357 |
+
mtype = data.get("type")
|
| 358 |
+
if mtype == "ping":
|
| 359 |
+
channel.send(json.dumps({"type": "pong", "t": time.time()}))
|
| 360 |
+
elif mtype == "metrics_request":
|
| 361 |
+
send_metrics()
|
| 362 |
+
elif mtype == "set_reference":
|
| 363 |
+
b64 = data.get("image_jpeg_base64")
|
| 364 |
+
if b64:
|
| 365 |
+
try:
|
| 366 |
+
# Guard size (<= 2MB when base64)
|
| 367 |
+
if len(b64) > 2_800_000:
|
| 368 |
+
channel.send(json.dumps({"type": "error", "message": "reference too large"}))
|
| 369 |
+
return
|
| 370 |
+
raw = base64.b64decode(b64)
|
| 371 |
+
arr = np.frombuffer(raw, np.uint8)
|
| 372 |
+
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
| 373 |
+
if img is not None:
|
| 374 |
+
pipeline = get_pipeline()
|
| 375 |
+
pipeline.set_reference_frame(img)
|
| 376 |
+
channel.send(json.dumps({"type": "reference_ack"}))
|
| 377 |
+
except Exception as e:
|
| 378 |
+
channel.send(json.dumps({"type": "error", "message": str(e)}))
|
| 379 |
+
except Exception as e:
|
| 380 |
+
logger.error(f"Data channel message error: {e}")
|
| 381 |
+
|
| 382 |
+
@pc.on("connectionstatechange")
|
| 383 |
+
async def on_state_change():
|
| 384 |
+
logger.info("Peer connection state: %s", pc.connectionState)
|
| 385 |
+
if pc.connectionState in ("failed", "closed", "disconnected"):
|
| 386 |
+
try:
|
| 387 |
+
await pc.close()
|
| 388 |
+
except Exception:
|
| 389 |
+
pass
|
| 390 |
+
|
| 391 |
+
# Set remote description
|
| 392 |
+
try:
|
| 393 |
+
desc = RTCSessionDescription(sdp=offer["sdp"], type=offer["type"])
|
| 394 |
+
await pc.setRemoteDescription(desc)
|
| 395 |
+
except Exception as e:
|
| 396 |
+
raise HTTPException(status_code=400, detail=f"Invalid SDP offer: {e}")
|
| 397 |
+
|
| 398 |
+
# Attach incoming tracks and re-add outbound processed tracks
|
| 399 |
+
@pc.on("track")
|
| 400 |
+
def on_track(track):
|
| 401 |
+
logger.info("Track received: %s", track.kind)
|
| 402 |
+
if track.kind == "video":
|
| 403 |
+
local = IncomingVideoTrack(track)
|
| 404 |
+
pc.addTrack(local)
|
| 405 |
+
elif track.kind == "audio":
|
| 406 |
+
local_a = IncomingAudioTrack(track)
|
| 407 |
+
pc.addTrack(local_a)
|
| 408 |
+
|
| 409 |
+
# Create answer
|
| 410 |
+
answer = await pc.createAnswer()
|
| 411 |
+
# Prefer H264 for broader compatibility (fallback to as-is if munging fails)
|
| 412 |
+
patched_sdp = _prefer_codec(answer.sdp, 'video', os.getenv('MIRAGE_PREFERRED_VIDEO_CODEC', 'H264'))
|
| 413 |
+
answer = RTCSessionDescription(sdp=patched_sdp, type=answer.type)
|
| 414 |
+
await pc.setLocalDescription(answer)
|
| 415 |
+
|
| 416 |
+
_peer_state = PeerState(pc=pc, created=time.time())
|
| 417 |
+
|
| 418 |
+
logger.info("WebRTC answer created")
|
| 419 |
+
return {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type}
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
@router.get("/token")
|
| 423 |
+
async def mint_token():
|
| 424 |
+
"""Return a short-lived signed token that can be used as X-Auth-Token.
|
| 425 |
+
Public endpoint; signature uses server-held API key, if configured.
|
| 426 |
+
"""
|
| 427 |
+
if not API_KEY:
|
| 428 |
+
raise HTTPException(status_code=400, detail="API key not configured")
|
| 429 |
+
return {"token": _mint_token(), "ttl": TOKEN_TTL_SECONDS}
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
@router.post("/cleanup")
|
| 433 |
+
async def cleanup_peer(x_api_key: Optional[str] = Header(default=None)):
|
| 434 |
+
_check_api_key(x_api_key)
|
| 435 |
+
async with _peer_lock:
|
| 436 |
+
global _peer_state
|
| 437 |
+
if _peer_state is None:
|
| 438 |
+
return {"status": "no_peer"}
|
| 439 |
+
try:
|
| 440 |
+
await _peer_state.pc.close()
|
| 441 |
+
except Exception:
|
| 442 |
+
pass
|
| 443 |
+
_peer_state = None
|
| 444 |
+
return {"status": "closed"}
|