Spaces:
Sleeping
Sleeping
| import logging | |
| import os | |
| import io | |
| import numpy as np | |
| import torch | |
| import soundfile as sf | |
| import uuid | |
| import time | |
| from fastapi import FastAPI, HTTPException, BackgroundTasks | |
| from fastapi.responses import StreamingResponse, JSONResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from typing import Optional | |
| from omnivoice import OmniVoice, OmniVoiceGenerationConfig | |
| from text_preprocessor import chunk_text | |
| logging.basicConfig( | |
| level=logging.WARNING, | |
| format="%(asctime)s %(name)s %(levelname)s: %(message)s", | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # FastAPI app | |
| app = FastAPI(title="Arabic TTS Server (OmniVoice)") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=[ | |
| "https://arabic-tts-frontend.web.app", | |
| "https://arabic-tts-frontend.firebaseapp.com", | |
| "http://localhost:3000", | |
| "http://localhost:8000" | |
| ], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Global variables for model | |
| CHECKPOINT = os.environ.get("OMNIVOICE_MODEL", "k2-fsa/OmniVoice") | |
| model = None | |
| sampling_rate = 24000 | |
| # Simple In-Memory Database for State Tracking (Step 4 preview) | |
| tasks_db = {} | |
| async def startup_event(): | |
| global model, sampling_rate | |
| print(f"Loading OmniVoice model from {CHECKPOINT} ...") | |
| model = OmniVoice.from_pretrained( | |
| CHECKPOINT, | |
| load_asr=True, | |
| device_map="cpu", # Using CPU by default, adjust if a GPU is available. | |
| ) | |
| sampling_rate = model.sampling_rate | |
| print("Model loaded successfully!") | |
| class SynthesizeRequest(BaseModel): | |
| text: str | |
| voice: Optional[str] = "Auto" | |
| speed: Optional[float] = 1.0 | |
| import os | |
| import shutil | |
| def process_audio_task(task_id: str, chunks: list[str], speed: float, voice_id: str): | |
| """ | |
| Background worker that iterivately generates audio, saves chunks to disk, | |
| concatenates them, and handles cleanup. Validates consistent voice! | |
| """ | |
| try: | |
| total_chunks = len(chunks) | |
| tasks_db[task_id]["status"] = "processing" | |
| tasks_db[task_id]["total_chunks"] = total_chunks | |
| chunk_dir = os.path.join("audio_chunks", task_id) | |
| os.makedirs(chunk_dir, exist_ok=True) | |
| gen_config = OmniVoiceGenerationConfig( | |
| num_step=32, | |
| guidance_scale=2.0, | |
| denoise=True, | |
| preprocess_prompt=False, | |
| postprocess_output=False, | |
| ) | |
| master_voice_prompt = None | |
| # Check if user requested a specific built-in voice | |
| if voice_id and voice_id != "Auto": | |
| voice_path = os.path.join("voices", f"{voice_id}.wav") | |
| text_path = os.path.join("voices", f"{voice_id}.txt") | |
| if os.path.exists(voice_path): | |
| ref_text = None | |
| if os.path.exists(text_path): | |
| with open(text_path, "r", encoding="utf-8") as f: | |
| ref_text = f.read().strip() | |
| try: | |
| master_voice_prompt = model.create_voice_clone_prompt(ref_audio=voice_path, ref_text=ref_text) | |
| except Exception as e: | |
| logger.warning(f"Voice clone setup failed: {e}") | |
| for i, chunk in enumerate(chunks): | |
| # Update state tracker | |
| tasks_db[task_id]["current_chunk"] = i + 1 | |
| tasks_db[task_id]["progress"] = int(((i) / total_chunks) * 100) | |
| chunk_path = os.path.join(chunk_dir, f"chunk_{i}.wav") | |
| # Check if already generated for resume ability | |
| if not os.path.exists(chunk_path): | |
| kw = dict( | |
| text=chunk, | |
| language="Auto", | |
| generation_config=gen_config | |
| ) | |
| if speed is not None and speed != 1.0: | |
| kw["speed"] = speed | |
| # Apply consistent voice cloning (prevents mid-book gender switching) | |
| if master_voice_prompt is not None: | |
| kw["voice_clone_prompt"] = master_voice_prompt | |
| # Generate Audio via OmniVoice | |
| audio = model.generate(**kw) | |
| waveform = audio[0].squeeze(0).numpy() | |
| # Save chunk to disk incrementally | |
| sf.write(chunk_path, waveform, sampling_rate, format='wav', subtype='PCM_16') | |
| # If Auto mode, use the FIRST successfully generated chunk as the reference | |
| # voice for ALL subsequent chunks. This locks the randomly chosen voice! | |
| if master_voice_prompt is None and i == 0: | |
| try: | |
| master_voice_prompt = model.create_voice_clone_prompt(ref_audio=chunk_path, ref_text=chunk) | |
| except Exception as e: | |
| logger.warning(f"Could not extract voice clone from chunk 0: {e}") | |
| # All chunks generated, now concatenate | |
| tasks_db[task_id]["status"] = "stitching" | |
| all_data = [] | |
| sr = sampling_rate | |
| for i in range(total_chunks): | |
| chunk_path = os.path.join(chunk_dir, f"chunk_{i}.wav") | |
| data, sr = sf.read(chunk_path) | |
| all_data.append(data) | |
| final_waveform = np.concatenate(all_data) | |
| final_dir = os.path.join("static", "audio") | |
| os.makedirs(final_dir, exist_ok=True) | |
| final_path = os.path.join(final_dir, f"{task_id}.wav") | |
| sf.write(final_path, final_waveform, sr, format='wav', subtype='PCM_16') | |
| # Cleanup temporary chunks | |
| shutil.rmtree(chunk_dir) | |
| tasks_db[task_id]["status"] = "completed" | |
| tasks_db[task_id]["progress"] = 100 | |
| tasks_db[task_id]["download_url"] = f"audio/{task_id}.wav" | |
| except Exception as e: | |
| logger.error(f"Background task failed: {str(e)}") | |
| tasks_db[task_id]["status"] = "failed" | |
| tasks_db[task_id]["error"] = str(e) | |
| async def synthesize(req: SynthesizeRequest, background_tasks: BackgroundTasks): | |
| if not model: | |
| raise HTTPException(status_code=500, detail="Model not loaded yet.") | |
| try: | |
| # Step 1 Integration: chunk the text | |
| chunks = chunk_text(req.text.strip()) | |
| if not chunks: | |
| raise HTTPException(status_code=400, detail="Text is empty or invalid.") | |
| task_id = str(uuid.uuid4()) | |
| # Step 4 preview: Initialize state tracking | |
| tasks_db[task_id] = { | |
| "task_id": task_id, | |
| "status": "pending", | |
| "progress": 0, | |
| "current_chunk": 0, | |
| "total_chunks": len(chunks) | |
| } | |
| # Step 2 Integration: Start the background process instead of blocking | |
| background_tasks.add_task(process_audio_task, task_id, chunks, req.speed, req.voice) | |
| return JSONResponse(content={"task_id": task_id, "message": "Audio generation started in the background."}) | |
| except Exception as e: | |
| logger.error(f"Synthesis failed: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def get_status(task_id: str): | |
| if task_id not in tasks_db: | |
| raise HTTPException(status_code=404, detail="Task not found.") | |
| return JSONResponse(content=tasks_db[task_id]) | |
| # Ensure voices directory is explicitly available | |
| os.makedirs("voices", exist_ok=True) | |
| app.mount("/voices", StaticFiles(directory="voices"), name="voices") | |
| # Mount static files directly on root | |
| app.mount("/", StaticFiles(directory="static", html=True), name="static") | |