mazesmazes commited on
Commit
812d81c
·
verified ·
1 Parent(s): 3542a34

Training in progress - step 1000

Browse files
alignment.py CHANGED
@@ -3,6 +3,11 @@
3
  import numpy as np
4
  import torch
5
 
 
 
 
 
 
6
 
7
  def _get_device() -> str:
8
  """Get best available device for non-transformers models."""
@@ -65,6 +70,11 @@ class ForcedAligner:
65
  trellis = torch.full((num_frames + 1, num_tokens + 1), -float("inf"))
66
  trellis[0, 0] = 0
67
 
 
 
 
 
 
68
  for t in range(num_frames):
69
  for j in range(num_tokens + 1):
70
  # Stay: emit blank and stay at j tokens
@@ -80,7 +90,7 @@ class ForcedAligner:
80
  @staticmethod
81
  def _backtrack(
82
  trellis: torch.Tensor, emission: torch.Tensor, tokens: list[int], blank_id: int = 0
83
- ) -> list[tuple[int, float, float]]:
84
  """Backtrack through trellis to find optimal forced monotonic alignment.
85
 
86
  Guarantees:
@@ -88,7 +98,8 @@ class ForcedAligner:
88
  - Strictly monotonic: each token's frames come after previous token's
89
  - No frame skipping or token teleporting
90
 
91
- Returns list of (token_id, start_frame, end_frame) for each token.
 
92
  """
93
  num_frames = emission.size(0)
94
  num_tokens = len(tokens)
@@ -102,13 +113,18 @@ class ForcedAligner:
102
  # Alignment failed - fall back to uniform distribution
103
  frames_per_token = num_frames / num_tokens
104
  return [
105
- (tokens[i], i * frames_per_token, (i + 1) * frames_per_token)
 
 
 
 
 
106
  for i in range(num_tokens)
107
  ]
108
 
109
  # Backtrack: find where each token transition occurred
110
- # path[i] = frame where token i was first emitted
111
- token_frames: list[list[int]] = [[] for _ in range(num_tokens)]
112
 
113
  t = num_frames
114
  j = num_tokens
@@ -120,38 +136,40 @@ class ForcedAligner:
120
 
121
  if move_score >= stay_score:
122
  # Token j-1 was emitted at frame t-1
123
- token_frames[j - 1].insert(0, t - 1)
 
 
124
  j -= 1
125
  # Always decrement time (monotonic)
126
  t -= 1
127
 
128
  # Handle any remaining tokens at the start (edge case)
129
  while j > 0:
130
- token_frames[j - 1].insert(0, 0)
131
  j -= 1
132
 
133
- # Convert to spans
134
- token_spans: list[tuple[int, float, float]] = []
135
- for token_idx, frames in enumerate(token_frames):
136
- if not frames:
137
  # Token never emitted - assign minimal span after previous
138
  if token_spans:
139
  prev_end = token_spans[-1][2]
140
- frames = [int(prev_end)]
141
  else:
142
- frames = [0]
143
 
144
  token_id = tokens[token_idx]
 
145
  start_frame = float(min(frames))
146
  end_frame = float(max(frames)) + 1.0
147
- token_spans.append((token_id, start_frame, end_frame))
148
 
149
- return token_spans
 
150
 
151
- # Offset compensation for Wav2Vec2-BASE systematic bias (in seconds)
152
- # Calibrated on librispeech-alignments dataset
153
- START_OFFSET = 0.06 # Subtract from start times (shift earlier)
154
- END_OFFSET = -0.03 # Add to end times (shift later)
155
 
156
  @classmethod
157
  def align(
@@ -229,26 +247,28 @@ class ForcedAligner:
229
  frame_duration = 320 / cls._bundle.sample_rate
230
 
231
  # Apply separate offset compensation for start/end (Wav2Vec2 systematic bias)
232
- start_offset = cls.START_OFFSET
233
- end_offset = cls.END_OFFSET
234
 
235
  # Group aligned tokens into words based on pipe separator
 
236
  words = text.split()
237
  word_timestamps = []
238
- current_word_start = None
239
- current_word_end = None
240
  word_idx = 0
241
  separator_id = dictionary.get("|", dictionary.get(" ", 0))
242
 
243
- for token_id, start_frame, end_frame in alignment_path:
244
  if token_id == separator_id: # Word separator
245
  if (
246
- current_word_start is not None
247
- and current_word_end is not None
248
  and word_idx < len(words)
249
  ):
250
- start_time = max(0.0, current_word_start * frame_duration - start_offset)
251
- end_time = max(0.0, current_word_end * frame_duration - end_offset)
 
252
  word_timestamps.append(
253
  {
254
  "word": words[word_idx],
@@ -257,21 +277,17 @@ class ForcedAligner:
257
  }
258
  )
259
  word_idx += 1
260
- current_word_start = None
261
- current_word_end = None
262
  else:
263
- if current_word_start is None:
264
- current_word_start = start_frame
265
- current_word_end = end_frame
266
 
267
  # Don't forget the last word
268
- if (
269
- current_word_start is not None
270
- and current_word_end is not None
271
- and word_idx < len(words)
272
- ):
273
- start_time = max(0.0, current_word_start * frame_duration - start_offset)
274
- end_time = max(0.0, current_word_end * frame_duration - end_offset)
275
  word_timestamps.append(
276
  {
277
  "word": words[word_idx],
 
3
  import numpy as np
4
  import torch
5
 
6
+ # Offset compensation for Wav2Vec2-BASE systematic bias (in seconds)
7
+ # Calibrated on librispeech-alignments dataset (n=25, MAE=48ms)
8
+ START_OFFSET = 0.04 # Subtract from start times (shift earlier)
9
+ END_OFFSET = -0.04 # Subtract from end times (shift later)
10
+
11
 
12
  def _get_device() -> str:
13
  """Get best available device for non-transformers models."""
 
70
  trellis = torch.full((num_frames + 1, num_tokens + 1), -float("inf"))
71
  trellis[0, 0] = 0
72
 
73
+ # Force alignment to use all tokens by preventing staying in blank
74
+ # at the end when there are still tokens to emit
75
+ if num_tokens > 1:
76
+ trellis[-num_tokens + 1 :, 0] = float("inf")
77
+
78
  for t in range(num_frames):
79
  for j in range(num_tokens + 1):
80
  # Stay: emit blank and stay at j tokens
 
90
  @staticmethod
91
  def _backtrack(
92
  trellis: torch.Tensor, emission: torch.Tensor, tokens: list[int], blank_id: int = 0
93
+ ) -> list[tuple[int, float, float, float]]:
94
  """Backtrack through trellis to find optimal forced monotonic alignment.
95
 
96
  Guarantees:
 
98
  - Strictly monotonic: each token's frames come after previous token's
99
  - No frame skipping or token teleporting
100
 
101
+ Returns list of (token_id, start_frame, end_frame, peak_frame) for each token.
102
+ The peak_frame is the frame with highest emission probability for that token.
103
  """
104
  num_frames = emission.size(0)
105
  num_tokens = len(tokens)
 
113
  # Alignment failed - fall back to uniform distribution
114
  frames_per_token = num_frames / num_tokens
115
  return [
116
+ (
117
+ tokens[i],
118
+ i * frames_per_token,
119
+ (i + 1) * frames_per_token,
120
+ (i + 0.5) * frames_per_token,
121
+ )
122
  for i in range(num_tokens)
123
  ]
124
 
125
  # Backtrack: find where each token transition occurred
126
+ # Store (frame, emission_score) for each token
127
+ token_frames: list[list[tuple[int, float]]] = [[] for _ in range(num_tokens)]
128
 
129
  t = num_frames
130
  j = num_tokens
 
136
 
137
  if move_score >= stay_score:
138
  # Token j-1 was emitted at frame t-1
139
+ # Store frame and its emission probability
140
+ emit_prob = emission[t - 1, tokens[j - 1]].exp().item()
141
+ token_frames[j - 1].insert(0, (t - 1, emit_prob))
142
  j -= 1
143
  # Always decrement time (monotonic)
144
  t -= 1
145
 
146
  # Handle any remaining tokens at the start (edge case)
147
  while j > 0:
148
+ token_frames[j - 1].insert(0, (0, 0.0))
149
  j -= 1
150
 
151
+ # Convert to spans with peak frame
152
+ token_spans: list[tuple[int, float, float, float]] = []
153
+ for token_idx, frames_with_scores in enumerate(token_frames):
154
+ if not frames_with_scores:
155
  # Token never emitted - assign minimal span after previous
156
  if token_spans:
157
  prev_end = token_spans[-1][2]
158
+ frames_with_scores = [(int(prev_end), 0.0)]
159
  else:
160
+ frames_with_scores = [(0, 0.0)]
161
 
162
  token_id = tokens[token_idx]
163
+ frames = [f for f, _ in frames_with_scores]
164
  start_frame = float(min(frames))
165
  end_frame = float(max(frames)) + 1.0
 
166
 
167
+ # Find peak frame (highest emission probability)
168
+ peak_frame, _ = max(frames_with_scores, key=lambda x: x[1])
169
 
170
+ token_spans.append((token_id, start_frame, end_frame, float(peak_frame)))
171
+
172
+ return token_spans
 
173
 
174
  @classmethod
175
  def align(
 
247
  frame_duration = 320 / cls._bundle.sample_rate
248
 
249
  # Apply separate offset compensation for start/end (Wav2Vec2 systematic bias)
250
+ start_offset = START_OFFSET
251
+ end_offset = END_OFFSET
252
 
253
  # Group aligned tokens into words based on pipe separator
254
+ # Use peak emission frame for more accurate word boundaries
255
  words = text.split()
256
  word_timestamps = []
257
+ first_char_peak = None
258
+ last_char_peak = None
259
  word_idx = 0
260
  separator_id = dictionary.get("|", dictionary.get(" ", 0))
261
 
262
+ for token_id, _start_frame, _end_frame, peak_frame in alignment_path:
263
  if token_id == separator_id: # Word separator
264
  if (
265
+ first_char_peak is not None
266
+ and last_char_peak is not None
267
  and word_idx < len(words)
268
  ):
269
+ # Use peak frames for word boundaries
270
+ start_time = max(0.0, first_char_peak * frame_duration - start_offset)
271
+ end_time = max(0.0, (last_char_peak + 1) * frame_duration - end_offset)
272
  word_timestamps.append(
273
  {
274
  "word": words[word_idx],
 
277
  }
278
  )
279
  word_idx += 1
280
+ first_char_peak = None
281
+ last_char_peak = None
282
  else:
283
+ if first_char_peak is None:
284
+ first_char_peak = peak_frame
285
+ last_char_peak = peak_frame
286
 
287
  # Don't forget the last word
288
+ if first_char_peak is not None and last_char_peak is not None and word_idx < len(words):
289
+ start_time = max(0.0, first_char_peak * frame_duration - start_offset)
290
+ end_time = max(0.0, (last_char_peak + 1) * frame_duration - end_offset)
 
 
 
 
291
  word_timestamps.append(
292
  {
293
  "word": words[word_idx],
asr_config.py CHANGED
@@ -64,6 +64,7 @@ class ASRConfig(transformers.PretrainedConfig):
64
  lora_target_modules: Optional[list] = None, # Default: all linear layers
65
  freeze_projector: bool = False, # True for Stage 2 (LoRA-only training)
66
  do_sample: bool = False,
 
67
  temperature: Optional[float] = None,
68
  top_p: Optional[float] = None,
69
  top_k: Optional[int] = None,
@@ -174,6 +175,7 @@ class ASRConfig(transformers.PretrainedConfig):
174
  )
175
  self.use_cache = use_cache if use_cache is not None else generation_defaults["use_cache"]
176
  self.do_sample = do_sample
 
177
  self.temperature = temperature
178
  self.top_p = top_p
179
  self.top_k = top_k
 
64
  lora_target_modules: Optional[list] = None, # Default: all linear layers
65
  freeze_projector: bool = False, # True for Stage 2 (LoRA-only training)
66
  do_sample: bool = False,
67
+ enable_thinking: bool = False, # Enable Qwen3 thinking mode for omni models
68
  temperature: Optional[float] = None,
69
  top_p: Optional[float] = None,
70
  top_k: Optional[int] = None,
 
175
  )
176
  self.use_cache = use_cache if use_cache is not None else generation_defaults["use_cache"]
177
  self.do_sample = do_sample
178
+ self.enable_thinking = enable_thinking
179
  self.temperature = temperature
180
  self.top_p = top_p
181
  self.top_k = top_k
asr_modeling.py CHANGED
@@ -582,7 +582,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
582
  tokenize=True,
583
  add_generation_prompt=True,
584
  return_tensors="pt",
585
- enable_thinking=False, # Disable Qwen3 thinking mode for ASR
586
  )
587
  input_ids = chat_result.input_ids.to(device)
588
 
@@ -665,7 +665,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
665
  tokenize=True,
666
  add_generation_prompt=True,
667
  return_tensors="pt",
668
- enable_thinking=False, # Disable Qwen3 thinking mode for ASR
669
  )
670
  input_ids = chat_result.input_ids.to(device)
671
 
@@ -764,7 +764,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
764
  tokenize=True,
765
  add_generation_prompt=True,
766
  return_tensors="pt",
767
- enable_thinking=False, # Disable Qwen3 thinking mode for ASR
768
  ).to(device)
769
 
770
  if input_ids.dim() == 1:
 
582
  tokenize=True,
583
  add_generation_prompt=True,
584
  return_tensors="pt",
585
+ enable_thinking=getattr(self.config, "enable_thinking", False),
586
  )
587
  input_ids = chat_result.input_ids.to(device)
588
 
 
665
  tokenize=True,
666
  add_generation_prompt=True,
667
  return_tensors="pt",
668
+ enable_thinking=getattr(self.config, "enable_thinking", False),
669
  )
670
  input_ids = chat_result.input_ids.to(device)
671
 
 
764
  tokenize=True,
765
  add_generation_prompt=True,
766
  return_tensors="pt",
767
+ enable_thinking=getattr(self.config, "enable_thinking", False),
768
  ).to(device)
769
 
770
  if input_ids.dim() == 1:
chat_template.jinja CHANGED
@@ -1,89 +1,94 @@
1
- {%- if tools %}
2
- {{- '<|im_start|>system\n' }}
3
- {%- if messages[0].role == 'system' %}
4
- {{- messages[0].content + '\n\n' }}
5
- {%- endif %}
6
- {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
7
- {%- for tool in tools %}
8
- {{- "\n" }}
9
- {{- tool | tojson }}
10
- {%- endfor %}
11
- {{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
12
- {%- else %}
13
- {%- if messages[0].role == 'system' %}
14
- {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
15
- {%- endif %}
16
- {%- endif %}
17
- {%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
18
- {%- for message in messages[::-1] %}
19
- {%- set index = (messages|length - 1) - loop.index0 %}
20
- {%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
21
- {%- set ns.multi_step_tool = false %}
22
- {%- set ns.last_query_index = index %}
23
- {%- endif %}
24
- {%- endfor %}
25
- {%- for message in messages %}
26
- {%- if message.content is string %}
27
- {%- set content = message.content %}
28
- {%- else %}
29
- {%- set content = '' %}
30
- {%- endif %}
31
- {%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
32
- {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
33
- {%- elif message.role == "assistant" %}
34
- {%- set reasoning_content = '' %}
35
- {%- if message.reasoning_content is string %}
36
- {%- set reasoning_content = message.reasoning_content %}
37
- {%- else %}
38
- {%- if '</think>' in content %}
39
- {%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
40
- {%- set content = content.split('</think>')[-1].lstrip('\n') %}
41
- {%- endif %}
42
- {%- endif %}
43
- {%- if loop.index0 > ns.last_query_index %}
44
- {%- if loop.last or (not loop.last and reasoning_content) %}
45
- {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
46
- {%- else %}
47
- {{- '<|im_start|>' + message.role + '\n' + content }}
48
- {%- endif %}
49
- {%- else %}
50
- {{- '<|im_start|>' + message.role + '\n' + content }}
51
- {%- endif %}
52
- {%- if message.tool_calls %}
53
- {%- for tool_call in message.tool_calls %}
54
- {%- if (loop.first and content) or (not loop.first) %}
55
- {{- '\n' }}
56
- {%- endif %}
57
- {%- if tool_call.function %}
58
- {%- set tool_call = tool_call.function %}
59
- {%- endif %}
60
- {{- '<tool_call>\n{"name": "' }}
61
- {{- tool_call.name }}
62
- {{- '", "arguments": ' }}
63
- {%- if tool_call.arguments is string %}
64
- {{- tool_call.arguments }}
65
- {%- else %}
66
- {{- tool_call.arguments | tojson }}
67
- {%- endif %}
68
- {{- '}\n</tool_call>' }}
69
- {%- endfor %}
70
- {%- endif %}
71
- {{- '<|im_end|>\n' }}
72
- {%- elif message.role == "tool" %}
73
- {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
74
- {{- '<|im_start|>user' }}
75
- {%- endif %}
76
- {{- '\n<tool_response>\n' }}
77
- {{- content }}
78
- {{- '\n</tool_response>' }}
79
- {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
80
- {{- '<|im_end|>\n' }}
81
- {%- endif %}
82
- {%- endif %}
83
- {%- endfor %}
84
- {%- if add_generation_prompt %}
85
- {{- '<|im_start|>assistant\n' }}
86
- {%- if true %}
87
- {{- '<think>\n\n</think>\n\n' }}
88
- {%- endif %}
89
- {%- endif %}
 
 
 
 
 
 
1
+ {# ───── defaults ───── #}
2
+ {%- if enable_thinking is not defined -%}
3
+ {%- set enable_thinking = true -%}
4
+ {%- endif -%}
5
+
6
+ {# ───── reasoning mode ───── #}
7
+ {%- if enable_thinking -%}
8
+ {%- set reasoning_mode = "/think" -%}
9
+ {%- else -%}
10
+ {%- set reasoning_mode = "/no_think" -%}
11
+ {%- endif -%}
12
+
13
+ {# ───── header (system message) ───── #}
14
+ {{- "<|im_start|>system\n" -}}
15
+
16
+ {%- if messages[0].role == "system" -%}
17
+ {%- set system_message = messages[0].content -%}
18
+ {%- if "/no_think" in system_message -%}
19
+ {%- set reasoning_mode = "/no_think" -%}
20
+ {%- elif "/think" in system_message -%}
21
+ {%- set reasoning_mode = "/think" -%}
22
+ {%- endif -%}
23
+ {%- set custom_instructions = system_message.replace("/no_think", "").replace("/think", "").rstrip() -%}
24
+ {%- endif -%}
25
+
26
+ {%- if "/system_override" in system_message -%}
27
+ {{- custom_instructions.replace("/system_override", "").rstrip() -}}
28
+ {{- "<|im_end|>\n" -}}
29
+ {%- else -%}
30
+ {{- "## Metadata\n\n" -}}
31
+ {{- "Knowledge Cutoff Date: June 2025\n" -}}
32
+ {%- set today = strftime_now("%d %B %Y") -%}
33
+ {{- "Today Date: " ~ today ~ "\n" -}}
34
+ {{- "Reasoning Mode: " + reasoning_mode + "\n\n" -}}
35
+
36
+ {{- "## Custom Instructions\n\n" -}}
37
+ {%- if custom_instructions -%}
38
+ {{- custom_instructions + "\n\n" -}}
39
+ {%- elif reasoning_mode == "/think" -%}
40
+ {{- "You are a helpful AI assistant named SmolLM, trained by Hugging Face. Your role as an assistant involves thoroughly exploring questions through a systematic thinking process before providing the final precise and accurate solutions. This requires engaging in a comprehensive cycle of analysis, summarizing, exploration, reassessment, reflection, backtracking, and iteration to develop well-considered thinking process. Please structure your response into two main sections: Thought and Solution using the specified format: <think> Thought section </think> Solution section. In the Thought section, detail your reasoning process in steps. Each step should include detailed considerations such as analysing questions, summarizing relevant findings, brainstorming new ideas, verifying the accuracy of the current steps, refining any errors, and revisiting previous steps. In the Solution section, based on various attempts, explorations, and reflections from the Thought section, systematically present the final solution that you deem correct. The Solution section should be logical, accurate, and concise and detail necessary steps needed to reach the conclusion.\n\n" -}}
41
+ {%- else -%}
42
+ {{- "You are a helpful AI assistant named SmolLM, trained by Hugging Face.\n\n" -}}
43
+ {%- endif -%}
44
+
45
+ {%- if xml_tools or python_tools or tools -%}
46
+ {{- "### Tools\n\n" -}}
47
+ {%- if xml_tools or tools -%}
48
+ {%- if tools -%}
49
+ {%- set xml_tools = tools -%}
50
+ {%- endif -%}
51
+ {%- set ns = namespace(xml_tool_string="You may call one or more functions to assist with the user query.\nYou are provided with function signatures within <tools></tools> XML tags:\n\n<tools>\n") -%}
52
+ {%- for tool in xml_tools[:] -%} {# The slicing makes sure that xml_tools is a list #}
53
+ {%- set ns.xml_tool_string = ns.xml_tool_string ~ (tool | string) ~ "\n" -%}
54
+ {%- endfor -%}
55
+ {%- set xml_tool_string = ns.xml_tool_string + "</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call>" -%}
56
+ {{- xml_tool_string -}}
57
+ {%- endif -%}
58
+ {%- if python_tools -%}
59
+ {%- set ns = namespace(python_tool_string="When you send a message containing Python code between '<code>' and '</code>' tags, it will be executed in a stateful Jupyter notebook environment, and you will then be given the output to continued reasoning in an agentic loop.\n\nYou can use the following tools in your python code like regular functions:\n<tools>\n") -%}
60
+ {%- for tool in python_tools[:] -%} {# The slicing makes sure that python_tools is a list #}
61
+ {%- set ns.python_tool_string = ns.python_tool_string ~ (tool | string) ~ "\n" -%}
62
+ {%- endfor -%}
63
+ {%- set python_tool_string = ns.python_tool_string + "</tools>\n\nThe state persists between code executions: so variables that you define in one step are still available thereafter." -%}
64
+ {{- python_tool_string -}}
65
+ {%- endif -%}
66
+ {{- "\n\n" -}}
67
+ {{- "<|im_end|>\n" -}}
68
+ {%- endif -%}
69
+ {%- endif -%}
70
+ {# ───── main loop ───── #}
71
+ {%- for message in messages -%}
72
+ {%- set content = message.content if message.content is string else "" -%}
73
+ {%- if message.role == "user" -%}
74
+ {{ "<|im_start|>" + message.role + "\n" + content + "<|im_end|>\n" }}
75
+ {%- elif message.role == "assistant" -%}
76
+ {% generation %}
77
+ {%- if reasoning_mode == "/think" -%}
78
+ {{ "<|im_start|>assistant\n" + content.lstrip("\n") + "<|im_end|>\n" }}
79
+ {%- else -%}
80
+ {{ "<|im_start|>assistant\n" + "<think>\n\n</think>\n" + content.lstrip("\n") + "<|im_end|>\n" }}
81
+ {%- endif -%}
82
+ {% endgeneration %}
83
+ {%- elif message.role == "tool" -%}
84
+ {{ "<|im_start|>" + "user\n" + content + "<|im_end|>\n" }}
85
+ {%- endif -%}
86
+ {%- endfor -%}
87
+ {# ───── generation prompt ───── #}
88
+ {%- if add_generation_prompt -%}
89
+ {%- if reasoning_mode == "/think" -%}
90
+ {{ "<|im_start|>assistant\n" }}
91
+ {%- else -%}
92
+ {{ "<|im_start|>assistant\n" + "<think>\n\n</think>\n" }}
93
+ {%- endif -%}
94
+ {%- endif -%}
diarization.py CHANGED
@@ -91,20 +91,47 @@ class SpectralCluster:
91
  def get_spec_embs(
92
  self, laplacian: np.ndarray, k_oracle: int | None = None
93
  ) -> tuple[np.ndarray, int]:
94
- """Extract spectral embeddings from Laplacian."""
 
 
 
 
 
 
95
  lambdas, eig_vecs = scipy.linalg.eigh(laplacian)
96
 
97
- if k_oracle is not None:
98
- num_of_spk = k_oracle
99
- else:
100
- lambda_gap_list = self.get_eigen_gaps(
101
- lambdas[self.min_num_spks - 1 : self.max_num_spks + 1]
102
- )
103
- num_of_spk = np.argmax(lambda_gap_list) + self.min_num_spks
104
 
105
  emb = eig_vecs[:, :num_of_spk]
106
  return emb, num_of_spk
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  def cluster_embs(self, emb: np.ndarray, k: int) -> np.ndarray:
109
  """Cluster spectral embeddings using k-means."""
110
  _, labels, _ = k_means(emb, k, n_init=10)
 
91
  def get_spec_embs(
92
  self, laplacian: np.ndarray, k_oracle: int | None = None
93
  ) -> tuple[np.ndarray, int]:
94
+ """Extract spectral embeddings from Laplacian.
95
+
96
+ Uses the eigengap heuristic to estimate the number of clusters:
97
+ The number of clusters k is chosen where the gap between consecutive
98
+ eigenvalues is largest, indicating a transition from "cluster" eigenvalues
99
+ (near 0) to "noise" eigenvalues.
100
+ """
101
  lambdas, eig_vecs = scipy.linalg.eigh(laplacian)
102
 
103
+ num_of_spk = k_oracle if k_oracle is not None else self._estimate_num_speakers(lambdas)
 
 
 
 
 
 
104
 
105
  emb = eig_vecs[:, :num_of_spk]
106
  return emb, num_of_spk
107
 
108
+ def _estimate_num_speakers(self, lambdas: np.ndarray) -> int:
109
+ """Estimate number of speakers using refined eigengap heuristic.
110
+
111
+ For spectral clustering, we look for the largest gap in eigenvalues.
112
+ The eigenvalues corresponding to clusters are close to 0, and there
113
+ should be a significant jump to the remaining eigenvalues.
114
+ """
115
+ # Consider eigenvalues from index 1 to max_num_spks (skip first, it's always ~0)
116
+ # We need gaps between positions, so look at indices 1 to max_num_spks+1
117
+ max_idx = min(self.max_num_spks + 1, len(lambdas))
118
+ relevant_lambdas = lambdas[1:max_idx] # Skip first eigenvalue
119
+
120
+ if len(relevant_lambdas) < 2:
121
+ return self.min_num_spks
122
+
123
+ # Compute absolute gaps (not ratios - ratios are unstable near 0)
124
+ gaps = np.diff(relevant_lambdas)
125
+
126
+ # Find the largest gap - the index gives us (k-1) since we skipped first
127
+ # Add 1 to convert from gap index to number of speakers
128
+ # Add 1 again because we skipped the first eigenvalue
129
+ max_gap_idx = int(np.argmax(gaps))
130
+ num_of_spk = max_gap_idx + 2 # +1 for gap->count, +1 for skipped eigenvalue
131
+
132
+ # Clamp between min and max
133
+ return max(self.min_num_spks, min(num_of_spk, self.max_num_spks))
134
+
135
  def cluster_embs(self, emb: np.ndarray, k: int) -> np.ndarray:
136
  """Cluster spectral embeddings using k-means."""
137
  _, labels, _ = k_means(emb, k, n_init=10)
tokenizer.json CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:33b674fb8444e2553eae8f1b261093371920a28ef75b5c18f4deb3f9217ed0ba
3
- size 11422834
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d4aeaf198f783cbf58d8cd59812baac429ffe49147bf9648f6618de20b8d4a4c
3
+ size 17209003
tokenizer_config.json CHANGED
Binary files a/tokenizer_config.json and b/tokenizer_config.json differ