| import logging |
| import os |
| import uuid |
| from contextlib import asynccontextmanager |
| from tempfile import NamedTemporaryFile |
|
|
| import boto3 |
| import torchaudio |
| from fastapi import BackgroundTasks, Depends, FastAPI, Header, HTTPException |
| from fastapi.security import APIKeyHeader |
| from pydantic import BaseModel |
|
|
| from inference import load_models, process_voice_conversion |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| |
| models = None |
| API_KEY = os.getenv("API_KEY") |
|
|
| api_key_header = APIKeyHeader(name="Authorization", auto_error=False) |
|
|
|
|
| async def verify_api_key(authorization: str = Header(None)): |
| if not authorization: |
| logger.warning("No API key provided") |
| raise HTTPException(status_code=401, detail="API key is missing") |
|
|
| if authorization.startswith("Bearer "): |
| token = authorization.replace("Bearer ", "") |
| else: |
| token = authorization |
|
|
| if token != API_KEY: |
| logger.warning("Invalid API key provided") |
| raise HTTPException(status_code=401, detail="Invalid API key") |
|
|
| return token |
|
|
|
|
| def get_s3_client(): |
| client_kwargs = {'region_name': os.getenv("AWS_REGION", "us-east-1")} |
|
|
| if os.getenv("AWS_ACCESS_KEY_ID") and os.getenv("AWS_SECRET_ACCESS_KEY"): |
| client_kwargs.update({ |
| 'aws_access_key_id': os.getenv("AWS_ACCESS_KEY_ID"), |
| 'aws_secret_access_key': os.getenv("AWS_SECRET_ACCESS_KEY") |
| }) |
|
|
| return boto3.client('s3', **client_kwargs) |
|
|
|
|
| s3_client = get_s3_client() |
|
|
| S3_PREFIX = os.getenv("S3_PREFIX", "seedvc-outputs") |
| S3_BUCKET = os.getenv("S3_BUCKET", "elevenlabs-clone") |
|
|
|
|
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| global models |
| logger.info("Loading Seed-VC model...") |
| try: |
| models = load_models() |
|
|
| logger.info("Seed-VC model loaded successfully") |
| except Exception as e: |
| logger.error(f"Failed to load model: {e}") |
| raise |
|
|
| yield |
|
|
| logger.info("Shutting down Seed-VC API") |
|
|
| app = FastAPI(title="Seed-VC API", |
| lifespan=lifespan) |
|
|
| TARGET_VOICES = { |
| "andreas": "examples/reference/andreas1.wav", |
| "woman": "examples/reference/s1p1.wav", |
| "trump": "examples/reference/trump_0.wav", |
| } |
|
|
|
|
| class VoiceConversionRequest(BaseModel): |
| source_audio_key: str |
| target_voice: str |
|
|
|
|
| @app.post("/convert", dependencies=[Depends(verify_api_key)]) |
| async def generate_speech(request: VoiceConversionRequest, background_tasks: BackgroundTasks): |
| if not models: |
| raise HTTPException(status_code=500, detail="Model not loaded") |
|
|
| if request.target_voice not in TARGET_VOICES: |
| raise HTTPException( |
| status_code=400, detail=f"Target voice not supported. Choose from: {', '.join(TARGET_VOICES.keys())}") |
|
|
| try: |
| target_audio_path = TARGET_VOICES[request.target_voice] |
| logger.info( |
| f"Converting voice: {request.source_audio_key} to {request.target_voice}") |
|
|
| |
| audio_id = str(uuid.uuid4()) |
| output_filename = f"{audio_id}.wav" |
| local_path = f"/tmp/{output_filename}" |
|
|
| logger.info("Downloading source audio") |
| source_temp = NamedTemporaryFile(delete=False, suffix=".wav") |
| try: |
| s3_client.download_fileobj( |
| S3_BUCKET, Key=request.source_audio_key, Fileobj=source_temp) |
| source_temp.close() |
| except Exception as e: |
| os.unlink(source_temp.name) |
| raise HTTPException( |
| status_code=404, detail="Source audio not found") |
|
|
| vc_wave, sr = process_voice_conversion( |
| models=models, source=source_temp.name, target_name=target_audio_path, output=None) |
|
|
| os.unlink(source_temp.name) |
|
|
| torchaudio.save(local_path, vc_wave, sr) |
|
|
| |
| s3_key = f"{S3_PREFIX}/{output_filename}" |
| s3_client.upload_file(local_path, S3_BUCKET, s3_key) |
|
|
| presigned_url = s3_client.generate_presigned_url( |
| 'get_object', |
| Params={'Bucket': S3_BUCKET, 'Key': s3_key}, |
| ExpiresIn=3600 |
| ) |
|
|
| background_tasks.add_task(os.remove, local_path) |
|
|
| return { |
| "audio_url": presigned_url, |
| "s3_key": s3_key |
| } |
| except Exception as e: |
| logger.error(f"Error in voice conversion: {e}") |
| raise HTTPException( |
| status_code=500, detail="Error in voice conversion") |
|
|
|
|
| @app.get("/voices", dependencies=[Depends(verify_api_key)]) |
| async def list_voices(): |
| return {"voices": list(TARGET_VOICES.keys())} |
|
|
|
|
| @app.get("/health", dependencies=[Depends(verify_api_key)]) |
| async def health_check(): |
| if models: |
| return {"status": "healthy", "model": "loaded"} |
| return {"status": "unhealthy", "model": "not loaded"} |
|
|