Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, UploadFile, File, Form | |
| from fastapi.responses import FileResponse | |
| import torch | |
| import torchaudio | |
| import os | |
| from pydantic import BaseModel | |
| from typing import List, Optional | |
| from pathlib import Path | |
| OUTPUT_DIR = "outputs" | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| from huggingface_hub import hf_hub_download | |
| # ------------------------ | |
| # Download model files from Hugging Face if not present | |
| # ------------------------ | |
| MODEL_DIR = "my_model" | |
| config_path = hf_hub_download( | |
| repo_id="MariaKaiser/egtts_finetuned_with_vocab", | |
| filename="my_model/config.json", | |
| cache_dir=MODEL_DIR | |
| ) | |
| vocab_path = hf_hub_download( | |
| repo_id="MariaKaiser/egtts_finetuned_with_vocab", | |
| filename="my_model/vocab.json", | |
| cache_dir=MODEL_DIR | |
| ) | |
| model_path = hf_hub_download( | |
| repo_id="MariaKaiser/egtts_finetuned_with_vocab", | |
| filename="my_model/model.pth", | |
| cache_dir=MODEL_DIR | |
| ) | |
| from TTS.tts.models.xtts import Xtts | |
| from TTS.tts.configs.xtts_config import XttsConfig | |
| # Load model | |
| config = XttsConfig() | |
| config.load_json(config_path) | |
| model = Xtts.init_from_config(config) | |
| model.load_checkpoint( | |
| config, | |
| checkpoint_dir= os.path.dirname(model_path), | |
| use_deepspeed=False, | |
| vocab_path= vocab_path | |
| ) | |
| model.to(device) | |
| # --------- Define your models ---------- | |
| class BGMusicDto(BaseModel): | |
| musicPath: str | |
| emotion: str | |
| volume: float | |
| class SentenceDto(BaseModel): | |
| speaker: str | |
| sentenceId: str | |
| sentence: str | |
| prosodyReference: str | |
| emotion: str | |
| intensity: str | |
| class LocationDto(BaseModel): | |
| locationName: str | |
| path: str | |
| class SceneDto(BaseModel): | |
| sceneId: str | |
| location: LocationDto | |
| sentences: List[SentenceDto] | |
| bgMusic: BGMusicDto | |
| class ChapterDto(BaseModel): | |
| chapterId: str | |
| title: SentenceDto | |
| scenes: List[SceneDto] | |
| class CastDto(BaseModel): | |
| name: str | |
| gender: str | |
| isAdult: bool | |
| voiceReference: str | |
| class StoryCreationDTO(BaseModel): | |
| storyId: str | |
| chapters: List[ChapterDto] | |
| cast: List[CastDto] | |
| #----------------------------------------------------------- | |
| #__________ func to get file from supabase__________________ | |
| import httpx | |
| import tempfile | |
| import asyncio | |
| # async def download_file_from_url(url: str, retries: int = 3, delay: float = 2.0) -> str | None: | |
| # """ | |
| # Downloads a file from a URL and returns the path to a temporary file. | |
| # Retries on failure up to `retries` times, waiting `delay` seconds between attempts. | |
| # Returns None if all attempts fail. | |
| # """ | |
| # for attempt in range(1, retries + 1): | |
| # try: | |
| # async with httpx.AsyncClient(timeout=60.0) as client: # increased timeout | |
| # response = await client.get(url) | |
| # response.raise_for_status() # raises for non-200 status codes | |
| # # Save to a temporary file | |
| # temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") | |
| # temp_file.write(response.content) | |
| # temp_file.close() | |
| # print(f"Downloaded {url} successfully on attempt {attempt}") | |
| # return temp_file.name | |
| # except Exception as e: | |
| # print(f"Attempt {attempt} failed for {url}: {e}") | |
| # if attempt < retries: | |
| # await asyncio.sleep(delay) # wait before retrying | |
| # print(f"All {retries} attempts failed for {url}") | |
| # return None | |
| download_cache = {} | |
| async def download_scene_files(scene: SceneDto): | |
| tasks = [] | |
| # Sentence prosody references | |
| for sentence in scene.sentences: | |
| tasks.append(download_file_from_url(sentence.prosodyReference)) | |
| # Location SFX | |
| if scene.location.path: | |
| tasks.append(download_file_from_url(scene.location.path)) | |
| # Background music | |
| if scene.bgMusic and scene.bgMusic.musicPath: | |
| tasks.append(download_file_from_url(scene.bgMusic.musicPath)) | |
| # Run all downloads concurrently | |
| downloaded_files = await asyncio.gather(*tasks) | |
| return downloaded_files | |
| async def download_file_from_url(url: str, retries: int = 3, delay: float = 2.0) -> str | None: | |
| """ | |
| Downloads a file from a URL and returns the path to a temporary file. | |
| If download fails after `retries` attempts, returns None instead of raising an error. | |
| Caches successful downloads to avoid repeated requests. | |
| """ | |
| if url in download_cache: | |
| #print(f"{url} is got from cache") | |
| return download_cache[url] | |
| for attempt in range(1, retries + 1): | |
| try: | |
| async with httpx.AsyncClient(timeout=60.0) as client: | |
| response = await client.get(url) | |
| response.raise_for_status() | |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") | |
| temp_file.write(response.content) | |
| temp_file.close() | |
| #print(f"{url} is downloaded and saved in cache") | |
| download_cache[url] = temp_file.name | |
| return temp_file.name | |
| except Exception as e: | |
| #print(f"Attempt {attempt} failed for {url}: {e}") | |
| if attempt < retries: | |
| await asyncio.sleep(delay) | |
| #print(f"All {retries} attempts failed for {url}, skipping...") | |
| return None | |
| #----------------------------------------------------------- | |
| #takes the text to be said and path to the prosody audio and path to save the generated audio and returns path to the generated audio | |
| # (save_path -> full path including the filename, not just a folder.) | |
| def inference_by_model(text: str, audio_file: str, save_path: str) -> str: | |
| gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(audio_path=[audio_file]) | |
| out = model.inference( | |
| text=text, | |
| language="ar", | |
| gpt_cond_latent=gpt_cond_latent, | |
| speaker_embedding=speaker_embedding, | |
| temperature= 0.65, | |
| top_k=model.config.top_k, | |
| length_penalty=model.config.length_penalty, | |
| repetition_penalty=model.config.repetition_penalty, | |
| top_p=model.config.top_p, | |
| ) | |
| os.makedirs(os.path.dirname(save_path), exist_ok=True) | |
| torchaudio.save(save_path, torch.tensor(out["wav"]).unsqueeze(0), 24000) | |
| return save_path | |
| #_______________generate audios and folder structure_______________________ | |
| async def generate_story_audios(story: StoryCreationDTO, base_output: str): | |
| """ | |
| Generates audio files and folders for the entire story | |
| """ | |
| story_dir = Path(base_output) / story.storyId | |
| story_dir.mkdir(parents=True, exist_ok=True) | |
| for chapter in story.chapters: | |
| chapter_dir = story_dir / chapter.chapterId | |
| chapter_dir.mkdir(exist_ok=True) | |
| # --- Chapter title audio --- | |
| prosody_file_title = await download_file_from_url(chapter.title.prosodyReference) | |
| title_save_path = chapter_dir / "title.wav" | |
| tagged_text_title = generate_tagged_text( | |
| chapter.title.sentence, | |
| chapter.title.emotion, | |
| chapter.title.intensity | |
| ) | |
| title_generated_audio_path = inference_by_model( | |
| text=tagged_text_title, | |
| audio_file=prosody_file_title, | |
| save_path=title_save_path | |
| ) | |
| # os.remove(prosody_file_title) | |
| for scene in chapter.scenes: | |
| await download_scene_files(scene) | |
| scene_dir = chapter_dir / scene.sceneId | |
| scene_dir.mkdir(exist_ok=True) | |
| # --- Sentences audio --- | |
| for sentence in scene.sentences: | |
| # Download the prosody reference audio from Supabase | |
| prosody_file = download_cache[sentence.prosodyReference] | |
| sentence_save_path = scene_dir / f"{sentence.sentenceId}.wav" | |
| tagged_text = generate_tagged_text( | |
| sentence.sentence, | |
| sentence.emotion, | |
| sentence.intensity | |
| ) | |
| sentence_generated_audio_path = inference_by_model( | |
| text=tagged_text, | |
| audio_file=prosody_file, | |
| save_path=sentence_save_path | |
| ) | |
| # os.remove(prosody_file) | |
| #_______________ Concatenating the generated audios to make the final story (post-processing)_______________________ | |
| from pydub import AudioSegment | |
| import os | |
| import subprocess | |
| def ensure_wav(file_path: str) -> str: | |
| """ | |
| Convert a single audio file to WAV using ffmpeg. | |
| Returns the path to the WAV file. | |
| If the file is already WAV, returns the original path. | |
| """ | |
| ext = os.path.splitext(file_path)[1].lower() | |
| if ext == ".wav": | |
| return file_path # Already WAV | |
| # Output path: same folder, same name, .wav extension | |
| wav_path = os.path.splitext(file_path)[0] + ".wav" | |
| # Run ffmpeg conversion | |
| subprocess.run(["ffmpeg", "-y", "-i", file_path, wav_path], check=True) | |
| print(f"Converted: {file_path} → {wav_path}") | |
| return wav_path | |
| from pydub import AudioSegment | |
| import asyncio | |
| async def concat_story_audio(story: StoryCreationDTO, base_output: str, final_path: str = None): # full path including filename | |
| story_dir = Path(base_output) / story.storyId | |
| story_dir.mkdir(parents=True, exist_ok=True) | |
| if final_path is None: | |
| final_path = story_dir / f"{story.storyId}_full.wav" | |
| else: | |
| final_path = Path(final_path) | |
| final_path.parent.mkdir(parents=True, exist_ok=True) # ensure folder exists | |
| chapters_audio = AudioSegment.silent(duration=0) # start empty | |
| for chapter in story.chapters: | |
| chapter_dir = story_dir / chapter.chapterId | |
| # --- Chapter title --- | |
| title_path = chapter_dir / "title.wav" | |
| chapter_audio = AudioSegment.from_wav(title_path) | |
| for scene in chapter.scenes: | |
| scene_dir = chapter_dir / scene.sceneId | |
| scene_audio = AudioSegment.silent(duration=0) | |
| # --- Concatenate sentence audios --- | |
| for sentence in scene.sentences: | |
| sentence_path = scene_dir / f"{sentence.sentenceId}.wav" | |
| sentence_audio = AudioSegment.from_wav(sentence_path) | |
| scene_audio += sentence_audio | |
| # --- Add SFX for location if available --- | |
| if scene.location.path: | |
| sfx_file = await download_file_from_url(scene.location.path) | |
| if sfx_file: | |
| sfx_file_wav = ensure_wav(sfx_file) | |
| sfx_audio = AudioSegment.from_wav(sfx_file_wav) | |
| scene_audio = scene_audio.overlay(sfx_audio) | |
| # os.remove(sfx_file) | |
| #else: | |
| #print(f"SFX skipped for {scene.location.locationName}") | |
| # --- Add background music if available --- | |
| if scene.bgMusic and scene.bgMusic.musicPath: | |
| bg_url = scene.bgMusic.musicPath | |
| bg_file = await download_file_from_url(bg_url) | |
| bg_file_wav = ensure_wav(bg_file) | |
| bg_audio = AudioSegment.from_file(bg_file_wav) | |
| # Adjust volume | |
| bg_audio = bg_audio - (1 - scene.bgMusic.volume) * 30 # approximate | |
| # Loop if shorter than scene | |
| if len(bg_audio) < len(scene_audio): | |
| loops = (len(scene_audio) // len(bg_audio)) + 1 | |
| bg_audio = bg_audio * loops | |
| bg_audio = bg_audio[:len(scene_audio)] # trim to match scene | |
| scene_audio = scene_audio.overlay(bg_audio) | |
| # os.remove(bg_file) | |
| # Add 2 seconds of silence between scenes | |
| scene_audio += AudioSegment.silent(duration=2000) | |
| chapter_audio += scene_audio | |
| # Add 3 seconds of silence between chapters | |
| chapter_audio += AudioSegment.silent(duration=3000) | |
| chapters_audio += chapter_audio | |
| # Export final story | |
| chapters_audio.export(final_path, format="wav") | |
| return final_path | |
| #------------------------------------------------------------- | |
| app = FastAPI(title="EGTTS Arabic TTS API") | |
| tasks = {} | |
| #___________________Test end point to test supabase fetch | |
| from fastapi import Query | |
| from fastapi.responses import Response | |
| async def test_download(url: str = Query(...)): | |
| try: | |
| file_bytes = await download_file_from_url(url) | |
| return Response( | |
| content=file_bytes, | |
| media_type="audio/wav" # change if needed | |
| ) | |
| except Exception as e: | |
| return {"error": str(e)} | |
| #_________________________________________ | |
| def root(): | |
| return {"message": "Welcome! Visit /docs for Swagger UI."} | |
| #----------------------------------------------------------- | |
| class TTSResponse(BaseModel): | |
| fileName: str | |
| duration: float # seconds | |
| audioPath: str | |
| #---------------------------concatenate text with tags --------------------------- | |
| # Map Intensity numbers to tag strings | |
| intensity_map = { | |
| "LOW": "low", | |
| "MEDIUM": "mid", | |
| "HIGH": "high" | |
| } | |
| # Map Emotion enum names to lowercase tag strings | |
| emotion_map = { | |
| "HAPPINESS": "happiness", | |
| "SADNESS": "sadness", | |
| "FEAR": "fear", | |
| "ANGER": "anger", | |
| "SURPRISE": "surprise", | |
| "WHISPER": "whisper", | |
| "NARRATION": "narration" | |
| } | |
| def generate_tagged_text(text: str, emotion_enum: str, intensity_enum: str) -> str: | |
| """ | |
| Convert enums to <emo_x> <int_y> format and concatenate with text | |
| """ | |
| emo_tag = f"<emo_{emotion_map[emotion_enum]}>" | |
| int_tag = f"<int_{intensity_map[intensity_enum]}>" | |
| return f"{emo_tag} {int_tag} {text}" | |
| #----------------------------------------------------------- | |
| #-----------------Post End Point_____________________________ | |
| # @app.post("/tts/") | |
| # async def process_story(story: StoryCreationDTO): | |
| # # Optional: print info for debugging | |
| # print(story.storyId) | |
| # for cast in story.cast: | |
| # print(cast.name, cast.voiceReference) | |
| # for chapter in story.chapters: | |
| # for scene in chapter.scenes: | |
| # for sentence in scene.sentences: | |
| # print(sentence.speaker, sentence.sentence) | |
| # # 1️⃣ Generate all sentence audios and folder structure | |
| # await generate_story_audios(story, base_output=OUTPUT_DIR) | |
| # # 2️⃣ Concatenate all into final story audio | |
| # final_story_path = os.path.join(OUTPUT_DIR, story.storyId, f"{story.storyId}_full.wav") | |
| # final_generated_story_path = await concat_story_audio(story, base_output=OUTPUT_DIR, final_path=final_story_path) | |
| # # Convert to base64 and get duration | |
| # audio_b64, duration = audio_to_base64(final_generated_story_path) | |
| # response = TTSResponse( | |
| # file_name= os.path.basename(final_generated_story_path), | |
| # duration=duration, | |
| # audio_base64=audio_b64 | |
| # ) | |
| # return response | |
| # async def run_tts_pipeline(task_id: str, story: StoryCreationDTO): | |
| # try: | |
| # await generate_story_audios(story, base_output=OUTPUT_DIR) | |
| # final_story_path = os.path.join( | |
| # OUTPUT_DIR, | |
| # story.storyId, | |
| # f"{story.storyId}_full.wav" | |
| # ) | |
| # final_generated_story_path = await concat_story_audio( | |
| # story, | |
| # base_output=OUTPUT_DIR, | |
| # final_path=final_story_path | |
| # ) | |
| # audio_b64, duration = audio_to_base64(final_generated_story_path) | |
| # tasks[task_id] = { | |
| # "status": "completed", | |
| # "result": { | |
| # "fileName": os.path.basename(final_generated_story_path), | |
| # "duration": duration, | |
| # "audioPath": audio_b64 | |
| # } | |
| # } | |
| # except Exception as e: | |
| # print(f"Exception caught at run tts pipeline {str(e)} and status is now failed") | |
| # tasks[task_id] = { | |
| # "status": "failed", | |
| # "error": str(e) | |
| # } | |
| import os | |
| import uuid | |
| from supabase import create_client, Client | |
| from pydub import AudioSegment # For duration in seconds | |
| # Initialize Supabase client | |
| SUPABASE_URL = "https://kvlxvhdgacktsgykyckm.supabase.co/" | |
| SUPABASE_KEY = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6Imt2bHh2aGRnYWNrdHNneWt5Y2ttIiwicm9sZSI6InNlcnZpY2Vfcm9sZSIsImlhdCI6MTc3MTk2MTQ5MSwiZXhwIjoyMDg3NTM3NDkxfQ.tzfHcbzwzctHDDDp3vk4JGz30ajN2szncAV-1wK7_pM" | |
| supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY) | |
| import time | |
| async def run_tts_pipeline(task_id: str, story: StoryCreationDTO): | |
| start_time = time.time() # start timer | |
| try: | |
| # 1️⃣ Generate story audios | |
| await generate_story_audios(story, base_output=OUTPUT_DIR) | |
| # 2️⃣ Concatenate final story audio | |
| final_story_path = os.path.join( | |
| OUTPUT_DIR, | |
| story.storyId, | |
| f"{story.storyId}_full.wav" | |
| ) | |
| final_generated_story_path = await concat_story_audio( | |
| story, | |
| base_output=OUTPUT_DIR, | |
| final_path=final_story_path | |
| ) | |
| print(f" final_generated_story_path: {final_generated_story_path}") | |
| wav = AudioSegment.from_wav(final_generated_story_path) | |
| mp3_path = final_generated_story_path.with_suffix(".mp3") | |
| wav.export(mp3_path, format="mp3", bitrate="192k") | |
| print(f" final_generated_story_path after conversion to mp3: {mp3_path}") | |
| # 3️⃣ Calculate duration | |
| audio_segment = AudioSegment.from_file(mp3_path) | |
| duration_seconds = len(audio_segment) / 1000 # pydub gives length in milliseconds | |
| # 4️⃣ Prepare the file for upload | |
| file_name = f"{uuid.uuid4()}_{os.path.basename(mp3_path)}" | |
| storage_path = f"{story.storyId}/final/{file_name}" | |
| # with open(final_generated_story_path, "rb") as f: | |
| # file_bytes = f.read() | |
| supabase.storage.from_("story-audio-files").upload( | |
| storage_path, | |
| mp3_path | |
| ) | |
| # 6️⃣ Get public URL | |
| audio_url = supabase.storage.from_("story-audio-files").get_public_url(storage_path) | |
| # 7️⃣ Update task status with audio URL and duration | |
| tasks[task_id] = { | |
| "status": "completed", | |
| "result": { | |
| "fileName": os.path.basename(mp3_path), | |
| "duration": duration_seconds, | |
| "audioPath": audio_url | |
| } | |
| } | |
| # --- Print processing time --- | |
| end_time = time.time() | |
| elapsed = end_time - start_time | |
| print(f"Story {story.storyId} processed in {elapsed:.2f} seconds") | |
| except Exception as e: | |
| print(f"exception caught at run tts pipeline {str(e)}") | |
| tasks[task_id] = { | |
| "status": "failed", | |
| "error": str(e) | |
| } | |
| from fastapi import BackgroundTasks | |
| import uuid | |
| async def process_story(story: StoryCreationDTO, background_tasks: BackgroundTasks): | |
| task_id = str(uuid.uuid4()) | |
| tasks[task_id] = { | |
| "status": "processing", | |
| "result": None | |
| } | |
| background_tasks.add_task(run_tts_pipeline, task_id, story) | |
| return {"task_id": task_id} | |
| #-----------------------Results Get End Point ______________________________________ | |
| # @app.get("/tts/results/{task_id}") | |
| # async def get_results(task_id: str): | |
| # if task_id not in tasks: | |
| # return {"status": "not_found"} | |
| # task = tasks[task_id] | |
| # if task["status"] == "processing": | |
| # return {"status": "processing"} | |
| # if task["status"] == "failed": | |
| # return { | |
| # "status": "failed", | |
| # "error": task["error"] | |
| # } | |
| # return task["result"] | |
| async def get_results(task_id: str): | |
| if task_id not in tasks: | |
| return {"status": "not_found"} | |
| task = tasks[task_id] | |
| if task["status"] == "processing": | |
| return {"status": "processing"} | |
| if task["status"] == "failed": | |
| return { | |
| "status": "failed", | |
| "error": task.get("error", "Unknown error") | |
| } | |
| # Ensure result exists and has all required fields | |
| result = task.get("result") | |
| if result and all(k in result for k in ("fileName", "duration", "audioPath")): | |
| #clearing cache | |
| print(f"all fields are available {result}") | |
| for file_path in download_cache.values(): | |
| if os.path.exists(file_path): | |
| os.remove(file_path) | |
| download_cache.clear() | |
| return {"status": "completed", **result} | |
| else: | |
| print(f"missing field {result}") | |
| # If result is missing fields, mark as still processing | |
| return {"status": "processing"} | |
| #----------------------------Test End Point to test tts inference------------------------------------ | |
| async def tts_endpoint( | |
| text: str = Form(...), | |
| audio_file: UploadFile = File(...), | |
| emotionName: str = Form(...), | |
| intensity: int = Form(...) | |
| ): | |
| file_path = os.path.join(OUTPUT_DIR, audio_file.filename) | |
| with open(file_path, "wb") as f: | |
| f.write(await audio_file.read()) | |
| tagged_text = generate_tagged_text(text, emotionName, intensity) | |
| output_path = os.path.join(OUTPUT_DIR, "out_test.wav") | |
| output_wav = inference_by_model(tagged_text, file_path,output_path) | |
| return FileResponse(output_wav, media_type="audio/wav", filename="output.wav") | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |
| # if __name__ == "__main__": | |
| # import uvicorn | |
| # uvicorn.run(app, host="0.0.0.0", port=7860) |