File size: 4,862 Bytes
4e9a3bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
"""Segment localization from explainability outputs."""

import logging
import numpy as np
from scipy.signal import find_peaks

from src.reporting.schemas import SegmentAnnotation

logger = logging.getLogger(__name__)


class SegmentLocaliser:
    """Localize time segments from Grad-CAM and attention outputs."""

    def __init__(
        self,
        min_segment_duration_ms: int = 200,
        merge_threshold: float = 0.3,
    ):
        """
        Initialize segment localiser.

        Args:
            min_segment_duration_ms: Minimum segment duration
            merge_threshold: Merge segments if gap < threshold (in seconds)
        """
        self.min_segment_duration_ms = min_segment_duration_ms
        self.merge_threshold = merge_threshold

    def localise(
        self,
        heatmap: np.ndarray,
        attention_importance: np.ndarray,
        hop_length: int = 512,
        sr: int = 16000,
        top_k: int = 5,
    ) -> list[SegmentAnnotation]:
        """
        Localise important segments.

        Args:
            heatmap: Grad-CAM heatmap (freq, time)
            attention_importance: Attention rollout importance (seq_len,)
            hop_length: Hop length in samples
            sr: Sample rate
            top_k: Number of segments to return

        Returns:
            List of SegmentAnnotation objects
        """
        # Combine heatmap and attention
        # Average heatmap over frequency dimension
        heatmap_time = heatmap.mean(axis=0)  # (time,)

        # Ensure same length
        min_len = min(len(heatmap_time), len(attention_importance))
        heatmap_time = heatmap_time[:min_len]
        attention_importance = attention_importance[:min_len]

        # Combined importance score
        combined_importance = 0.6 * heatmap_time + 0.4 * attention_importance
        combined_importance = combined_importance / (combined_importance.max() + 1e-8)

        # Find peaks
        peaks, properties = find_peaks(
            combined_importance,
            height=np.percentile(combined_importance, 60),  # Above 60th percentile
            distance=int(self.min_segment_duration_ms / 1000 * sr / hop_length),
        )

        # Sort by importance
        if len(peaks) > 0:
            peak_heights = properties["peak_heights"]
            sorted_indices = np.argsort(peak_heights)[::-1][:top_k]
            peaks = peaks[sorted_indices]
            peak_heights = peak_heights[sorted_indices]
        else:
            # If no peaks, use top-k frames
            peaks = np.argsort(combined_importance)[-top_k:][::-1]
            peak_heights = combined_importance[peaks]

        # Convert to time segments
        segments = []
        for i, (peak_idx, height) in enumerate(zip(peaks, peak_heights)):
            # Expand around peak to find segment boundaries
            start_idx = max(0, peak_idx - 5)
            end_idx = min(len(combined_importance) - 1, peak_idx + 5)

            # Convert to milliseconds
            start_ms = int((start_idx * hop_length / sr) * 1000)
            end_ms = int((end_idx * hop_length / sr) * 1000)

            # Assign label based on position and characteristics
            label = self._assign_label(i, peak_idx, len(combined_importance))

            segment = SegmentAnnotation(
                start_ms=start_ms,
                end_ms=end_ms,
                label=label,
                weight=float(height),
            )

            segments.append(segment)

        # Merge nearby segments if needed
        segments = self._merge_segments(segments)

        logger.info(f"Localised {len(segments)} segments")

        return segments

    def _assign_label(self, segment_idx: int, peak_idx: int, total_frames: int) -> str:
        """Assign label to segment based on characteristics."""
        labels = [
            "imprecise_consonants",
            "irregular_rate",
            "monopitch",
            "hypernasality",
            "reduced_breath_support",
        ]

        # Simple heuristic: cycle through labels
        return labels[segment_idx % len(labels)]

    def _merge_segments(
        self,
        segments: list[SegmentAnnotation],
    ) -> list[SegmentAnnotation]:
        """Merge segments that are close together."""
        if len(segments) <= 1:
            return segments

        # Sort by start time
        segments = sorted(segments, key=lambda s: s.start_ms)

        merged = [segments[0]]

        for current in segments[1:]:
            last = merged[-1]

            gap_sec = (current.start_ms - last.end_ms) / 1000.0

            if gap_sec < self.merge_threshold:
                # Merge segments
                last.end_ms = current.end_ms
                last.weight = max(last.weight, current.weight)
            else:
                merged.append(current)

        return merged