voice-tools / src /lib /memory_optimizer.py
jcudit's picture
jcudit HF Staff
fix: also correct lib/ in gitignore to only exclude root-level, add src/lib package
3ff2f18
"""
Memory optimization utilities.
Provides utilities for processing large audio files (>1 hour) efficiently
without running out of memory.
"""
import gc
import logging
from pathlib import Path
from typing import Iterator, List, Optional, Tuple
import numpy as np
from src.lib.audio_io import AudioIOError, read_audio
logger = logging.getLogger(__name__)
class AudioChunker:
"""
Utility for processing large audio files in chunks.
Allows processing audio files that are too large to fit in memory
by streaming them in manageable chunks.
"""
def __init__(self, chunk_duration: float = 60.0, overlap: float = 5.0):
"""
Initialize audio chunker.
Args:
chunk_duration: Duration of each chunk in seconds (default: 60s)
overlap: Overlap between chunks in seconds (default: 5s)
"""
self.chunk_duration = chunk_duration
self.overlap = overlap
logger.debug(f"AudioChunker initialized (chunk: {chunk_duration}s, overlap: {overlap}s)")
def iter_chunks(
self, file_path: str, target_sr: int = 16000
) -> Iterator[Tuple[np.ndarray, int, float, float]]:
"""
Iterate over audio file in chunks.
Args:
file_path: Path to audio file
target_sr: Target sample rate
Yields:
Tuples of (audio_chunk, sample_rate, start_time, end_time)
Raises:
AudioIOError: If file cannot be read
"""
try:
# Read full audio (we'll optimize this for truly large files later)
audio, sr = read_audio(file_path, target_sr=target_sr)
total_duration = len(audio) / sr
logger.info(
f"Processing {Path(file_path).name} in chunks "
f"(duration: {total_duration:.1f}s, chunk size: {self.chunk_duration}s)"
)
# Calculate chunk parameters
chunk_samples = int(self.chunk_duration * sr)
overlap_samples = int(self.overlap * sr)
step_samples = chunk_samples - overlap_samples
position = 0
chunk_idx = 0
while position < len(audio):
# Extract chunk
chunk_start = position
chunk_end = min(position + chunk_samples, len(audio))
chunk = audio[chunk_start:chunk_end]
# Calculate time boundaries
start_time = chunk_start / sr
end_time = chunk_end / sr
logger.debug(
f"Chunk {chunk_idx}: {start_time:.1f}s - {end_time:.1f}s "
f"({len(chunk) / sr:.1f}s)"
)
yield chunk, sr, start_time, end_time
# Move to next chunk
position += step_samples
chunk_idx += 1
# Force garbage collection between chunks
gc.collect()
logger.info(f"Processed {chunk_idx} chunks")
except Exception as e:
logger.error(f"Failed to process chunks: {e}")
raise AudioIOError(f"Chunking failed: {e}")
def process_file_in_chunks(
self, file_path: str, processor_func, target_sr: int = 16000, **processor_kwargs
) -> List:
"""
Process audio file in chunks with custom processor function.
Args:
file_path: Path to audio file
processor_func: Function to process each chunk
Should accept (audio, sr, start_time, end_time, **kwargs)
target_sr: Target sample rate
**processor_kwargs: Additional arguments for processor function
Returns:
List of processing results from each chunk
Example:
>>> def detect_segments(audio, sr, start_time, end_time):
... # Process audio chunk
... return segments
>>>
>>> chunker = AudioChunker(chunk_duration=60.0)
>>> results = chunker.process_file_in_chunks(
... "long_file.m4a",
... detect_segments
... )
"""
results = []
for chunk, sr, start_time, end_time in self.iter_chunks(file_path, target_sr):
try:
result = processor_func(chunk, sr, start_time, end_time, **processor_kwargs)
results.append(result)
except Exception as e:
logger.error(f"Chunk processing failed at {start_time:.1f}s: {e}")
# Continue with next chunk
continue
return results
class MemoryMonitor:
"""
Monitor and manage memory usage during processing.
"""
def __init__(self, max_memory_mb: Optional[float] = None):
"""
Initialize memory monitor.
Args:
max_memory_mb: Maximum memory usage in MB (None = no limit)
"""
self.max_memory_mb = max_memory_mb
try:
import os
import psutil
self.process = psutil.Process(os.getpid())
self.psutil_available = True
except ImportError:
logger.warning("psutil not available, memory monitoring disabled")
self.psutil_available = False
def get_current_memory_mb(self) -> float:
"""
Get current memory usage in MB.
Returns:
Memory usage in MB, or 0 if unavailable
"""
if not self.psutil_available:
return 0.0
try:
return self.process.memory_info().rss / 1024 / 1024
except Exception:
return 0.0
def check_memory_limit(self) -> bool:
"""
Check if memory usage is below limit.
Returns:
True if within limit (or no limit set), False if exceeded
"""
if self.max_memory_mb is None:
return True
current_mb = self.get_current_memory_mb()
if current_mb > self.max_memory_mb:
logger.warning(
f"Memory limit exceeded: {current_mb:.1f}MB > {self.max_memory_mb:.1f}MB"
)
return False
return True
def force_cleanup(self):
"""Force garbage collection and cleanup."""
gc.collect()
if self.psutil_available:
try:
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.debug("Cleared CUDA cache")
except ImportError:
pass
logger.debug("Forced garbage collection")
def optimize_for_large_files(audio_duration: float) -> dict:
"""
Get optimization recommendations for large files.
Args:
audio_duration: Duration of audio file in seconds
Returns:
Dictionary with optimization parameters
"""
# Thresholds
LARGE_FILE_THRESHOLD = 3600 # 1 hour
VERY_LARGE_FILE_THRESHOLD = 7200 # 2 hours
config = {
"use_chunking": False,
"chunk_duration": 60.0,
"chunk_overlap": 5.0,
"force_gc_frequency": 10, # Force GC every N chunks
"recommended_batch_size": 32,
}
if audio_duration > VERY_LARGE_FILE_THRESHOLD:
# Very large file (>2 hours)
config.update(
{
"use_chunking": True,
"chunk_duration": 30.0, # Smaller chunks
"chunk_overlap": 3.0,
"force_gc_frequency": 5, # More frequent GC
"recommended_batch_size": 16, # Smaller batches
}
)
logger.info(
f"Large file detected ({audio_duration / 3600:.1f}h), "
"using aggressive memory optimization"
)
elif audio_duration > LARGE_FILE_THRESHOLD:
# Large file (>1 hour)
config.update(
{
"use_chunking": True,
"chunk_duration": 60.0,
"chunk_overlap": 5.0,
"force_gc_frequency": 10,
"recommended_batch_size": 24,
}
)
logger.info(
f"Large file detected ({audio_duration / 3600:.1f}h), using memory optimization"
)
return config
def estimate_memory_requirements(
audio_duration: float, sample_rate: int = 16000, num_models: int = 3, safety_factor: float = 2.0
) -> float:
"""
Estimate memory requirements for processing.
Args:
audio_duration: Duration in seconds
sample_rate: Sample rate in Hz
num_models: Number of ML models to load
safety_factor: Safety multiplier (default: 2.0)
Returns:
Estimated memory requirement in MB
"""
# Audio data (float32 = 4 bytes)
audio_mb = (audio_duration * sample_rate * 4) / 1024 / 1024
# Model overhead (rough estimate)
model_mb = num_models * 500 # ~500MB per model
# Processing overhead
processing_mb = audio_mb * 2 # Intermediate buffers, embeddings, etc.
total_mb = (audio_mb + model_mb + processing_mb) * safety_factor
logger.debug(
f"Estimated memory: audio={audio_mb:.1f}MB, "
f"models={model_mb:.1f}MB, processing={processing_mb:.1f}MB, "
f"total={total_mb:.1f}MB (with {safety_factor}x safety factor)"
)
return total_mb