MariaKaiser commited on
Commit
37516b2
·
verified ·
1 Parent(s): e5595a9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +199 -469
app.py CHANGED
@@ -1,22 +1,114 @@
1
- from fastapi import FastAPI, UploadFile, File, Form
2
  from fastapi.responses import FileResponse
3
- import torch
4
- import torchaudio
5
- import os
6
  from pydantic import BaseModel
7
- from typing import List, Optional
8
  from pathlib import Path
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
 
 
 
10
  OUTPUT_DIR = "outputs"
11
  os.makedirs(OUTPUT_DIR, exist_ok=True)
12
 
 
 
 
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
 
15
- from huggingface_hub import hf_hub_download
 
 
 
 
 
16
 
17
- # ------------------------
18
- # Download model files from Hugging Face if not present
19
- # ------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  MODEL_DIR = "my_model"
21
 
22
  config_path = hf_hub_download(
@@ -37,30 +129,45 @@ model_path = hf_hub_download(
37
  cache_dir=MODEL_DIR
38
  )
39
 
40
- from TTS.tts.models.xtts import Xtts
41
- from TTS.tts.configs.xtts_config import XttsConfig
42
-
43
- # Load model
44
  config = XttsConfig()
45
  config.load_json(config_path)
46
 
47
  model = Xtts.init_from_config(config)
48
  model.load_checkpoint(
49
  config,
50
- checkpoint_dir= os.path.dirname(model_path),
51
  use_deepspeed=False,
52
- vocab_path= vocab_path
53
  )
54
  model.to(device)
55
 
56
- # --------- Define your models ----------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
 
 
 
58
  class BGMusicDto(BaseModel):
59
  musicPath: str
60
  emotion: str
61
  volume: float
62
 
63
-
64
  class SentenceDto(BaseModel):
65
  speaker: str
66
  sentenceId: str
@@ -72,7 +179,7 @@ class SentenceDto(BaseModel):
72
  class LocationDto(BaseModel):
73
  locationName: str
74
  path: str
75
-
76
  class SceneDto(BaseModel):
77
  sceneId: str
78
  location: LocationDto
@@ -84,114 +191,26 @@ class ChapterDto(BaseModel):
84
  title: SentenceDto
85
  scenes: List[SceneDto]
86
 
87
-
88
  class CastDto(BaseModel):
89
  name: str
90
  gender: str
91
  isAdult: bool
92
  voiceReference: str
93
 
94
-
95
  class StoryCreationDTO(BaseModel):
96
  storyId: str
97
  chapters: List[ChapterDto]
98
  cast: List[CastDto]
99
 
100
- #-----------------------------------------------------------
101
-
102
- #__________ func to get file from supabase__________________
103
-
104
- import httpx
105
- import tempfile
106
- import asyncio
107
-
108
- # async def download_file_from_url(url: str, retries: int = 3, delay: float = 2.0) -> str | None:
109
- # """
110
- # Downloads a file from a URL and returns the path to a temporary file.
111
- # Retries on failure up to `retries` times, waiting `delay` seconds between attempts.
112
- # Returns None if all attempts fail.
113
- # """
114
- # for attempt in range(1, retries + 1):
115
- # try:
116
- # async with httpx.AsyncClient(timeout=60.0) as client: # increased timeout
117
- # response = await client.get(url)
118
- # response.raise_for_status() # raises for non-200 status codes
119
-
120
- # # Save to a temporary file
121
- # temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
122
- # temp_file.write(response.content)
123
- # temp_file.close()
124
- # print(f"Downloaded {url} successfully on attempt {attempt}")
125
- # return temp_file.name
126
-
127
- # except Exception as e:
128
- # print(f"Attempt {attempt} failed for {url}: {e}")
129
- # if attempt < retries:
130
- # await asyncio.sleep(delay) # wait before retrying
131
-
132
- # print(f"All {retries} attempts failed for {url}")
133
- # return None
134
-
135
-
136
- download_cache = {}
137
-
138
- async def download_scene_files(scene: SceneDto):
139
- tasks = []
140
-
141
- # Sentence prosody references
142
- for sentence in scene.sentences:
143
- tasks.append(download_file_from_url(sentence.prosodyReference))
144
-
145
- # Location SFX
146
- if scene.location.path:
147
- tasks.append(download_file_from_url(scene.location.path))
148
-
149
- # Background music
150
- if scene.bgMusic and scene.bgMusic.musicPath:
151
- tasks.append(download_file_from_url(scene.bgMusic.musicPath))
152
-
153
- # Run all downloads concurrently
154
- downloaded_files = await asyncio.gather(*tasks)
155
- return downloaded_files
156
 
157
- async def download_file_from_url(url: str, retries: int = 3, delay: float = 2.0) -> str | None:
158
- """
159
- Downloads a file from a URL and returns the path to a temporary file.
160
- If download fails after `retries` attempts, returns None instead of raising an error.
161
- Caches successful downloads to avoid repeated requests.
162
- """
163
- if url in download_cache:
164
- #print(f"{url} is got from cache")
165
- return download_cache[url]
166
-
167
- for attempt in range(1, retries + 1):
168
- try:
169
- async with httpx.AsyncClient(timeout=60.0) as client:
170
- response = await client.get(url)
171
- response.raise_for_status()
172
-
173
- temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
174
- temp_file.write(response.content)
175
- temp_file.close()
176
-
177
- #print(f"{url} is downloaded and saved in cache")
178
-
179
- download_cache[url] = temp_file.name
180
- return temp_file.name
181
-
182
- except Exception as e:
183
- #print(f"Attempt {attempt} failed for {url}: {e}")
184
- if attempt < retries:
185
- await asyncio.sleep(delay)
186
-
187
- #print(f"All {retries} attempts failed for {url}, skipping...")
188
- return None
189
-
190
- #-----------------------------------------------------------
191
-
192
- #takes the text to be said and path to the prosody audio and path to save the generated audio and returns path to the generated audio
193
- # (save_path -> full path including the filename, not just a folder.)
194
- def inference_by_model(text: str, audio_file: str, save_path: str) -> str:
195
  gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(audio_path=[audio_file])
196
  out = model.inference(
197
  text=text,
@@ -204,17 +223,14 @@ def inference_by_model(text: str, audio_file: str, save_path: str) -> str:
204
  repetition_penalty=model.config.repetition_penalty,
205
  top_p=model.config.top_p,
206
  )
207
-
208
  os.makedirs(os.path.dirname(save_path), exist_ok=True)
209
  torchaudio.save(save_path, torch.tensor(out["wav"]).unsqueeze(0), 24000)
210
  return save_path
211
 
212
- #_______________generate audios and folder structure_______________________
213
-
 
214
  async def generate_story_audios(story: StoryCreationDTO, base_output: str):
215
- """
216
- Generates audio files and folders for the entire story
217
- """
218
  story_dir = Path(base_output) / story.storyId
219
  story_dir.mkdir(parents=True, exist_ok=True)
220
 
@@ -222,337 +238,115 @@ async def generate_story_audios(story: StoryCreationDTO, base_output: str):
222
  chapter_dir = story_dir / chapter.chapterId
223
  chapter_dir.mkdir(exist_ok=True)
224
 
225
- # --- Chapter title audio ---
226
- prosody_file_title = await download_file_from_url(chapter.title.prosodyReference)
227
  title_save_path = chapter_dir / "title.wav"
228
-
229
  tagged_text_title = generate_tagged_text(
230
- chapter.title.sentence,
231
- chapter.title.emotion,
232
- chapter.title.intensity
233
- )
234
-
235
- title_generated_audio_path = inference_by_model(
236
- text=tagged_text_title,
237
- audio_file=prosody_file_title,
238
- save_path=title_save_path
239
  )
240
- # os.remove(prosody_file_title)
241
 
242
  for scene in chapter.scenes:
243
- await download_scene_files(scene)
244
  scene_dir = chapter_dir / scene.sceneId
245
  scene_dir.mkdir(exist_ok=True)
246
 
247
- # --- Sentences audio ---
248
  for sentence in scene.sentences:
249
- # Download the prosody reference audio from Supabase
250
- prosody_file = download_cache[sentence.prosodyReference]
251
  sentence_save_path = scene_dir / f"{sentence.sentenceId}.wav"
252
  tagged_text = generate_tagged_text(
253
- sentence.sentence,
254
- sentence.emotion,
255
- sentence.intensity
256
- )
257
- sentence_generated_audio_path = inference_by_model(
258
- text=tagged_text,
259
- audio_file=prosody_file,
260
- save_path=sentence_save_path
261
  )
262
- # os.remove(prosody_file)
263
-
264
- #_______________ Concatenating the generated audios to make the final story (post-processing)_______________________
265
-
266
- from pydub import AudioSegment
267
- import os
268
- import subprocess
269
 
 
 
 
270
  def ensure_wav(file_path: str) -> str:
271
- """
272
- Convert a single audio file to WAV using ffmpeg.
273
- Returns the path to the WAV file.
274
- If the file is already WAV, returns the original path.
275
- """
276
  ext = os.path.splitext(file_path)[1].lower()
277
-
278
  if ext == ".wav":
279
- return file_path # Already WAV
280
-
281
- # Output path: same folder, same name, .wav extension
282
  wav_path = os.path.splitext(file_path)[0] + ".wav"
283
-
284
- # Run ffmpeg conversion
285
  subprocess.run(["ffmpeg", "-y", "-i", file_path, wav_path], check=True)
286
-
287
- print(f"Converted: {file_path} → {wav_path}")
288
  return wav_path
289
 
290
- from pydub import AudioSegment
291
- import asyncio
292
-
293
- async def concat_story_audio(story: StoryCreationDTO, base_output: str, final_path: str = None): # full path including filename
294
  story_dir = Path(base_output) / story.storyId
295
  story_dir.mkdir(parents=True, exist_ok=True)
296
-
297
  if final_path is None:
298
  final_path = story_dir / f"{story.storyId}_full.wav"
299
  else:
300
  final_path = Path(final_path)
301
- final_path.parent.mkdir(parents=True, exist_ok=True) # ensure folder exists
302
 
303
- chapters_audio = AudioSegment.silent(duration=0) # start empty
304
 
305
  for chapter in story.chapters:
306
  chapter_dir = story_dir / chapter.chapterId
307
-
308
- # --- Chapter title ---
309
- title_path = chapter_dir / "title.wav"
310
- chapter_audio = AudioSegment.from_wav(title_path)
311
 
312
  for scene in chapter.scenes:
313
  scene_dir = chapter_dir / scene.sceneId
314
  scene_audio = AudioSegment.silent(duration=0)
315
 
316
- # --- Concatenate sentence audios ---
317
  for sentence in scene.sentences:
318
  sentence_path = scene_dir / f"{sentence.sentenceId}.wav"
319
- sentence_audio = AudioSegment.from_wav(sentence_path)
320
- scene_audio += sentence_audio
321
 
322
- # --- Add SFX for location if available ---
323
  if scene.location.path:
324
- sfx_file = await download_file_from_url(scene.location.path)
325
- if sfx_file:
326
- sfx_file_wav = ensure_wav(sfx_file)
327
- sfx_audio = AudioSegment.from_wav(sfx_file_wav)
328
- scene_audio = scene_audio.overlay(sfx_audio)
329
- # os.remove(sfx_file)
330
- #else:
331
- #print(f"SFX skipped for {scene.location.locationName}")
332
-
333
- # --- Add background music if available ---
334
  if scene.bgMusic and scene.bgMusic.musicPath:
335
- bg_url = scene.bgMusic.musicPath
336
- bg_file = await download_file_from_url(bg_url)
337
  bg_file_wav = ensure_wav(bg_file)
338
  bg_audio = AudioSegment.from_file(bg_file_wav)
339
-
340
- # Adjust volume
341
- bg_audio = bg_audio - (1 - scene.bgMusic.volume) * 30 # approximate
342
- # Loop if shorter than scene
343
  if len(bg_audio) < len(scene_audio):
344
- loops = (len(scene_audio) // len(bg_audio)) + 1
345
- bg_audio = bg_audio * loops
346
- bg_audio = bg_audio[:len(scene_audio)] # trim to match scene
347
  scene_audio = scene_audio.overlay(bg_audio)
348
- # os.remove(bg_file)
349
 
350
- # Add 2 seconds of silence between scenes
351
  scene_audio += AudioSegment.silent(duration=2000)
352
  chapter_audio += scene_audio
353
 
354
- # Add 3 seconds of silence between chapters
355
  chapter_audio += AudioSegment.silent(duration=3000)
356
  chapters_audio += chapter_audio
357
 
358
- # Export final story
359
  chapters_audio.export(final_path, format="wav")
360
  return final_path
361
 
362
- #-------------------------------------------------------------
363
-
 
364
  app = FastAPI(title="EGTTS Arabic TTS API")
365
-
366
  tasks = {}
367
 
368
- #___________________Test end point to test supabase fetch
369
-
370
- from fastapi import Query
371
- from fastapi.responses import Response
372
-
373
- @app.get("/test-download/")
374
- async def test_download(url: str = Query(...)):
375
- try:
376
- file_bytes = await download_file_from_url(url)
377
-
378
- return Response(
379
- content=file_bytes,
380
- media_type="audio/wav" # change if needed
381
- )
382
-
383
- except Exception as e:
384
- return {"error": str(e)}
385
- #_________________________________________
386
-
387
- @app.get("/")
388
- def root():
389
- return {"message": "Welcome! Visit /docs for Swagger UI."}
390
-
391
- #-----------------------------------------------------------
392
-
393
- class TTSResponse(BaseModel):
394
- fileName: str
395
- duration: float # seconds
396
- audioPath: str
397
-
398
- #---------------------------concatenate text with tags ---------------------------
399
-
400
- # Map Intensity numbers to tag strings
401
- intensity_map = {
402
- "LOW": "low",
403
- "MEDIUM": "mid",
404
- "HIGH": "high"
405
- }
406
-
407
- # Map Emotion enum names to lowercase tag strings
408
- emotion_map = {
409
- "HAPPINESS": "happiness",
410
- "SADNESS": "sadness",
411
- "FEAR": "fear",
412
- "ANGER": "anger",
413
- "SURPRISE": "surprise",
414
- "WHISPER": "whisper",
415
- "NARRATION": "narration"
416
- }
417
-
418
- def generate_tagged_text(text: str, emotion_enum: str, intensity_enum: str) -> str:
419
- """
420
- Convert enums to <emo_x> <int_y> format and concatenate with text
421
- """
422
- emo_tag = f"<emo_{emotion_map[emotion_enum]}>"
423
- int_tag = f"<int_{intensity_map[intensity_enum]}>"
424
- return f"{emo_tag} {int_tag} {text}"
425
-
426
- #-----------------------------------------------------------
427
-
428
- #-----------------Post End Point_____________________________
429
-
430
- # @app.post("/tts/")
431
- # async def process_story(story: StoryCreationDTO):
432
-
433
- # # Optional: print info for debugging
434
- # print(story.storyId)
435
- # for cast in story.cast:
436
- # print(cast.name, cast.voiceReference)
437
- # for chapter in story.chapters:
438
- # for scene in chapter.scenes:
439
- # for sentence in scene.sentences:
440
- # print(sentence.speaker, sentence.sentence)
441
-
442
- # # 1️⃣ Generate all sentence audios and folder structure
443
- # await generate_story_audios(story, base_output=OUTPUT_DIR)
444
-
445
- # # 2️⃣ Concatenate all into final story audio
446
- # final_story_path = os.path.join(OUTPUT_DIR, story.storyId, f"{story.storyId}_full.wav")
447
- # final_generated_story_path = await concat_story_audio(story, base_output=OUTPUT_DIR, final_path=final_story_path)
448
-
449
- # # Convert to base64 and get duration
450
- # audio_b64, duration = audio_to_base64(final_generated_story_path)
451
-
452
- # response = TTSResponse(
453
- # file_name= os.path.basename(final_generated_story_path),
454
- # duration=duration,
455
- # audio_base64=audio_b64
456
- # )
457
-
458
- # return response
459
-
460
-
461
- # async def run_tts_pipeline(task_id: str, story: StoryCreationDTO):
462
- # try:
463
- # await generate_story_audios(story, base_output=OUTPUT_DIR)
464
-
465
- # final_story_path = os.path.join(
466
- # OUTPUT_DIR,
467
- # story.storyId,
468
- # f"{story.storyId}_full.wav"
469
- # )
470
-
471
- # final_generated_story_path = await concat_story_audio(
472
- # story,
473
- # base_output=OUTPUT_DIR,
474
- # final_path=final_story_path
475
- # )
476
-
477
- # audio_b64, duration = audio_to_base64(final_generated_story_path)
478
-
479
- # tasks[task_id] = {
480
- # "status": "completed",
481
- # "result": {
482
- # "fileName": os.path.basename(final_generated_story_path),
483
- # "duration": duration,
484
- # "audioPath": audio_b64
485
- # }
486
- # }
487
-
488
- # except Exception as e:
489
- # print(f"Exception caught at run tts pipeline {str(e)} and status is now failed")
490
- # tasks[task_id] = {
491
- # "status": "failed",
492
- # "error": str(e)
493
- # }
494
-
495
- import os
496
- import uuid
497
- from supabase import create_client, Client
498
- from pydub import AudioSegment # For duration in seconds
499
-
500
- # Initialize Supabase client
501
- SUPABASE_URL = "https://kvlxvhdgacktsgykyckm.supabase.co/"
502
- SUPABASE_KEY = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6Imt2bHh2aGRnYWNrdHNneWt5Y2ttIiwicm9sZSI6InNlcnZpY2Vfcm9sZSIsImlhdCI6MTc3MTk2MTQ5MSwiZXhwIjoyMDg3NTM3NDkxfQ.tzfHcbzwzctHDDDp3vk4JGz30ajN2szncAV-1wK7_pM"
503
- supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY)
504
-
505
- import time
506
  async def run_tts_pipeline(task_id: str, story: StoryCreationDTO):
507
- start_time = time.time() # start timer
508
  try:
509
- # 1️⃣ Generate story audios
510
- await generate_story_audios(story, base_output=OUTPUT_DIR)
511
-
512
- # 2️⃣ Concatenate final story audio
513
- final_story_path = os.path.join(
514
- OUTPUT_DIR,
515
- story.storyId,
516
- f"{story.storyId}_full.wav"
517
- )
518
-
519
- final_generated_story_path = await concat_story_audio(
520
- story,
521
- base_output=OUTPUT_DIR,
522
- final_path=final_story_path
523
- )
524
-
525
- print(f" final_generated_story_path: {final_generated_story_path}")
526
 
 
527
  wav = AudioSegment.from_wav(final_generated_story_path)
528
  mp3_path = final_generated_story_path.with_suffix(".mp3")
529
  wav.export(mp3_path, format="mp3", bitrate="192k")
530
 
531
- print(f" final_generated_story_path after conversion to mp3: {mp3_path}")
532
-
533
-
534
- # 3️⃣ Calculate duration
535
  audio_segment = AudioSegment.from_file(mp3_path)
536
- duration_seconds = len(audio_segment) / 1000 # pydub gives length in milliseconds
537
-
538
- # 4️⃣ Prepare the file for upload
539
  file_name = f"{uuid.uuid4()}_{os.path.basename(mp3_path)}"
540
  storage_path = f"{story.storyId}/final/{file_name}"
541
-
542
- # with open(final_generated_story_path, "rb") as f:
543
- # file_bytes = f.read()
544
-
545
-
546
-
547
- supabase.storage.from_("story-audio-files").upload(
548
- storage_path,
549
- mp3_path
550
- )
551
-
552
- # 6️⃣ Get public URL
553
  audio_url = supabase.storage.from_("story-audio-files").get_public_url(storage_path)
554
 
555
- # 7️⃣ Update task status with audio URL and duration
556
  tasks[task_id] = {
557
  "status": "completed",
558
  "result": {
@@ -562,112 +356,48 @@ async def run_tts_pipeline(task_id: str, story: StoryCreationDTO):
562
  }
563
  }
564
 
565
- # --- Print processing time ---
566
- end_time = time.time()
567
- elapsed = end_time - start_time
568
  print(f"Story {story.storyId} processed in {elapsed:.2f} seconds")
569
 
570
  except Exception as e:
571
- print(f"exception caught at run tts pipeline {str(e)}")
572
- tasks[task_id] = {
573
- "status": "failed",
574
- "error": str(e)
575
- }
576
-
577
- from fastapi import BackgroundTasks
578
- import uuid
579
 
 
 
 
580
  @app.post("/tts/")
581
  async def process_story(story: StoryCreationDTO, background_tasks: BackgroundTasks):
582
-
583
  task_id = str(uuid.uuid4())
584
-
585
- tasks[task_id] = {
586
- "status": "processing",
587
- "result": None
588
- }
589
-
590
  background_tasks.add_task(run_tts_pipeline, task_id, story)
591
-
592
  return {"task_id": task_id}
593
 
594
- #-----------------------Results Get End Point ______________________________________
595
-
596
- # @app.get("/tts/results/{task_id}")
597
- # async def get_results(task_id: str):
598
-
599
- # if task_id not in tasks:
600
- # return {"status": "not_found"}
601
-
602
- # task = tasks[task_id]
603
-
604
- # if task["status"] == "processing":
605
- # return {"status": "processing"}
606
-
607
- # if task["status"] == "failed":
608
- # return {
609
- # "status": "failed",
610
- # "error": task["error"]
611
- # }
612
-
613
- # return task["result"]
614
-
615
  @app.get("/tts/results/{task_id}")
616
  async def get_results(task_id: str):
617
  if task_id not in tasks:
618
  return {"status": "not_found"}
619
-
620
  task = tasks[task_id]
621
-
622
  if task["status"] == "processing":
623
  return {"status": "processing"}
624
-
625
  if task["status"] == "failed":
626
- return {
627
- "status": "failed",
628
- "error": task.get("error", "Unknown error")
629
- }
630
-
631
- # Ensure result exists and has all required fields
632
- result = task.get("result")
633
- if result and all(k in result for k in ("fileName", "duration", "audioPath")):
634
- #clearing cache
635
- print(f"all fields are available {result}")
636
- for file_path in download_cache.values():
637
- if os.path.exists(file_path):
638
- os.remove(file_path)
639
- download_cache.clear()
640
-
641
- return {"status": "completed", **result}
642
- else:
643
- print(f"missing field {result}")
644
- # If result is missing fields, mark as still processing
645
- return {"status": "processing"}
646
-
647
- #----------------------------Test End Point to test tts inference------------------------------------
648
-
649
- @app.post("/tts_test/")
650
- async def tts_endpoint(
651
- text: str = Form(...),
652
- audio_file: UploadFile = File(...),
653
- emotionName: str = Form(...),
654
- intensity: int = Form(...)
655
- ):
656
-
657
- file_path = os.path.join(OUTPUT_DIR, audio_file.filename)
658
- with open(file_path, "wb") as f:
659
- f.write(await audio_file.read())
660
 
661
- tagged_text = generate_tagged_text(text, emotionName, intensity)
 
 
662
 
663
- output_path = os.path.join(OUTPUT_DIR, "out_test.wav")
664
- output_wav = inference_by_model(tagged_text, file_path,output_path)
665
- return FileResponse(output_wav, media_type="audio/wav", filename="output.wav")
 
 
 
666
 
 
 
 
667
  import uvicorn
668
- uvicorn.run(app, host="0.0.0.0", port=7860)
669
-
670
-
671
- # if __name__ == "__main__":
672
- # import uvicorn
673
- # uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
+ from fastapi import FastAPI, BackgroundTasks, UploadFile, File, Form
2
  from fastapi.responses import FileResponse
 
 
 
3
  from pydantic import BaseModel
4
+ from typing import List
5
  from pathlib import Path
6
+ import os
7
+ import uuid
8
+ import asyncio
9
+ import time
10
+ import httpx
11
+ from supabase import create_client, Client
12
+ import torchaudio
13
+ import torch
14
+ from TTS.tts.models.xtts import Xtts
15
+ from TTS.tts.configs.xtts_config import XttsConfig
16
+ from huggingface_hub import hf_hub_download
17
+ from pydub import AudioSegment
18
+ import subprocess
19
 
20
+ # -----------------------------
21
+ # Paths & Device
22
+ # -----------------------------
23
  OUTPUT_DIR = "outputs"
24
  os.makedirs(OUTPUT_DIR, exist_ok=True)
25
 
26
+ CACHE_DIR = "disk cache"
27
+ os.makedirs(CACHE_DIR, exist_ok=True)
28
+
29
  device = "cuda" if torch.cuda.is_available() else "cpu"
30
 
31
+ # -----------------------------
32
+ # Supabase client
33
+ # -----------------------------
34
+ SUPABASE_URL = "https://kvlxvhdgacktsgykyckm.supabase.co/"
35
+ SUPABASE_KEY = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6Imt2bHh2aGRnYWNrdHNneWt5Y2ttIiwicm9sZSI6InNlcnZpY2Vfcm9sZSIsImlhdCI6MTc3MTk2MTQ5MSwiZXhwIjoyMDg3NTM3NDkxfQ.tzfHcbzwzctHDDDp3vk4JGz30ajN2szncAV-1wK7_pM"
36
+ supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY)
37
 
38
+ # -----------------------------
39
+ # Download cache (memory)
40
+ # -----------------------------
41
+ download_cache = {} # URL -> local path
42
+
43
+ # -----------------------------
44
+ # Helper to get cached file (downloads if missing)
45
+ # -----------------------------
46
+ async def get_cached_file(url: str, subfolder: str) -> str:
47
+ """
48
+ Returns local cached path for URL.
49
+ Downloads and stores in subfolder if missing.
50
+ """
51
+ if url in download_cache:
52
+ return download_cache[url]
53
+
54
+ folder_path = os.path.join(CACHE_DIR, subfolder)
55
+ os.makedirs(folder_path, exist_ok=True)
56
+
57
+ local_path = os.path.join(folder_path, os.path.basename(url))
58
+
59
+ if os.path.exists(local_path):
60
+ download_cache[url] = local_path
61
+ print(f"Found on disk, added to cache: {local_path}")
62
+ return local_path
63
+
64
+ # Download from URL/Supabase
65
+ async with httpx.AsyncClient(timeout=60) as client:
66
+ resp = await client.get(url)
67
+ resp.raise_for_status()
68
+ with open(local_path, "wb") as f:
69
+ f.write(resp.content)
70
+ download_cache[url] = local_path
71
+ print(f"Downloaded and cached: {url} → {local_path}")
72
+ return local_path
73
+
74
+ # -----------------------------
75
+ # Preload all assets from Supabase at startup
76
+ # -----------------------------
77
+ def list_all_files(bucket_name: str):
78
+ """
79
+ Returns list of (file_name, public_url) tuples in bucket
80
+ """
81
+ response = supabase.storage.from_(bucket_name).list()
82
+ files = []
83
+ for f in response:
84
+ url = supabase.storage.from_(bucket_name).get_public_url(f["name"])
85
+ files.append((f["name"], url))
86
+ return files
87
+
88
+ async def download_to_cache(url: str, subfolder: str):
89
+ await get_cached_file(url, subfolder)
90
+
91
+ async def preload_all_assets():
92
+ print("Starting Supabase asset preloading...")
93
+ tasks = []
94
+
95
+ buckets = {
96
+ "voice-actor-files": "prosody",
97
+ "bg-music": "bg_music",
98
+ "location-audio-files": "sfx"
99
+ }
100
+
101
+ for bucket_name, subfolder in buckets.items():
102
+ files = list_all_files(bucket_name)
103
+ for _, url in files:
104
+ tasks.append(download_to_cache(url, subfolder))
105
+
106
+ await asyncio.gather(*tasks)
107
+ print(f"Preloading completed. {len(download_cache)} files cached on disk.")
108
+
109
+ # -----------------------------
110
+ # TTS Model (XTTS)
111
+ # -----------------------------
112
  MODEL_DIR = "my_model"
113
 
114
  config_path = hf_hub_download(
 
129
  cache_dir=MODEL_DIR
130
  )
131
 
 
 
 
 
132
  config = XttsConfig()
133
  config.load_json(config_path)
134
 
135
  model = Xtts.init_from_config(config)
136
  model.load_checkpoint(
137
  config,
138
+ checkpoint_dir=os.path.dirname(model_path),
139
  use_deepspeed=False,
140
+ vocab_path=vocab_path
141
  )
142
  model.to(device)
143
 
144
+ # -----------------------------
145
+ # Enums mapping for TTS tags
146
+ # -----------------------------
147
+ intensity_map = {"LOW": "low", "MEDIUM": "mid", "HIGH": "high"}
148
+ emotion_map = {
149
+ "HAPPINESS": "happiness",
150
+ "SADNESS": "sadness",
151
+ "FEAR": "fear",
152
+ "ANGER": "anger",
153
+ "SURPRISE": "surprise",
154
+ "WHISPER": "whisper",
155
+ "NARRATION": "narration"
156
+ }
157
+
158
+ def generate_tagged_text(text: str, emotion_enum: str, intensity_enum: str) -> str:
159
+ emo_tag = f"<emo_{emotion_map[emotion_enum]}>"
160
+ int_tag = f"<int_{intensity_map[intensity_enum]}>"
161
+ return f"{emo_tag} {int_tag} {text}"
162
 
163
+ # -----------------------------
164
+ # DTO Models
165
+ # -----------------------------
166
  class BGMusicDto(BaseModel):
167
  musicPath: str
168
  emotion: str
169
  volume: float
170
 
 
171
  class SentenceDto(BaseModel):
172
  speaker: str
173
  sentenceId: str
 
179
  class LocationDto(BaseModel):
180
  locationName: str
181
  path: str
182
+
183
  class SceneDto(BaseModel):
184
  sceneId: str
185
  location: LocationDto
 
191
  title: SentenceDto
192
  scenes: List[SceneDto]
193
 
 
194
  class CastDto(BaseModel):
195
  name: str
196
  gender: str
197
  isAdult: bool
198
  voiceReference: str
199
 
 
200
  class StoryCreationDTO(BaseModel):
201
  storyId: str
202
  chapters: List[ChapterDto]
203
  cast: List[CastDto]
204
 
205
+ class TTSResponse(BaseModel):
206
+ fileName: str
207
+ duration: float
208
+ audioPath: str
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
+ # -----------------------------
211
+ # TTS Inference
212
+ # -----------------------------
213
+ def inference_by_model(text: str, audio_file: str, save_path: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
  gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(audio_path=[audio_file])
215
  out = model.inference(
216
  text=text,
 
223
  repetition_penalty=model.config.repetition_penalty,
224
  top_p=model.config.top_p,
225
  )
 
226
  os.makedirs(os.path.dirname(save_path), exist_ok=True)
227
  torchaudio.save(save_path, torch.tensor(out["wav"]).unsqueeze(0), 24000)
228
  return save_path
229
 
230
+ # -----------------------------
231
+ # Generate story audios
232
+ # -----------------------------
233
  async def generate_story_audios(story: StoryCreationDTO, base_output: str):
 
 
 
234
  story_dir = Path(base_output) / story.storyId
235
  story_dir.mkdir(parents=True, exist_ok=True)
236
 
 
238
  chapter_dir = story_dir / chapter.chapterId
239
  chapter_dir.mkdir(exist_ok=True)
240
 
241
+ prosody_file_title = await get_cached_file(chapter.title.prosodyReference, "prosody")
 
242
  title_save_path = chapter_dir / "title.wav"
 
243
  tagged_text_title = generate_tagged_text(
244
+ chapter.title.sentence,
245
+ chapter.title.emotion,
246
+ chapter.title.intensity
 
 
 
 
 
 
247
  )
248
+ inference_by_model(tagged_text_title, prosody_file_title, str(title_save_path))
249
 
250
  for scene in chapter.scenes:
 
251
  scene_dir = chapter_dir / scene.sceneId
252
  scene_dir.mkdir(exist_ok=True)
253
 
 
254
  for sentence in scene.sentences:
255
+ prosody_file = await get_cached_file(sentence.prosodyReference, "prosody")
 
256
  sentence_save_path = scene_dir / f"{sentence.sentenceId}.wav"
257
  tagged_text = generate_tagged_text(
258
+ sentence.sentence,
259
+ sentence.emotion,
260
+ sentence.intensity
 
 
 
 
 
261
  )
262
+ inference_by_model(tagged_text, prosody_file, str(sentence_save_path))
 
 
 
 
 
 
263
 
264
+ # -----------------------------
265
+ # Concatenate audio
266
+ # -----------------------------
267
  def ensure_wav(file_path: str) -> str:
 
 
 
 
 
268
  ext = os.path.splitext(file_path)[1].lower()
 
269
  if ext == ".wav":
270
+ return file_path
 
 
271
  wav_path = os.path.splitext(file_path)[0] + ".wav"
 
 
272
  subprocess.run(["ffmpeg", "-y", "-i", file_path, wav_path], check=True)
 
 
273
  return wav_path
274
 
275
+ async def concat_story_audio(story: StoryCreationDTO, base_output: str, final_path: str = None):
 
 
 
276
  story_dir = Path(base_output) / story.storyId
277
  story_dir.mkdir(parents=True, exist_ok=True)
 
278
  if final_path is None:
279
  final_path = story_dir / f"{story.storyId}_full.wav"
280
  else:
281
  final_path = Path(final_path)
282
+ final_path.parent.mkdir(parents=True, exist_ok=True)
283
 
284
+ chapters_audio = AudioSegment.silent(duration=0)
285
 
286
  for chapter in story.chapters:
287
  chapter_dir = story_dir / chapter.chapterId
288
+ chapter_audio = AudioSegment.from_wav(chapter_dir / "title.wav")
 
 
 
289
 
290
  for scene in chapter.scenes:
291
  scene_dir = chapter_dir / scene.sceneId
292
  scene_audio = AudioSegment.silent(duration=0)
293
 
 
294
  for sentence in scene.sentences:
295
  sentence_path = scene_dir / f"{sentence.sentenceId}.wav"
296
+ scene_audio += AudioSegment.from_wav(sentence_path)
 
297
 
 
298
  if scene.location.path:
299
+ sfx_file = await get_cached_file(scene.location.path, "sfx")
300
+ sfx_file_wav = ensure_wav(sfx_file)
301
+ scene_audio = scene_audio.overlay(AudioSegment.from_wav(sfx_file_wav))
302
+
 
 
 
 
 
 
303
  if scene.bgMusic and scene.bgMusic.musicPath:
304
+ bg_file = await get_cached_file(scene.bgMusic.musicPath, "bg_music")
 
305
  bg_file_wav = ensure_wav(bg_file)
306
  bg_audio = AudioSegment.from_file(bg_file_wav)
307
+ bg_audio = bg_audio - (1 - scene.bgMusic.volume) * 30
 
 
 
308
  if len(bg_audio) < len(scene_audio):
309
+ bg_audio = bg_audio * ((len(scene_audio) // len(bg_audio)) + 1)
310
+ bg_audio = bg_audio[:len(scene_audio)]
 
311
  scene_audio = scene_audio.overlay(bg_audio)
 
312
 
 
313
  scene_audio += AudioSegment.silent(duration=2000)
314
  chapter_audio += scene_audio
315
 
 
316
  chapter_audio += AudioSegment.silent(duration=3000)
317
  chapters_audio += chapter_audio
318
 
 
319
  chapters_audio.export(final_path, format="wav")
320
  return final_path
321
 
322
+ # -----------------------------
323
+ # FastAPI app & tasks
324
+ # -----------------------------
325
  app = FastAPI(title="EGTTS Arabic TTS API")
 
326
  tasks = {}
327
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
328
  async def run_tts_pipeline(task_id: str, story: StoryCreationDTO):
329
+ start_time = time.time()
330
  try:
331
+ print(f"Starting story: {story.storyId}")
332
+ await generate_story_audios(story, OUTPUT_DIR)
333
+ final_wav_path = Path(OUTPUT_DIR) / story.storyId / f"{story.storyId}_full.wav"
334
+ final_generated_story_path = await concat_story_audio(story, OUTPUT_DIR, final_path=str(final_wav_path))
 
 
 
 
 
 
 
 
 
 
 
 
 
335
 
336
+ # Convert to mp3
337
  wav = AudioSegment.from_wav(final_generated_story_path)
338
  mp3_path = final_generated_story_path.with_suffix(".mp3")
339
  wav.export(mp3_path, format="mp3", bitrate="192k")
340
 
 
 
 
 
341
  audio_segment = AudioSegment.from_file(mp3_path)
342
+ duration_seconds = len(audio_segment) / 1000
343
+
344
+ # Upload final story
345
  file_name = f"{uuid.uuid4()}_{os.path.basename(mp3_path)}"
346
  storage_path = f"{story.storyId}/final/{file_name}"
347
+ supabase.storage.from_("story-audio-files").upload(storage_path, mp3_path)
 
 
 
 
 
 
 
 
 
 
 
348
  audio_url = supabase.storage.from_("story-audio-files").get_public_url(storage_path)
349
 
 
350
  tasks[task_id] = {
351
  "status": "completed",
352
  "result": {
 
356
  }
357
  }
358
 
359
+ elapsed = time.time() - start_time
 
 
360
  print(f"Story {story.storyId} processed in {elapsed:.2f} seconds")
361
 
362
  except Exception as e:
363
+ tasks[task_id] = {"status": "failed", "error": str(e)}
364
+ print(f"Exception for story {story.storyId}: {e}")
 
 
 
 
 
 
365
 
366
+ # -----------------------------
367
+ # FastAPI endpoints
368
+ # -----------------------------
369
  @app.post("/tts/")
370
  async def process_story(story: StoryCreationDTO, background_tasks: BackgroundTasks):
 
371
  task_id = str(uuid.uuid4())
372
+ tasks[task_id] = {"status": "processing", "result": None}
 
 
 
 
 
373
  background_tasks.add_task(run_tts_pipeline, task_id, story)
 
374
  return {"task_id": task_id}
375
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
376
  @app.get("/tts/results/{task_id}")
377
  async def get_results(task_id: str):
378
  if task_id not in tasks:
379
  return {"status": "not_found"}
 
380
  task = tasks[task_id]
 
381
  if task["status"] == "processing":
382
  return {"status": "processing"}
 
383
  if task["status"] == "failed":
384
+ return {"status": "failed", "error": task.get("error", "Unknown error")}
385
+ return {"status": "completed", **task["result"]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
386
 
387
+ @app.get("/")
388
+ def root():
389
+ return {"message": "Welcome! Visit /docs for Swagger UI."}
390
 
391
+ # -----------------------------
392
+ # Startup preload event
393
+ # -----------------------------
394
+ @app.on_event("startup")
395
+ async def startup_event():
396
+ await preload_all_assets()
397
 
398
+ # -----------------------------
399
+ # Run app
400
+ # -----------------------------
401
  import uvicorn
402
+ if __name__ == "__main__":
403
+ uvicorn.run(app, host="0.0.0.0", port=7860)