don0726 commited on
Commit
7d819f5
Β·
verified Β·
1 Parent(s): ba8cf88

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -45
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  import uuid
3
  import torch
@@ -15,24 +16,16 @@ from collections import deque
15
  os.environ["COQUI_TOS_AGREED"] = "1"
16
 
17
  # =========================
18
- # πŸ”₯ DEVICE
19
  # =========================
20
  device = "cuda" if torch.cuda.is_available() else "cpu"
21
 
22
- # =========================
23
- # πŸ”₯ MULTI-CHANNEL MODEL POOL
24
- # =========================
25
- NUM_CHANNELS = 5
26
- print(f"πŸš€ Loading {NUM_CHANNELS} XTTS channels...")
27
-
28
- channels = []
29
- for i in range(NUM_CHANNELS):
30
- tts_model = TTS(
31
- model_name="tts_models/multilingual/multi-dataset/xtts_v2",
32
- progress_bar=False
33
- ).to(device)
34
- channels.append({"tts": tts_model, "busy": False})
35
- print(f"βœ… Channel {i+1} loaded")
36
 
37
  # =========================
38
  # πŸ“ OUTPUT DIR
@@ -41,41 +34,48 @@ OUTPUT_DIR = "outputs"
41
  os.makedirs(OUTPUT_DIR, exist_ok=True)
42
 
43
  # =========================
44
- # πŸ”Ή REQUEST QUEUE
45
  # =========================
 
 
 
46
  request_queue = deque()
47
 
48
  # =========================
49
- # πŸ”₯ MULTI-CHANNEL WORKER
50
  # =========================
51
- async def channel_worker(channel):
52
- """Worker that processes requests assigned to this channel."""
 
53
  while True:
54
  if len(request_queue) == 0:
55
  await asyncio.sleep(0.1)
56
  continue
57
 
58
- if channel["busy"]:
59
- await asyncio.sleep(0.05)
60
- continue
 
 
 
 
 
61
 
62
- # Pick next request
63
- text, lang, audio_path, output_path, future = request_queue.popleft()
64
- channel["busy"] = True
65
 
66
- try:
67
- channel["tts"].tts_to_file(
68
- text=text,
69
- speaker_wav=audio_path,
70
- language=lang,
71
- file_path=output_path,
72
- split_sentences=True
73
- )
74
- future.set_result(output_path)
75
- except Exception as e:
76
- future.set_result(str(e))
77
- finally:
78
- channel["busy"] = False
79
 
80
 
81
  # =========================
@@ -85,10 +85,7 @@ api = FastAPI()
85
 
86
  @api.on_event("startup")
87
  async def startup_event():
88
- # Start a worker for each channel
89
- for ch in channels:
90
- asyncio.create_task(channel_worker(ch))
91
- print(f"πŸ”₯ {NUM_CHANNELS} channel workers started!")
92
 
93
 
94
  @api.post("/clone-voice/")
@@ -107,7 +104,6 @@ async def clone_voice_api(
107
  loop = asyncio.get_event_loop()
108
  future = loop.create_future()
109
 
110
- # Add to request queue
111
  request_queue.append((text, language, input_path, output_path, future))
112
 
113
  result = await future
@@ -137,6 +133,7 @@ async def clone_voice_ui(audio_path, text, language):
137
  future = loop.create_future()
138
 
139
  request_queue.append((text, language, audio_path, output_path, future))
 
140
  result = await future
141
 
142
  if isinstance(result, str) and result.endswith(".wav"):
@@ -145,8 +142,8 @@ async def clone_voice_ui(audio_path, text, language):
145
  return f"❌ {result}", None
146
 
147
 
148
- with gr.Blocks(title="XTTS Voice Cloning (Multi-Channel)") as demo:
149
- gr.Markdown("# 🎀 XTTS Voice Cloning (Multi-Channel)")
150
 
151
  audio_input = gr.Audio(type="filepath", label="Speaker Audio")
152
  text_input = gr.Textbox(label="Text")
@@ -163,6 +160,7 @@ with gr.Blocks(title="XTTS Voice Cloning (Multi-Channel)") as demo:
163
  outputs=[status, output_audio]
164
  )
165
 
 
166
  demo.queue(max_size=20)
167
 
168
  # =========================
 
1
+
2
  import os
3
  import uuid
4
  import torch
 
16
  os.environ["COQUI_TOS_AGREED"] = "1"
17
 
18
  # =========================
19
+ # πŸ”₯ LOAD MODEL ONCE
20
  # =========================
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
 
23
+ print("πŸš€ Loading XTTS model...")
24
+ tts = TTS(
25
+ model_name="tts_models/multilingual/multi-dataset/xtts_v2",
26
+ progress_bar=False
27
+ ).to(device)
28
+ print("βœ… Model loaded!")
 
 
 
 
 
 
 
 
29
 
30
  # =========================
31
  # πŸ“ OUTPUT DIR
 
34
  os.makedirs(OUTPUT_DIR, exist_ok=True)
35
 
36
  # =========================
37
+ # ⚑ BATCH CONFIG
38
  # =========================
39
+ BATCH_SIZE = 3
40
+ BATCH_WAIT_TIME = 1 # seconds
41
+
42
  request_queue = deque()
43
 
44
  # =========================
45
+ # πŸ”₯ BATCH WORKER
46
  # =========================
47
+ async def batch_worker():
48
+ print("πŸ”₯ Batch worker started...")
49
+
50
  while True:
51
  if len(request_queue) == 0:
52
  await asyncio.sleep(0.1)
53
  continue
54
 
55
+ # Wait to collect batch
56
+ await asyncio.sleep(BATCH_WAIT_TIME)
57
+
58
+ batch = []
59
+ while len(request_queue) > 0 and len(batch) < BATCH_SIZE:
60
+ batch.append(request_queue.popleft())
61
+
62
+ print(f"⚑ Processing batch of {len(batch)}")
63
 
64
+ for item in batch:
65
+ text, lang, audio_path, output_path, future = item
 
66
 
67
+ try:
68
+ tts.tts_to_file(
69
+ text=text,
70
+ speaker_wav=audio_path,
71
+ language=lang,
72
+ file_path=output_path,
73
+ split_sentences=True
74
+ )
75
+ future.set_result(output_path)
76
+
77
+ except Exception as e:
78
+ future.set_result(str(e))
 
79
 
80
 
81
  # =========================
 
85
 
86
  @api.on_event("startup")
87
  async def startup_event():
88
+ asyncio.create_task(batch_worker())
 
 
 
89
 
90
 
91
  @api.post("/clone-voice/")
 
104
  loop = asyncio.get_event_loop()
105
  future = loop.create_future()
106
 
 
107
  request_queue.append((text, language, input_path, output_path, future))
108
 
109
  result = await future
 
133
  future = loop.create_future()
134
 
135
  request_queue.append((text, language, audio_path, output_path, future))
136
+
137
  result = await future
138
 
139
  if isinstance(result, str) and result.endswith(".wav"):
 
142
  return f"❌ {result}", None
143
 
144
 
145
+ with gr.Blocks(title="XTTS Voice Cloning (Batching)") as demo:
146
+ gr.Markdown("# 🎀 XTTS Voice Cloning (Batch Mode)")
147
 
148
  audio_input = gr.Audio(type="filepath", label="Speaker Audio")
149
  text_input = gr.Textbox(label="Text")
 
160
  outputs=[status, output_audio]
161
  )
162
 
163
+ # βœ… FIXED QUEUE (no concurrency_count)
164
  demo.queue(max_size=20)
165
 
166
  # =========================