StrokeMitra-API / src /explainability /segment_localiser.py
DhruvB1906's picture
Upload folder using huggingface_hub
4e9a3bc verified
"""Segment localization from explainability outputs."""
import logging
import numpy as np
from scipy.signal import find_peaks
from src.reporting.schemas import SegmentAnnotation
logger = logging.getLogger(__name__)
class SegmentLocaliser:
"""Localize time segments from Grad-CAM and attention outputs."""
def __init__(
self,
min_segment_duration_ms: int = 200,
merge_threshold: float = 0.3,
):
"""
Initialize segment localiser.
Args:
min_segment_duration_ms: Minimum segment duration
merge_threshold: Merge segments if gap < threshold (in seconds)
"""
self.min_segment_duration_ms = min_segment_duration_ms
self.merge_threshold = merge_threshold
def localise(
self,
heatmap: np.ndarray,
attention_importance: np.ndarray,
hop_length: int = 512,
sr: int = 16000,
top_k: int = 5,
) -> list[SegmentAnnotation]:
"""
Localise important segments.
Args:
heatmap: Grad-CAM heatmap (freq, time)
attention_importance: Attention rollout importance (seq_len,)
hop_length: Hop length in samples
sr: Sample rate
top_k: Number of segments to return
Returns:
List of SegmentAnnotation objects
"""
# Combine heatmap and attention
# Average heatmap over frequency dimension
heatmap_time = heatmap.mean(axis=0) # (time,)
# Ensure same length
min_len = min(len(heatmap_time), len(attention_importance))
heatmap_time = heatmap_time[:min_len]
attention_importance = attention_importance[:min_len]
# Combined importance score
combined_importance = 0.6 * heatmap_time + 0.4 * attention_importance
combined_importance = combined_importance / (combined_importance.max() + 1e-8)
# Find peaks
peaks, properties = find_peaks(
combined_importance,
height=np.percentile(combined_importance, 60), # Above 60th percentile
distance=int(self.min_segment_duration_ms / 1000 * sr / hop_length),
)
# Sort by importance
if len(peaks) > 0:
peak_heights = properties["peak_heights"]
sorted_indices = np.argsort(peak_heights)[::-1][:top_k]
peaks = peaks[sorted_indices]
peak_heights = peak_heights[sorted_indices]
else:
# If no peaks, use top-k frames
peaks = np.argsort(combined_importance)[-top_k:][::-1]
peak_heights = combined_importance[peaks]
# Convert to time segments
segments = []
for i, (peak_idx, height) in enumerate(zip(peaks, peak_heights)):
# Expand around peak to find segment boundaries
start_idx = max(0, peak_idx - 5)
end_idx = min(len(combined_importance) - 1, peak_idx + 5)
# Convert to milliseconds
start_ms = int((start_idx * hop_length / sr) * 1000)
end_ms = int((end_idx * hop_length / sr) * 1000)
# Assign label based on position and characteristics
label = self._assign_label(i, peak_idx, len(combined_importance))
segment = SegmentAnnotation(
start_ms=start_ms,
end_ms=end_ms,
label=label,
weight=float(height),
)
segments.append(segment)
# Merge nearby segments if needed
segments = self._merge_segments(segments)
logger.info(f"Localised {len(segments)} segments")
return segments
def _assign_label(self, segment_idx: int, peak_idx: int, total_frames: int) -> str:
"""Assign label to segment based on characteristics."""
labels = [
"imprecise_consonants",
"irregular_rate",
"monopitch",
"hypernasality",
"reduced_breath_support",
]
# Simple heuristic: cycle through labels
return labels[segment_idx % len(labels)]
def _merge_segments(
self,
segments: list[SegmentAnnotation],
) -> list[SegmentAnnotation]:
"""Merge segments that are close together."""
if len(segments) <= 1:
return segments
# Sort by start time
segments = sorted(segments, key=lambda s: s.start_ms)
merged = [segments[0]]
for current in segments[1:]:
last = merged[-1]
gap_sec = (current.start_ms - last.end_ms) / 1000.0
if gap_sec < self.merge_threshold:
# Merge segments
last.end_ms = current.end_ms
last.weight = max(last.weight, current.weight)
else:
merged.append(current)
return merged