colab-user commited on
Commit
45711db
·
1 Parent(s): 22c6367

fix word alignment

Browse files
app/services/alignment.py CHANGED
@@ -8,7 +8,6 @@ from typing import List, Tuple, Optional
8
  from dataclasses import dataclass
9
 
10
  from app.core.config import get_settings
11
-
12
  from app.services.transcription import WordTimestamp
13
  from app.services.diarization import SpeakerSegment
14
  from app.schemas.models import TranscriptSegment
@@ -34,120 +33,133 @@ class AlignmentService:
34
  Uses word-center-based algorithm for accurate speaker-to-text mapping.
35
  """
36
 
37
- # Pause threshold for splitting segments (seconds)
38
  PAUSE_THRESHOLD = 1.0
 
 
 
39
 
40
  @staticmethod
41
  def get_word_center(word: WordTimestamp) -> float:
42
  """Calculate the center time of a word."""
43
  return (word.start + word.end) / 2
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  @classmethod
46
- def find_speaker_at_time(
47
  cls,
48
  time: float,
49
- speaker_segments: List[SpeakerSegment]
50
  ) -> Optional[str]:
51
- """
52
- Find which speaker is speaking at a given time.
53
-
54
- Args:
55
- time: Time point in seconds
56
- speaker_segments: List of speaker segments from diarization
57
-
58
- Returns:
59
- Speaker label or None if no speaker found
60
- """
61
  for seg in speaker_segments:
62
- if seg.start <= time <= seg.end:
63
  return seg.speaker
64
  return None
65
 
 
66
  @classmethod
67
- def find_closest_speaker(
68
  cls,
69
  time: float,
70
- speaker_segments: List[SpeakerSegment]
71
- ) -> str:
72
- """
73
- Find the closest speaker to a given time (for gaps/silence).
74
-
75
- Args:
76
- time: Time point in seconds
77
- speaker_segments: List of speaker segments
78
-
79
- Returns:
80
- Closest speaker label or "Unknown"
81
- """
82
  if not speaker_segments:
83
  return "Unknown"
84
-
85
- min_distance = float('inf')
86
- closest_speaker = "Unknown"
87
-
88
  for seg in speaker_segments:
89
- # Distance to segment start or end
90
- dist_to_start = abs(time - seg.start)
91
- dist_to_end = abs(time - seg.end)
92
- min_seg_dist = min(dist_to_start, dist_to_end)
93
-
94
- if min_seg_dist < min_distance:
95
- min_distance = min_seg_dist
96
- closest_speaker = seg.speaker
97
-
98
- return closest_speaker
99
 
100
  @classmethod
101
  def assign_speakers_to_words(
102
  cls,
103
  words: List[WordTimestamp],
104
- speaker_segments: List[SpeakerSegment]
105
  ) -> List[WordWithSpeaker]:
106
- """
107
- Step 3c: Assign speakers to each word based on word center time.
108
-
109
- Args:
110
- words: List of words with timestamps from transcription
111
- speaker_segments: List of speaker segments from diarization
112
-
113
- Returns:
114
- List of words with speaker assignments
115
- """
116
  if not speaker_segments:
117
- # No diarization available, assign all to "Speaker 1"
118
- logger.warning("No speaker segments available, using single speaker")
119
  return [
120
- WordWithSpeaker(
121
- word=w.word,
122
- start=w.start,
123
- end=w.end,
124
- speaker="Speaker 1"
125
- )
126
  for w in words
127
  ]
128
-
129
- words_with_speakers = []
130
-
 
 
131
  for word in words:
132
- # Calculate word center time
133
- center_time = cls.get_word_center(word)
134
-
135
- # Find speaker at this time
136
- speaker = cls.find_speaker_at_time(center_time, speaker_segments)
137
-
138
- # If no direct match, find closest speaker
139
  if speaker is None:
140
- speaker = cls.find_closest_speaker(center_time, speaker_segments)
141
-
142
- words_with_speakers.append(WordWithSpeaker(
143
- word=word.word,
144
- start=word.start,
145
- end=word.end,
146
- speaker=speaker
147
- ))
148
-
149
- logger.debug(f"Assigned speakers to {len(words_with_speakers)} words")
150
- return words_with_speakers
 
 
 
 
 
 
 
 
 
 
151
 
152
  @classmethod
153
  def reconstruct_segments(
 
8
  from dataclasses import dataclass
9
 
10
  from app.core.config import get_settings
 
11
  from app.services.transcription import WordTimestamp
12
  from app.services.diarization import SpeakerSegment
13
  from app.schemas.models import TranscriptSegment
 
33
  Uses word-center-based algorithm for accurate speaker-to-text mapping.
34
  """
35
 
 
36
  PAUSE_THRESHOLD = 1.0
37
+ CENTER_TOL = 0.12
38
+ OVERLAP_TH = 0.5
39
+ DIA_MERGE_GAP = 0.25
40
 
41
  @staticmethod
42
  def get_word_center(word: WordTimestamp) -> float:
43
  """Calculate the center time of a word."""
44
  return (word.start + word.end) / 2
45
 
46
+
47
+ @staticmethod
48
+ def overlap_ratio(w_start, w_end, s_start, s_end):
49
+ overlap = max(0.0, min(w_end, s_end) - max(w_start, s_start))
50
+ dur = max(1e-6, w_end - w_start)
51
+ return overlap / dur
52
+
53
+
54
+ # Diarization merge
55
+ @classmethod
56
+ def merge_dia_segments(cls, segments: List[SpeakerSegment]) -> List[SpeakerSegment]:
57
+ if not segments:
58
+ return []
59
+
60
+ segments = sorted(segments, key=lambda s: s.start)
61
+ merged = [segments[0]]
62
+
63
+ for s in segments[1:]:
64
+ p = merged[-1]
65
+ if s.speaker == p.speaker and (s.start - p.end) <= cls.DIA_MERGE_GAP:
66
+ p.end = s.end
67
+ else:
68
+ merged.append(s)
69
+
70
+ return merged
71
+
72
+
73
  @classmethod
74
+ def find_speaker_center(
75
  cls,
76
  time: float,
77
+ speaker_segments: List[SpeakerSegment],
78
  ) -> Optional[str]:
79
+
 
 
 
 
 
 
 
 
 
80
  for seg in speaker_segments:
81
+ if seg.start - cls.CENTER_TOL <= time <= seg.end + cls.CENTER_TOL:
82
  return seg.speaker
83
  return None
84
 
85
+
86
  @classmethod
87
+ def find_speaker_center(
88
  cls,
89
  time: float,
90
+ speaker_segments: List[SpeakerSegment],
91
+ ) -> Optional[str]:
92
+
93
+ for seg in speaker_segments:
94
+ if seg.start - cls.CENTER_TOL <= time <= seg.end + cls.CENTER_TOL:
95
+ return seg.speaker
96
+ return None
97
+
98
+ @staticmethod
99
+ def find_closest_speaker(time: float, speaker_segments: List[SpeakerSegment]) -> str:
 
 
100
  if not speaker_segments:
101
  return "Unknown"
102
+
103
+ min_dist = float("inf")
104
+ closest = "Unknown"
105
+
106
  for seg in speaker_segments:
107
+ d = min(abs(time - seg.start), abs(time - seg.end))
108
+ if d < min_dist:
109
+ min_dist = d
110
+ closest = seg.speaker
111
+
112
+ return closest
113
+
 
 
 
114
 
115
  @classmethod
116
  def assign_speakers_to_words(
117
  cls,
118
  words: List[WordTimestamp],
119
+ speaker_segments: List[SpeakerSegment],
120
  ) -> List[WordWithSpeaker]:
121
+
122
+ words = [w for w in words if w.word and w.word.strip()]
123
+
 
 
 
 
 
 
 
124
  if not speaker_segments:
125
+ logger.warning("No diarization, fallback single speaker")
 
126
  return [
127
+ WordWithSpeaker(w.word, w.start, w.end, "Speaker 1")
 
 
 
 
 
128
  for w in words
129
  ]
130
+
131
+ speaker_segments = cls.merge_dia_segments(speaker_segments)
132
+
133
+ results = []
134
+
135
  for word in words:
136
+ center = cls.get_word_center(word)
137
+
138
+ # 1. CENTER
139
+ speaker = cls.find_speaker_center(center, speaker_segments)
140
+
 
 
141
  if speaker is None:
142
+ # 2. OVERLAP
143
+ best_ratio = 0
144
+ best_spk = None
145
+
146
+ for seg in speaker_segments:
147
+ r = cls.overlap_ratio(word.start, word.end, seg.start, seg.end)
148
+ if r > best_ratio:
149
+ best_ratio = r
150
+ best_spk = seg.speaker
151
+
152
+ if best_ratio >= cls.OVERLAP_TH:
153
+ speaker = best_spk
154
+ else:
155
+ # 3. CLOSEST
156
+ speaker = cls.find_closest_speaker(center, speaker_segments)
157
+
158
+ results.append(
159
+ WordWithSpeaker(word.word, word.start, word.end, speaker)
160
+ )
161
+
162
+ return results
163
 
164
  @classmethod
165
  def reconstruct_segments(
app/services/processor.py CHANGED
@@ -13,6 +13,9 @@ import torch
13
 
14
  from app.core.config import get_settings
15
  from app.services.transcription import TranscriptionService
 
 
 
16
  from app.services.diarization import DiarizationService, SpeakerSegment, DiarizationResult
17
 
18
  logger = logging.getLogger(__name__)
@@ -42,92 +45,6 @@ class ProcessingResult:
42
  txt_content: str = ""
43
  csv_content: str = ""
44
 
45
- def assign_speaker_to_word(word_start, diarization_segments):
46
- if not diarization_segments:
47
- return "SPEAKER_0"
48
-
49
- for seg in diarization_segments:
50
- if seg.start - 0.05 <= word_start <= seg.end + 0.05:
51
- return seg.speaker
52
-
53
- return min(
54
- diarization_segments,
55
- key=lambda s: abs((s.start + s.end) / 2 - word_start)
56
- ).speaker
57
-
58
-
59
-
60
- def _safe_extract_word(w: dict) -> str:
61
- """
62
- Robust extractor for finetuned Whisper outputs.
63
- Handles:
64
- - str
65
- - dict {text, confidence, ...}
66
- - None
67
- """
68
- raw = w.get("word", "")
69
-
70
- if isinstance(raw, dict):
71
- return str(raw.get("text", "")).strip()
72
-
73
- if isinstance(raw, str):
74
- return raw.strip()
75
-
76
- return str(raw).strip()
77
-
78
- def group_words_into_segments(
79
- words: List[dict],
80
- diarization_segments: List[SpeakerSegment],
81
- speaker_map: Dict[str, str],
82
- roles: Dict[str, str],
83
- max_word_gap_s: float = 0.6,
84
- ) -> List[TranscriptSegment]:
85
-
86
- segments: List[TranscriptSegment] = []
87
- current: Optional[TranscriptSegment] = None
88
-
89
- for w in words or []:
90
- text = _safe_extract_word(w)
91
- if not text:
92
- continue
93
-
94
- w_start = float(w["start"])
95
- w_end = float(w["end"])
96
-
97
- speaker_raw = assign_speaker_to_word(w_start, diarization_segments)
98
- speaker = speaker_map.get(speaker_raw, speaker_raw)
99
- role = roles.get(speaker, "KH")
100
-
101
- if current is None:
102
- current = TranscriptSegment(
103
- start=w_start,
104
- end=w_end,
105
- speaker=speaker,
106
- role=role,
107
- text=text,
108
- )
109
- continue
110
-
111
- gap = w_start - current.end
112
-
113
- if speaker == current.speaker and gap <= max_word_gap_s:
114
- current.text += " " + text
115
- current.end = max(current.end, w_end)
116
- else:
117
- segments.append(current)
118
- current = TranscriptSegment(
119
- start=w_start,
120
- end=w_end,
121
- speaker=speaker,
122
- role=role,
123
- text=text,
124
- )
125
-
126
- if current:
127
- segments.append(current)
128
-
129
- return segments
130
-
131
 
132
  def normalize_asr_result(result):
133
  """
@@ -280,9 +197,8 @@ class Processor:
280
  temperature: float = 0.0,
281
  best_of: int = 5,
282
  log_prob_threshold = -1.2,
283
- compression_ratio_threshold = 2.4,
284
-
285
- initial_prompt: Optional[str] = None,
286
  ) -> ProcessingResult:
287
 
288
  import asyncio
@@ -374,8 +290,13 @@ class Processor:
374
 
375
  logger.info(f"roles(mapped) = {roles}")
376
 
377
-
378
-
 
 
 
 
 
379
  # 7: Transcribe
380
  vad_options = None
381
  if vad_filter:
@@ -439,22 +360,47 @@ class Processor:
439
 
440
  if not words and not text:
441
  continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
442
 
443
- # ===== MAP WORD → SPEAKER =====
444
- word_segments = group_words_into_segments(
445
- words=words,
446
- diarization_segments=refined_segments,
447
- speaker_map=speaker_map,
448
- roles=roles,
449
- max_word_gap_s=0.6
450
  )
451
 
452
- for seg in word_segments:
453
- processed_segments.append(seg)
 
 
 
 
 
 
 
 
 
 
 
 
454
 
455
  if text:
456
  prev_prompt += " " + text
457
-
458
 
459
  if not processed_segments:
460
  processed_segments = [
@@ -468,15 +414,9 @@ class Processor:
468
  ]
469
 
470
  processed_segments = cls._merge_adjacent_segments(
471
- processed_segments,
472
- max_gap_s=0.7
473
  )
474
 
475
- processed_segments = [
476
- s for s in processed_segments
477
- if cls._is_meaningful_segment(s)
478
- ]
479
-
480
  processed_segments = cls._filter_segments_with_context(processed_segments)
481
  processing_time = time.time() - t0
482
 
 
13
 
14
  from app.core.config import get_settings
15
  from app.services.transcription import TranscriptionService
16
+ from app.services.alignment import AlignmentService
17
+ from app.services.transcription import WordTimestamp
18
+
19
  from app.services.diarization import DiarizationService, SpeakerSegment, DiarizationResult
20
 
21
  logger = logging.getLogger(__name__)
 
45
  txt_content: str = ""
46
  csv_content: str = ""
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  def normalize_asr_result(result):
50
  """
 
197
  temperature: float = 0.0,
198
  best_of: int = 5,
199
  log_prob_threshold = -1.2,
200
+ compression_ratio_threshold = 2.4
201
+
 
202
  ) -> ProcessingResult:
203
 
204
  import asyncio
 
290
 
291
  logger.info(f"roles(mapped) = {roles}")
292
 
293
+ # context window
294
+ windows = build_context_windows(
295
+ duration,
296
+ win_s=settings.CONTEXT_WINDOW_S,
297
+ overlap_s=settings.CONTEXT_OVERLAP_S,
298
+ )
299
+
300
  # 7: Transcribe
301
  vad_options = None
302
  if vad_filter:
 
360
 
361
  if not words and not text:
362
  continue
363
+
364
+ words = []
365
+ for w in words:
366
+ try:
367
+ words.append(
368
+ WordTimestamp(
369
+ word=str(w.get("word", "")).strip(),
370
+ start=float(w.get("start", 0)) + w_start,
371
+ end=float(w.get("end", 0)) + w_start,
372
+ )
373
+ )
374
+ except:
375
+ pass
376
+
377
+ if not words:
378
+ continue
379
 
380
+ # ===== ALIGNMENT =====
381
+ aligned_segments = AlignmentService.align_precision(
382
+ words,
383
+ refined_segments
 
 
 
384
  )
385
 
386
+ # ===== MAP WORD → ROLE =====
387
+ for seg in aligned_segments:
388
+ label = speaker_map.get(seg.speaker, seg.speaker)
389
+ role = roles.get(label, "KH")
390
+
391
+ processed_segments.append(
392
+ TranscriptSegment(
393
+ start=seg.start,
394
+ end=seg.end,
395
+ speaker=label,
396
+ role=role,
397
+ text=seg.text
398
+ )
399
+ )
400
 
401
  if text:
402
  prev_prompt += " " + text
403
+
404
 
405
  if not processed_segments:
406
  processed_segments = [
 
414
  ]
415
 
416
  processed_segments = cls._merge_adjacent_segments(
417
+ processed_segments
 
418
  )
419
 
 
 
 
 
 
420
  processed_segments = cls._filter_segments_with_context(processed_segments)
421
  processing_time = time.time() - t0
422
 
app/services/transcription.py CHANGED
@@ -83,9 +83,10 @@ class TranscriptionService:
83
 
84
  @classmethod
85
  def is_loaded(cls, model_name: str = None) -> bool:
 
86
  if model_name is None:
87
  model_name = settings.default_whisper_model
88
- """Check if a model is loaded."""
89
  cache_key = f"{model_name}_{settings.resolved_compute_type}"
90
  return cache_key in cls._models
91
 
 
83
 
84
  @classmethod
85
  def is_loaded(cls, model_name: str = None) -> bool:
86
+ """Check if a model is loaded."""
87
  if model_name is None:
88
  model_name = settings.default_whisper_model
89
+
90
  cache_key = f"{model_name}_{settings.resolved_compute_type}"
91
  return cache_key in cls._models
92