voice-tools / src /models /audio_segment.py
jcudit's picture
jcudit HF Staff
feat: implement cross-mode robustness fixes (phases 1-8)
95e1515
"""
Audio Segment Model
Represents a contiguous portion of audio with speaker and timing information.
"""
from dataclasses import dataclass
from enum import Enum
from typing import List, Optional
class SegmentType(Enum):
"""Classification of audio segment types."""
SPEECH = "speech"
NONVERBAL = "nonverbal"
SILENCE = "silence"
OVERLAP = "overlap" # Multiple speakers talking simultaneously
@dataclass
class AudioSegment:
"""
Audio segment with time range and speaker information.
IMPORTANT: This model stores ONLY metadata (timestamps, speaker info, classification).
Audio data is NEVER stored in AudioSegment instances. Audio is extracted on-demand
from source files using the stored timestamps during concatenation or processing.
This metadata-only design enables memory-efficient processing of large audio files
(>1 hour) by avoiding storage of thousands of audio arrays in memory.
Attributes:
start_time: Beginning timestamp in seconds
end_time: Ending timestamp in seconds
speaker_id: Identifier of the speaker in this segment
confidence: Certainty of speaker identification (0.0-1.0)
segment_type: Classification of the segment
audio_file: Path to the source audio file (optional, for reference only)
Usage Pattern:
# Create segment with metadata only
segment = AudioSegment(
start_time=10.5,
end_time=15.3,
speaker_id="speaker_00",
confidence=0.95,
segment_type=SegmentType.SPEECH
)
# Extract audio on-demand when needed
start_sample = int(segment.start_time * sample_rate)
end_sample = int(segment.end_time * sample_rate)
segment_audio = source_audio[start_sample:end_sample]
"""
start_time: float
end_time: float
speaker_id: str
confidence: float = 1.0
segment_type: SegmentType = SegmentType.SPEECH
audio_file: Optional[str] = None
def __post_init__(self):
"""Validate audio segment data."""
if self.start_time < 0:
raise ValueError(f"Start time cannot be negative: {self.start_time}")
if self.end_time <= self.start_time:
raise ValueError(
f"End time ({self.end_time}) must be after start time ({self.start_time})"
)
if not 0.0 <= self.confidence <= 1.0:
raise ValueError(f"Confidence must be between 0.0 and 1.0, got {self.confidence}")
# Ensure no audio data is accidentally stored (metadata-only enforcement)
if hasattr(self, "audio") or "audio" in self.__dict__:
raise ValueError(
"AudioSegment must not contain 'audio' attribute. "
"Audio data should be extracted on-demand using timestamps."
)
@property
def duration(self) -> float:
"""Calculate duration of the segment in seconds."""
return self.end_time - self.start_time
def overlaps_with(self, other: "AudioSegment") -> bool:
"""Check if this segment overlaps with another segment."""
return not (self.end_time <= other.start_time or other.end_time <= self.start_time)
def contains_time(self, time: float) -> bool:
"""Check if a timestamp falls within this segment."""
return self.start_time <= time <= self.end_time
def __repr__(self) -> str:
return (
f"AudioSegment("
f"speaker='{self.speaker_id}', "
f"time={self.start_time:.2f}-{self.end_time:.2f}s, "
f"duration={self.duration:.2f}s, "
f"confidence={self.confidence:.2f}, "
f"type={self.segment_type.value})"
)
class SegmentCollection:
"""
Collection of audio segments with utility methods.
Provides methods for filtering, sorting, and analyzing groups of segments.
"""
def __init__(self, segments: List[AudioSegment]):
"""Initialize collection with segments."""
self.segments = segments
def __len__(self) -> int:
"""Return number of segments."""
return len(self.segments)
def __iter__(self):
"""Iterate over segments."""
return iter(self.segments)
def __getitem__(self, index):
"""Get segment by index."""
return self.segments[index]
@property
def total_duration(self) -> float:
"""Calculate total duration of all segments."""
return sum(seg.duration for seg in self.segments)
def filter_by_speaker(self, speaker_id: str) -> "SegmentCollection":
"""Filter segments by speaker ID."""
filtered = [seg for seg in self.segments if seg.speaker_id == speaker_id]
return SegmentCollection(filtered)
def filter_by_type(self, segment_type: SegmentType) -> "SegmentCollection":
"""Filter segments by type."""
filtered = [seg for seg in self.segments if seg.segment_type == segment_type]
return SegmentCollection(filtered)
def sort_by_time(self) -> "SegmentCollection":
"""Sort segments by start time."""
sorted_segments = sorted(self.segments, key=lambda s: s.start_time)
return SegmentCollection(sorted_segments)
def get_speakers(self) -> List[str]:
"""Get unique list of speaker IDs."""
return list(set(seg.speaker_id for seg in self.segments))
def average_confidence(self) -> float:
"""Calculate average confidence across all segments."""
if not self.segments:
return 0.0
return sum(seg.confidence for seg in self.segments) / len(self.segments)