hivecorp commited on
Commit
a63a008
·
verified ·
1 Parent(s): 5b5a1cd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -46
app.py CHANGED
@@ -11,6 +11,10 @@ from concurrent.futures import ThreadPoolExecutor
11
  from typing import List, Tuple, Optional, Dict, Any
12
  import math
13
  from dataclasses import dataclass
 
 
 
 
14
 
15
  class TimingManager:
16
  def __init__(self):
@@ -189,8 +193,27 @@ class TTSError(Exception):
189
  """Custom exception for TTS processing errors"""
190
  pass
191
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  async def process_segment_with_timing(segment: Segment, voice: str, rate: str, pitch: str) -> Segment:
193
  """Process a complete segment as a single TTS unit with improved error handling"""
 
 
 
194
  audio_file = os.path.join(tempfile.gettempdir(), f"temp_segment_{segment.id}_{uuid.uuid4()}.wav")
195
  try:
196
  # Process the entire segment text as one unit, replacing newlines with spaces
@@ -207,7 +230,8 @@ async def process_segment_with_timing(segment: Segment, voice: str, rate: str, p
207
 
208
  try:
209
  segment.audio = AudioSegment.from_file(audio_file)
210
- # Reduced silence to 30ms for more natural flow
 
211
  silence = AudioSegment.silent(duration=30)
212
  segment.audio = silence + segment.audio + silence
213
  segment.duration = len(segment.audio)
@@ -215,16 +239,12 @@ async def process_segment_with_timing(segment: Segment, voice: str, rate: str, p
215
  raise TTSError(f"Failed to process audio file for segment {segment.id}: {str(e)}")
216
 
217
  return segment
218
- except Exception as e:
219
- if not isinstance(e, TTSError):
220
- raise TTSError(f"Unexpected error processing segment {segment.id}: {str(e)}")
221
- raise
222
  finally:
223
  if os.path.exists(audio_file):
224
  try:
225
  os.remove(audio_file)
226
  except Exception:
227
- pass # Ignore deletion errors
228
 
229
  # IMPROVEMENT 2: Better File Management with cleanup
230
  class FileManager:
@@ -294,56 +314,45 @@ async def generate_accurate_srt(
294
  lines_per_segment: int,
295
  progress_callback=None,
296
  parallel: bool = True,
297
- max_workers: int = 4
298
  ) -> Tuple[str, str]:
299
- """Generate accurate SRT with parallel processing option"""
300
  processor = TextProcessor(words_per_line, lines_per_segment)
301
  segments = processor.split_into_segments(text)
302
-
303
  total_segments = len(segments)
304
- processed_segments = []
305
 
306
- # Update progress to show segmentation is complete
307
- if progress_callback:
308
- progress_callback(0.1, "Text segmentation complete")
309
 
310
  if parallel and total_segments > 1:
311
- # Process segments in parallel
312
- processed_count = 0
313
- segment_tasks = []
314
-
315
- # Create a semaphore to limit concurrent tasks
316
  semaphore = asyncio.Semaphore(max_workers)
 
 
317
 
318
- async def process_with_semaphore(segment):
319
- async with semaphore:
320
- nonlocal processed_count
321
- try:
322
- result = await process_segment_with_timing(segment, voice, rate, pitch)
323
- processed_count += 1
324
- if progress_callback:
325
- progress = 0.1 + (0.8 * processed_count / total_segments)
326
- progress_callback(progress, f"Processed {processed_count}/{total_segments} segments")
327
- return result
328
- except Exception as e:
329
- # Handle errors in individual segments
330
- processed_count += 1
331
- if progress_callback:
332
- progress = 0.1 + (0.8 * processed_count / total_segments)
333
- progress_callback(progress, f"Error in segment {segment.id}: {str(e)}")
334
- raise
335
-
336
- # Create tasks for all segments
337
- for segment in segments:
338
- segment_tasks.append(process_with_semaphore(segment))
339
-
340
- # Run all tasks and collect results
341
- try:
342
- processed_segments = await asyncio.gather(*segment_tasks)
343
- except Exception as e:
344
  if progress_callback:
345
- progress_callback(0.9, f"Error during parallel processing: {str(e)}")
346
- raise TTSError(f"Failed during parallel processing: {str(e)}")
 
347
  else:
348
  # Process segments sequentially (original method)
349
  for i, segment in enumerate(segments):
@@ -417,6 +426,10 @@ async def generate_accurate_srt(
417
 
418
  return srt_path, audio_path
419
 
 
 
 
 
420
  # IMPROVEMENT 4: Progress Reporting with proper error handling for older Gradio versions
421
  async def process_text_with_progress(
422
  text,
@@ -601,4 +614,10 @@ with gr.Blocks(title="Advanced TTS with Configurable SRT Generation") as app:
601
  )
602
 
603
  if __name__ == "__main__":
 
 
 
 
 
 
604
  app.launch()
 
11
  from typing import List, Tuple, Optional, Dict, Any
12
  import math
13
  from dataclasses import dataclass
14
+ import multiprocessing
15
+ import psutil
16
+ import concurrent.futures
17
+ import gc
18
 
19
  class TimingManager:
20
  def __init__(self):
 
193
  """Custom exception for TTS processing errors"""
194
  pass
195
 
196
+ class ResourceOptimizer:
197
+ @staticmethod
198
+ def get_optimal_workers():
199
+ cpu_count = multiprocessing.cpu_count()
200
+ return max(cpu_count - 1, 1) # Leave one core for system
201
+
202
+ @staticmethod
203
+ def get_memory_limit():
204
+ # Use up to 70% of available RAM
205
+ return int(psutil.virtual_memory().available * 0.7)
206
+
207
+ @staticmethod
208
+ def get_batch_size(total_segments):
209
+ # Calculate optimal batch size based on CPU cores
210
+ return min(total_segments, ResourceOptimizer.get_optimal_workers() * 2)
211
+
212
  async def process_segment_with_timing(segment: Segment, voice: str, rate: str, pitch: str) -> Segment:
213
  """Process a complete segment as a single TTS unit with improved error handling"""
214
+ # Pre-allocate memory for audio processing
215
+ gc.collect() # Force garbage collection before processing
216
+
217
  audio_file = os.path.join(tempfile.gettempdir(), f"temp_segment_{segment.id}_{uuid.uuid4()}.wav")
218
  try:
219
  # Process the entire segment text as one unit, replacing newlines with spaces
 
230
 
231
  try:
232
  segment.audio = AudioSegment.from_file(audio_file)
233
+ # Optimize memory usage for audio processing
234
+ segment.audio = segment.audio.set_channels(1) # Convert to mono for memory efficiency
235
  silence = AudioSegment.silent(duration=30)
236
  segment.audio = silence + segment.audio + silence
237
  segment.duration = len(segment.audio)
 
239
  raise TTSError(f"Failed to process audio file for segment {segment.id}: {str(e)}")
240
 
241
  return segment
 
 
 
 
242
  finally:
243
  if os.path.exists(audio_file):
244
  try:
245
  os.remove(audio_file)
246
  except Exception:
247
+ pass
248
 
249
  # IMPROVEMENT 2: Better File Management with cleanup
250
  class FileManager:
 
314
  lines_per_segment: int,
315
  progress_callback=None,
316
  parallel: bool = True,
317
+ max_workers: Optional[int] = None
318
  ) -> Tuple[str, str]:
319
+ """Generate accurate SRT with optimized resource utilization"""
320
  processor = TextProcessor(words_per_line, lines_per_segment)
321
  segments = processor.split_into_segments(text)
 
322
  total_segments = len(segments)
 
323
 
324
+ # Optimize worker count based on system resources
325
+ if max_workers is None:
326
+ max_workers = ResourceOptimizer.get_optimal_workers()
327
 
328
  if parallel and total_segments > 1:
329
+ # Enhanced parallel processing with resource optimization
330
+ batch_size = ResourceOptimizer.get_batch_size(total_segments)
 
 
 
331
  semaphore = asyncio.Semaphore(max_workers)
332
+ processed_segments = []
333
+ processed_count = 0
334
 
335
+ # Process in batches for better resource utilization
336
+ for i in range(0, total_segments, batch_size):
337
+ batch = segments[i:i + batch_size]
338
+ batch_tasks = []
339
+
340
+ for segment in batch:
341
+ batch_tasks.append(
342
+ process_with_semaphore(segment, voice, rate, pitch, semaphore)
343
+ )
344
+
345
+ # Process batch with maximum resource utilization
346
+ batch_results = await asyncio.gather(*batch_tasks)
347
+ processed_segments.extend(batch_results)
348
+
349
+ # Force garbage collection between batches
350
+ gc.collect()
351
+
 
 
 
 
 
 
 
 
 
352
  if progress_callback:
353
+ processed_count += len(batch)
354
+ progress = 0.1 + (0.8 * processed_count / total_segments)
355
+ progress_callback(progress, f"Processed {processed_count}/{total_segments} segments")
356
  else:
357
  # Process segments sequentially (original method)
358
  for i, segment in enumerate(segments):
 
426
 
427
  return srt_path, audio_path
428
 
429
+ async def process_with_semaphore(segment, voice, rate, pitch, semaphore):
430
+ async with semaphore:
431
+ return await process_segment_with_timing(segment, voice, rate, pitch)
432
+
433
  # IMPROVEMENT 4: Progress Reporting with proper error handling for older Gradio versions
434
  async def process_text_with_progress(
435
  text,
 
614
  )
615
 
616
  if __name__ == "__main__":
617
+ # Set process priority to high
618
+ p = psutil.Process()
619
+ try:
620
+ p.nice(psutil.BELOW_NORMAL_PRIORITY_CLASS if os.name == 'nt' else 10)
621
+ except Exception:
622
+ pass
623
  app.launch()