Spaces:
Sleeping
Sleeping
File size: 4,862 Bytes
4e9a3bc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 | """Segment localization from explainability outputs."""
import logging
import numpy as np
from scipy.signal import find_peaks
from src.reporting.schemas import SegmentAnnotation
logger = logging.getLogger(__name__)
class SegmentLocaliser:
"""Localize time segments from Grad-CAM and attention outputs."""
def __init__(
self,
min_segment_duration_ms: int = 200,
merge_threshold: float = 0.3,
):
"""
Initialize segment localiser.
Args:
min_segment_duration_ms: Minimum segment duration
merge_threshold: Merge segments if gap < threshold (in seconds)
"""
self.min_segment_duration_ms = min_segment_duration_ms
self.merge_threshold = merge_threshold
def localise(
self,
heatmap: np.ndarray,
attention_importance: np.ndarray,
hop_length: int = 512,
sr: int = 16000,
top_k: int = 5,
) -> list[SegmentAnnotation]:
"""
Localise important segments.
Args:
heatmap: Grad-CAM heatmap (freq, time)
attention_importance: Attention rollout importance (seq_len,)
hop_length: Hop length in samples
sr: Sample rate
top_k: Number of segments to return
Returns:
List of SegmentAnnotation objects
"""
# Combine heatmap and attention
# Average heatmap over frequency dimension
heatmap_time = heatmap.mean(axis=0) # (time,)
# Ensure same length
min_len = min(len(heatmap_time), len(attention_importance))
heatmap_time = heatmap_time[:min_len]
attention_importance = attention_importance[:min_len]
# Combined importance score
combined_importance = 0.6 * heatmap_time + 0.4 * attention_importance
combined_importance = combined_importance / (combined_importance.max() + 1e-8)
# Find peaks
peaks, properties = find_peaks(
combined_importance,
height=np.percentile(combined_importance, 60), # Above 60th percentile
distance=int(self.min_segment_duration_ms / 1000 * sr / hop_length),
)
# Sort by importance
if len(peaks) > 0:
peak_heights = properties["peak_heights"]
sorted_indices = np.argsort(peak_heights)[::-1][:top_k]
peaks = peaks[sorted_indices]
peak_heights = peak_heights[sorted_indices]
else:
# If no peaks, use top-k frames
peaks = np.argsort(combined_importance)[-top_k:][::-1]
peak_heights = combined_importance[peaks]
# Convert to time segments
segments = []
for i, (peak_idx, height) in enumerate(zip(peaks, peak_heights)):
# Expand around peak to find segment boundaries
start_idx = max(0, peak_idx - 5)
end_idx = min(len(combined_importance) - 1, peak_idx + 5)
# Convert to milliseconds
start_ms = int((start_idx * hop_length / sr) * 1000)
end_ms = int((end_idx * hop_length / sr) * 1000)
# Assign label based on position and characteristics
label = self._assign_label(i, peak_idx, len(combined_importance))
segment = SegmentAnnotation(
start_ms=start_ms,
end_ms=end_ms,
label=label,
weight=float(height),
)
segments.append(segment)
# Merge nearby segments if needed
segments = self._merge_segments(segments)
logger.info(f"Localised {len(segments)} segments")
return segments
def _assign_label(self, segment_idx: int, peak_idx: int, total_frames: int) -> str:
"""Assign label to segment based on characteristics."""
labels = [
"imprecise_consonants",
"irregular_rate",
"monopitch",
"hypernasality",
"reduced_breath_support",
]
# Simple heuristic: cycle through labels
return labels[segment_idx % len(labels)]
def _merge_segments(
self,
segments: list[SegmentAnnotation],
) -> list[SegmentAnnotation]:
"""Merge segments that are close together."""
if len(segments) <= 1:
return segments
# Sort by start time
segments = sorted(segments, key=lambda s: s.start_ms)
merged = [segments[0]]
for current in segments[1:]:
last = merged[-1]
gap_sec = (current.start_ms - last.end_ms) / 1000.0
if gap_sec < self.merge_threshold:
# Merge segments
last.end_ms = current.end_ms
last.weight = max(last.weight, current.weight)
else:
merged.append(current)
return merged
|