Spaces:
Running
Running
| import base64 | |
| import os | |
| import tempfile | |
| import uuid | |
| from pathlib import Path | |
| from threading import Lock | |
| from typing import Dict, Optional | |
| import requests | |
| import torch | |
| from fastapi import BackgroundTasks, Body, FastAPI, Header, HTTPException | |
| from fastapi.responses import FileResponse, JSONResponse | |
| from pydantic import BaseModel, Field, HttpUrl | |
| # Environment configuration | |
| SPACE_API_KEY = os.getenv("SPACE_API_KEY") | |
| HF_TOKEN = ( | |
| os.getenv("HUGGING_FACE_HUB_TOKEN") | |
| or os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
| or os.getenv("HF_TOKEN") | |
| ) | |
| # Model configuration | |
| MODEL_DIR = os.getenv("MODEL_DIR", "./checkpoints") | |
| MAX_TEXT_LENGTH = 1000 | |
| DEFAULT_LANGUAGE = "en" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Job management | |
| JOBS: Dict[str, Dict[str, str]] = {} | |
| JOB_LOCK = Lock() | |
| # Set token in environment before importing | |
| if HF_TOKEN: | |
| os.environ["HUGGING_FACE_HUB_TOKEN"] = HF_TOKEN | |
| os.environ["HF_TOKEN"] = HF_TOKEN | |
| # Download and initialize OpenVoice model | |
| os.makedirs(MODEL_DIR, exist_ok=True) | |
| print(f"Initializing OpenVoice on {DEVICE}...") | |
| try: | |
| # Download checkpoints if needed | |
| if not Path(MODEL_DIR, "checkpoints_v2").exists(): | |
| print("Downloading OpenVoice V2 checkpoints...") | |
| from huggingface_hub import snapshot_download | |
| snapshot_download( | |
| repo_id="myshell-ai/OpenVoice", | |
| local_dir=MODEL_DIR, | |
| token=HF_TOKEN, | |
| ) | |
| print("Model download complete.") | |
| # Import OpenVoice modules | |
| from melo.api import TTS | |
| from openvoice import se_extractor | |
| from openvoice.api import ToneColorConverter | |
| # Initialize base TTS (MeloTTS) | |
| ckpt_converter = f'{MODEL_DIR}/checkpoints_v2/converter' | |
| # Initialize tone color converter | |
| tone_color_converter = ToneColorConverter( | |
| f'{ckpt_converter}/config.json', | |
| device=DEVICE | |
| ) | |
| tone_color_converter.load_ckpt(f'{ckpt_converter}/checkpoint.pth') | |
| # Initialize base TTS for English | |
| base_speaker_tts = TTS(language='EN', device=DEVICE) | |
| base_speaker = base_speaker_tts.hps.data.spk2id['EN-US'] | |
| print("OpenVoice V2 loaded successfully!") | |
| except Exception as exc: | |
| print(f"Error loading OpenVoice: {exc}") | |
| print("Trying alternative initialization...") | |
| try: | |
| # Fallback: Use simpler initialization | |
| from melo.api import TTS | |
| base_speaker_tts = TTS(language='EN', device=DEVICE) | |
| base_speaker = base_speaker_tts.hps.data.spk2id['EN-US'] | |
| # Mock converter for basic functionality | |
| tone_color_converter = None | |
| print("Loaded base TTS only (voice cloning disabled)") | |
| except Exception as exc2: | |
| raise RuntimeError(f"Failed to load OpenVoice: {exc2}") from exc2 | |
| # Initialize FastAPI app | |
| app = FastAPI(title="openvoice-api", version="2.0.0") | |
| class GenerateRequest(BaseModel): | |
| text: str = Field(..., min_length=1, max_length=MAX_TEXT_LENGTH) | |
| speaker_wav: str = Field(..., description="HTTPS URL or base64-encoded audio") | |
| language: Optional[str] = Field(DEFAULT_LANGUAGE, description="ISO code: en, es, fr, zh, ja, ko") | |
| speed: Optional[float] = Field(1.0, ge=0.5, le=2.0, description="Speech speed (0.5-2.0)") | |
| def _require_api_key(x_api_key: Optional[str]): | |
| """Validate API key if configured.""" | |
| if not SPACE_API_KEY: | |
| return | |
| if x_api_key != SPACE_API_KEY: | |
| raise HTTPException(status_code=401, detail="Unauthorized") | |
| def _write_temp_audio_from_url(url: HttpUrl) -> str: | |
| """Download audio from URL to temporary file.""" | |
| response = requests.get(url, stream=True, timeout=30) | |
| if response.status_code >= 400: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Could not fetch speaker audio: {response.status_code}" | |
| ) | |
| suffix = Path(url.path).suffix or ".wav" | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| if chunk: | |
| tmp.write(chunk) | |
| return tmp.name | |
| def _write_temp_audio_from_base64(payload: str) -> str: | |
| """Decode base64 audio to temporary file.""" | |
| try: | |
| raw = base64.b64decode(payload) | |
| except Exception as exc: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Invalid base64 speaker_wav" | |
| ) from exc | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: | |
| tmp.write(raw) | |
| return tmp.name | |
| def _temp_speaker_file(speaker_wav: str) -> str: | |
| """Handle speaker audio input from URL or base64.""" | |
| if speaker_wav.startswith("http://") or speaker_wav.startswith("https://"): | |
| return _write_temp_audio_from_url(HttpUrl(speaker_wav)) | |
| return _write_temp_audio_from_base64(speaker_wav) | |
| def _set_job(job_id: str, **kwargs): | |
| """Thread-safe job update.""" | |
| with JOB_LOCK: | |
| JOBS[job_id] = {**JOBS.get(job_id, {}), **kwargs} | |
| def _get_job(job_id: str) -> Optional[Dict[str, str]]: | |
| """Thread-safe job retrieval.""" | |
| with JOB_LOCK: | |
| data = JOBS.get(job_id) | |
| return dict(data) if data else None | |
| def _pop_job(job_id: str) -> Optional[Dict[str, str]]: | |
| """Thread-safe job removal.""" | |
| with JOB_LOCK: | |
| return JOBS.pop(job_id, None) | |
| def _cleanup_files(*files: str): | |
| """Background task to clean up temporary files after response is sent.""" | |
| for file_path in files: | |
| if file_path and Path(file_path).exists(): | |
| try: | |
| Path(file_path).unlink(missing_ok=True) | |
| except Exception: | |
| pass # Ignore cleanup errors | |
| def _run_generate_job(job_id: str, payload: Dict[str, str]): | |
| """ | |
| Background job for TTS generation using OpenVoice V2. | |
| Two-step process: | |
| 1. Generate base speech with MeloTTS | |
| 2. Apply target voice characteristics with ToneColorConverter | |
| """ | |
| speaker_file = None | |
| temp_audio = None | |
| output_file = None | |
| _set_job(job_id, status="processing") | |
| try: | |
| # Step 1: Generate base speech | |
| temp_audio = os.path.join( | |
| tempfile.gettempdir(), | |
| f"openvoice-temp-{uuid.uuid4()}.wav" | |
| ) | |
| speed = float(payload.get("speed", 1.0)) | |
| base_speaker_tts.tts_to_file( | |
| payload["text"], | |
| base_speaker, | |
| temp_audio, | |
| speed=speed | |
| ) | |
| # Step 2: Apply voice cloning if converter is available | |
| if tone_color_converter is not None: | |
| try: | |
| # Prepare reference audio | |
| speaker_file = _temp_speaker_file(payload["speaker_wav"]) | |
| # Extract target speaker embedding | |
| target_se, _ = se_extractor.get_se( | |
| speaker_file, | |
| tone_color_converter, | |
| vad=True | |
| ) | |
| # Get source speaker embedding | |
| source_se = torch.load( | |
| f'{MODEL_DIR}/checkpoints_v2/base_speakers/ses/en-us.pth', | |
| map_location=DEVICE | |
| ) | |
| # Apply voice conversion | |
| output_file = os.path.join( | |
| tempfile.gettempdir(), | |
| f"openvoice-{uuid.uuid4()}.wav" | |
| ) | |
| tone_color_converter.convert( | |
| audio_src_path=temp_audio, | |
| src_se=source_se, | |
| tgt_se=target_se, | |
| output_path=output_file, | |
| message="@MyShell" | |
| ) | |
| # Cleanup temp audio | |
| _cleanup_files(speaker_file, temp_audio) | |
| except Exception as convert_error: | |
| print(f"Voice conversion failed: {convert_error}") | |
| # Fall back to base audio without voice cloning | |
| output_file = temp_audio | |
| temp_audio = None | |
| _cleanup_files(speaker_file) | |
| else: | |
| # No converter available, use base audio | |
| output_file = temp_audio | |
| temp_audio = None | |
| # Verify output exists | |
| if not Path(output_file).exists(): | |
| raise RuntimeError( | |
| f"TTS generation failed: output file was not created" | |
| ) | |
| _set_job(job_id, status="completed", output_file=output_file) | |
| except Exception as exc: | |
| _cleanup_files(speaker_file, temp_audio, output_file) | |
| _set_job(job_id, status="error", error=str(exc)) | |
| def health(x_api_key: Optional[str] = Header(default=None)): | |
| """Health check endpoint.""" | |
| _require_api_key(x_api_key) | |
| return { | |
| "status": "ok", | |
| "model": "openvoice-v2", | |
| "device": DEVICE, | |
| "voice_cloning": tone_color_converter is not None, | |
| "supported_languages": ["en", "es", "fr", "zh", "ja", "ko"] | |
| } | |
| def generate( | |
| payload: GenerateRequest = Body(...), | |
| background_tasks: BackgroundTasks = BackgroundTasks(), | |
| x_api_key: Optional[str] = Header(default=None), | |
| ): | |
| """ | |
| Generate speech from text using voice cloning with OpenVoice. | |
| Returns job information for async processing. | |
| """ | |
| _require_api_key(x_api_key) | |
| job_id = str(uuid.uuid4()) | |
| _set_job(job_id, status="queued") | |
| # Offload the synthesis to background task | |
| background_tasks.add_task(_run_generate_job, job_id, payload.dict()) | |
| return JSONResponse( | |
| status_code=202, | |
| content={ | |
| "job_id": job_id, | |
| "status": "queued", | |
| "status_url": f"/status/{job_id}", | |
| "result_url": f"/result/{job_id}", | |
| }, | |
| ) | |
| def job_status(job_id: str, x_api_key: Optional[str] = Header(default=None)): | |
| """Check the status of a generation job.""" | |
| _require_api_key(x_api_key) | |
| job = _get_job(job_id) | |
| if not job: | |
| raise HTTPException(status_code=404, detail="Job not found") | |
| payload: Dict[str, str] = { | |
| "job_id": job_id, | |
| "status": job.get("status", "unknown") | |
| } | |
| if "error" in job: | |
| payload["error"] = job["error"] | |
| return payload | |
| def job_result( | |
| job_id: str, | |
| background_tasks: BackgroundTasks = BackgroundTasks(), | |
| x_api_key: Optional[str] = Header(default=None), | |
| ): | |
| """Retrieve the result of a completed generation job.""" | |
| _require_api_key(x_api_key) | |
| job = _get_job(job_id) | |
| if not job: | |
| raise HTTPException(status_code=404, detail="Job not found") | |
| status = job.get("status") | |
| if status != "completed": | |
| raise HTTPException( | |
| status_code=409, | |
| detail=f"Job not ready (status={status})" | |
| ) | |
| output_file = job.get("output_file") | |
| if not output_file or not Path(output_file).exists(): | |
| _pop_job(job_id) | |
| raise HTTPException(status_code=410, detail="Result expired or missing") | |
| # Remove job from memory and cleanup output after sending | |
| _pop_job(job_id) | |
| background_tasks.add_task(_cleanup_files, output_file) | |
| return FileResponse( | |
| output_file, | |
| media_type="audio/wav", | |
| filename="output.wav" | |
| ) | |
| def root(): | |
| """API root with available endpoints.""" | |
| return { | |
| "name": "openvoice-api", | |
| "version": "2.0.0", | |
| "model": "OpenVoice V2", | |
| "voice_cloning": tone_color_converter is not None, | |
| "endpoints": [ | |
| "/health", | |
| "/generate", | |
| "/status/{job_id}", | |
| "/result/{job_id}" | |
| ], | |
| "features": [ | |
| "Voice cloning with 3-10s reference audio" if tone_color_converter else "Base TTS only", | |
| "Multi-language support (EN, ES, FR, ZH, JA, KO)", | |
| "Adjustable speech speed (0.5-2.0x)", | |
| "Fast CPU performance" | |
| ] | |
| } | |