Backup-bdg commited on
Commit
102eca9
·
verified ·
1 Parent(s): e2adb57

Update model weights after training (epoch 7, loss 4.6721)

Browse files
cross_attention.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:15e348f1b98e8cc48f633f80a818a98a727f8a95e3794d3d7496c7c67d319c21
3
  size 174191400
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8cde5e1fb540a32b44b78415f2dcf2f037489d8b119a0f33b68a49811e6b8b50
3
  size 174191400
llm.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:3888d6f2029add98a6540daf90a2fffaf8b2c0420fca1b401042a37ae56f957f
3
- size 1506832040
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3b5ffb3061f8b427f852d024bd147adf53062939c8d92babd677d4ceacf953a8
3
+ size 1506836434
model.safetensors.index.json CHANGED
@@ -1,6 +1,6 @@
1
  {
2
  "metadata": {
3
- "total_size": 7309254542,
4
  "format": "components"
5
  },
6
  "weight_map": {
@@ -696,6 +696,9 @@
696
  "llm.model.layers.11.mlp.shared_expert.down_proj.lora_B": "llm.safetensors",
697
  "llm.model.layers.11.mlp.shared_expert.down_proj.linear.weight": "llm.safetensors",
698
  "llm.model.norm.weight": "llm.safetensors",
 
 
 
699
  "llm.lm_head.weight": "llm.safetensors",
700
  "vision_encoder.vision_model.vision_model.embeddings.patch_embedding.weight": "vision_encoder.safetensors",
701
  "vision_encoder.vision_model.vision_model.embeddings.patch_embedding.bias": "vision_encoder.safetensors",
 
1
  {
2
  "metadata": {
3
+ "total_size": 7309258640,
4
  "format": "components"
5
  },
6
  "weight_map": {
 
696
  "llm.model.layers.11.mlp.shared_expert.down_proj.lora_B": "llm.safetensors",
697
  "llm.model.layers.11.mlp.shared_expert.down_proj.linear.weight": "llm.safetensors",
698
  "llm.model.norm.weight": "llm.safetensors",
699
+ "llm.model.thought_gate.weight": "llm.safetensors",
700
+ "llm.model.thought_gate.bias": "llm.safetensors",
701
+ "llm.model.thought_layernorm.weight": "llm.safetensors",
702
  "llm.lm_head.weight": "llm.safetensors",
703
  "vision_encoder.vision_model.vision_model.embeddings.patch_embedding.weight": "vision_encoder.safetensors",
704
  "vision_encoder.vision_model.vision_model.embeddings.patch_embedding.bias": "vision_encoder.safetensors",
modeling_xoron.py CHANGED
@@ -9122,6 +9122,13 @@ class MoELlamaModel (nn .Module ):
9122
 
9123
  self .num_moe_layers =sum (1 for layer in self .layers if layer .is_moe_layer )
9124
 
 
 
 
 
 
 
 
9125
  self ._init_weights ()
9126
 
9127
  def _init_weights (self ):
@@ -9147,6 +9154,7 @@ class MoELlamaModel (nn .Module ):
9147
  output_hidden_states :bool =False ,
9148
  return_dict :bool =True ,
9149
  cache_position :Optional [torch .Tensor ]=None ,
 
9150
  )->Union [Tuple ,MoELlamaModelOutput ]:
9151
 
9152
  if inputs_embeds is None :
@@ -9207,6 +9215,39 @@ class MoELlamaModel (nn .Module ):
9207
  if output_attentions and attn_weights is not None :
9208
  all_attentions =all_attentions +(attn_weights ,)
9209
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9210
  hidden_states =self .norm (hidden_states )
9211
 
9212
  if output_hidden_states :
@@ -9315,6 +9356,7 @@ class MoELlamaForCausalLM (nn .Module ):
9315
  output_hidden_states :bool =False ,
9316
  return_dict :bool =True ,
9317
  cache_position :Optional [torch .Tensor ]=None ,
 
9318
  **kwargs ,
9319
  )->Union [Tuple ,CausalLMOutput ]:
9320
 
@@ -9329,6 +9371,7 @@ class MoELlamaForCausalLM (nn .Module ):
9329
  output_hidden_states =output_hidden_states ,
9330
  return_dict =True ,
9331
  cache_position =cache_position ,
 
9332
  )
9333
 
9334
  hidden_states =outputs .last_hidden_state
@@ -9379,12 +9422,14 @@ class MoELlamaForCausalLM (nn .Module ):
9379
  pad_token_id :Optional [int ]=None ,
9380
  eos_token_id :Optional [int ]=None ,
9381
  attention_mask :Optional [torch .Tensor ]=None ,
 
9382
  **kwargs ,
9383
  )->torch .Tensor :
9384
  batch_size =input_ids .shape [0 ]
9385
  device =input_ids .device
9386
 
9387
  past_key_values =None
 
9388
 
9389
  if attention_mask is None :
9390
  attention_mask =torch .ones_like (input_ids )
@@ -9396,7 +9441,10 @@ class MoELlamaForCausalLM (nn .Module ):
9396
  attention_mask =attention_mask ,
9397
  )
9398
 
9399
- outputs =self .forward (**model_inputs ,use_cache =True ,return_dict =True )
 
 
 
9400
 
9401
  next_token_logits =outputs .logits [:,-1 ,:]
9402
 
@@ -9660,67 +9708,85 @@ class XoronMultimodalModel (nn .Module ):
9660
 
9661
 
9662
  def apply_model_parallel (self ,device_map :Dict [str ,str ]):
9663
- """Apply Model Parallelism by placing components on different devices."""
9664
- self .device_map =device_map
9665
-
9666
-
9667
-
9668
- device_values =[v for v in device_map .values ()if isinstance (v ,str )]
9669
- self ._model_parallel =len (set (device_values ))>1
 
 
9670
 
9671
- if not self ._model_parallel :
9672
  logger .info (" ℹ️ Single device - no model parallelism needed")
9673
- return self
9674
-
9675
- logger .info ("Applying Model Parallelism...")
9676
-
9677
- self .vision_encoder =self .vision_encoder .to (device_map ['vision_encoder'])
9678
- logger .info (f" ✅ Vision encoder -> {device_map ['vision_encoder']}")
9679
-
9680
- self .video_encoder =self .video_encoder .to (device_map ['video_encoder'])
9681
- logger .info (f" ✅ Video encoder -> {device_map ['video_encoder']}")
9682
-
9683
- self .audio_encoder =self .audio_encoder .to (device_map ['audio_encoder'])
9684
- logger .info (f" ✅ Audio encoder -> {device_map ['audio_encoder']}")
9685
-
9686
- self .audio_decoder =self .audio_decoder .to (device_map ['audio_decoder'])
9687
- logger .info (f" ✅ Audio decoder -> {device_map ['audio_decoder']}")
9688
-
9689
-
9690
- if hasattr (self ,'waveform_decoder')and self .waveform_decoder is not None :
9691
- waveform_device =device_map .get ('waveform_decoder',device_map ['audio_decoder'])
9692
- self .waveform_decoder =self .waveform_decoder .to (waveform_device )
9693
- logger .info (f" ✅ Waveform decoder -> {waveform_device }")
9694
-
9695
- self .projector =self .projector .to (device_map ['projector'])
9696
- logger .info (f" ✅ Projector -> {device_map ['projector']}")
9697
-
9698
- self .audio_projector =self .audio_projector .to (device_map ['audio_projector'])
9699
- logger .info (f" ✅ Audio projector -> {device_map ['audio_projector']}")
9700
-
9701
- self .llm =self .llm .to (device_map ['llm'])
9702
- logger .info (f" ✅ LLM -> {device_map ['llm']}")
9703
-
9704
- if self .cross_attention_layers is not None :
9705
- self .cross_attention_layers =self .cross_attention_layers .to (device_map ['cross_attention'])
9706
- logger .info (f" ✅ Cross-attention -> {device_map ['cross_attention']}")
9707
-
9708
- if self .generator is not None :
9709
- self .generator =self .generator .to (device_map ['generator'])
9710
- logger .info (f" ✅ Image generator -> {device_map ['generator']}")
 
 
 
 
 
 
 
9711
 
9712
- if self .video_generator is not None :
9713
- self .video_generator =self .video_generator .to (device_map ['video_generator'])
9714
- logger .info (f" ✅ Video generator -> {device_map ['video_generator']}")
9715
-
9716
- marker_device =device_map ['modality_markers']
9717
- self .image_start =nn .Parameter (self .image_start .data .to (marker_device ))
9718
- self .image_end =nn .Parameter (self .image_end .data .to (marker_device ))
9719
- self .video_start =nn .Parameter (self .video_start .data .to (marker_device ))
9720
- self .video_end =nn .Parameter (self .video_end .data .to (marker_device ))
9721
- self .audio_start =nn .Parameter (self .audio_start .data .to (marker_device ))
9722
- self .audio_end =nn .Parameter (self .audio_end .data .to (marker_device ))
9723
- logger .info (f" ✅ Modality markers -> {marker_device }")
 
 
 
 
 
 
 
 
 
9724
 
9725
  logger .info ("Model Parallelism applied successfully!")
9726
  return self
@@ -10113,63 +10179,297 @@ class XoronMultimodalModel (nn .Module ):
10113
  def listen_and_respond (
10114
  self ,
10115
  audio_waveform :torch .Tensor ,
10116
- max_new_tokens :int =256 ,
 
10117
  speaker_embedding :torch .Tensor =None ,
10118
- )->torch .Tensor :
10119
- """
10120
- Full Speech-to-Speech: Listen to audio, generate text response, speak it back.
10121
-
10122
- This is the main conversational method - you speak to it, it responds with voice.
10123
-
10124
- Args:
10125
- audio_waveform: [B, T_audio] input audio (what you said)
10126
- max_new_tokens: Maximum tokens to generate for response
10127
- speaker_embedding: Optional speaker embedding for response voice
10128
-
10129
- Returns:
10130
- response_audio: [B, T_response] audio waveform of the model's response
10131
  """
10132
- device =audio_waveform .device
10133
-
10134
-
10135
- audio_embeds =self .listen (audio_waveform )
10136
-
10137
-
10138
- batch_size =audio_waveform .shape [0 ]
10139
-
10140
 
 
 
 
10141
 
10142
- dummy_input =torch .zeros (batch_size ,1 ,dtype =torch .long ,device =device )
10143
-
10144
-
10145
-
10146
- outputs =self .forward (
10147
- input_ids =dummy_input ,
10148
- audio_features =audio_waveform ,
10149
- )
10150
 
 
 
 
 
 
 
 
 
 
 
 
 
10151
 
10152
- response_embeds =outputs .get ('hidden_states',outputs .get ('last_hidden_state'))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10154
 
10155
- if response_embeds is not None :
 
10156
 
10157
- mel ,durations ,_ ,_ =self .audio_decoder (
10158
  response_embeds ,
10159
  speaker_embedding =speaker_embedding ,
10160
- )
10161
-
10162
 
10163
- mel_features =mel .transpose (1 ,2 )
10164
- if not hasattr (self ,'_mel_to_hidden'):
10165
- self ._mel_to_hidden =nn .Linear (80 ,self .config .hidden_size ).to (device )
10166
- audio_features =self ._mel_to_hidden (mel_features )
10167
- response_audio =self .waveform_decoder (audio_features )
10168
 
10169
- return response_audio
10170
 
 
 
 
 
 
 
 
 
10171
 
10172
- return torch .zeros (batch_size ,16000 ,device =device )
10173
 
10174
  def merge_lora_weights (self ):
10175
  """Merge LoRA weights into main weights for inference."""
@@ -11120,6 +11420,10 @@ XoronForCausalLM.register_for_auto_class("AutoModelForCausalLM")
11120
  model =cls (config ,device_map =device_map )
11121
 
11122
 
 
 
 
 
11123
  components_json =os .path .join (path ,"components.json")
11124
  model_path =os .path .join (path ,"model.safetensors")
11125
 
@@ -11127,7 +11431,7 @@ XoronForCausalLM.register_for_auto_class("AutoModelForCausalLM")
11127
 
11128
  logger .info ("Loading from component-based format...")
11129
  model ._load_components (path ,strict =strict )
11130
- model .lora_applied =lora_was_applied
11131
 
11132
  elif os .path .exists (model_path ):
11133
  logger .info ("Loading weights from safetensors...")
@@ -11143,7 +11447,7 @@ XoronForCausalLM.register_for_auto_class("AutoModelForCausalLM")
11143
  model .load_state_dict (checkpoint_state_dict ,strict =False )
11144
  logger .info ("Loaded weights from checkpoint")
11145
 
11146
- model .lora_applied =lora_was_applied
11147
  else :
11148
 
11149
  pytorch_path =os .path .join (path ,"pytorch_model.bin")
@@ -11154,7 +11458,7 @@ XoronForCausalLM.register_for_auto_class("AutoModelForCausalLM")
11154
  model .load_state_dict (checkpoint_state_dict ,strict =False )
11155
  logger .info ("Loaded weights from checkpoint")
11156
 
11157
- model .lora_applied =lora_was_applied
11158
  else :
11159
  raise FileNotFoundError (f"No model weights found at {path }")
11160
 
 
9122
 
9123
  self .num_moe_layers =sum (1 for layer in self .layers if layer .is_moe_layer )
9124
 
9125
+ # ── Coconut: Continuous Thought components ──
9126
+ # Learned gate controls how much recurrent thought vs original input
9127
+ # to retain at each thinking step. Sigmoid output in [0,1].
9128
+ self .thought_gate = nn .Linear (config .hidden_size , 1 , bias =True )
9129
+ nn .init .constant_ (self .thought_gate .bias , -2.0 ) # Initialize gate biased toward original (sigmoid(-2)≈0.12)
9130
+ self .thought_layernorm = LlamaRMSNorm (config .hidden_size , eps =config .rms_norm_eps )
9131
+
9132
  self ._init_weights ()
9133
 
9134
  def _init_weights (self ):
 
9154
  output_hidden_states :bool =False ,
9155
  return_dict :bool =True ,
9156
  cache_position :Optional [torch .Tensor ]=None ,
9157
+ thinking_depth :int =0 ,
9158
  )->Union [Tuple ,MoELlamaModelOutput ]:
9159
 
9160
  if inputs_embeds is None :
 
9215
  if output_attentions and attn_weights is not None :
9216
  all_attentions =all_attentions +(attn_weights ,)
9217
 
9218
+ # ── Coconut: Continuous Thought Loop ──
9219
+ # After the normal pass, loop hidden states back through the
9220
+ # transformer layers for extra computation in latent space.
9221
+ # No tokens are decoded — pure continuous reasoning.
9222
+ if thinking_depth > 0 :
9223
+ original_hidden = hidden_states .clone ()
9224
+ thought_position_ids = torch .arange (
9225
+ seq_len , device =hidden_states .device
9226
+ ).unsqueeze (0 ).expand (batch_size , -1 )
9227
+
9228
+ for thought_step in range (thinking_depth ):
9229
+ # Normalize before re-entering the layers
9230
+ hidden_states = self .thought_layernorm (hidden_states )
9231
+
9232
+ # Run through all layers again (no cache — full re-computation)
9233
+ for layer in self .layers :
9234
+ hidden_states , _ , _ , step_aux = layer (
9235
+ hidden_states =hidden_states ,
9236
+ attention_mask =None , # Self-attend freely in thought space
9237
+ position_ids =thought_position_ids ,
9238
+ past_key_value =None ,
9239
+ output_attentions =False ,
9240
+ use_cache =False ,
9241
+ )
9242
+ if step_aux is not None :
9243
+ total_aux_loss = total_aux_loss + step_aux
9244
+
9245
+ # Gated residual: blend thought with original
9246
+ # gate ∈ [0,1], initialized small so early training
9247
+ # stays close to original behavior
9248
+ gate = torch .sigmoid (self .thought_gate (hidden_states ))
9249
+ hidden_states = gate * hidden_states + (1.0 - gate ) * original_hidden
9250
+
9251
  hidden_states =self .norm (hidden_states )
9252
 
9253
  if output_hidden_states :
 
9356
  output_hidden_states :bool =False ,
9357
  return_dict :bool =True ,
9358
  cache_position :Optional [torch .Tensor ]=None ,
9359
+ thinking_depth :int =0 ,
9360
  **kwargs ,
9361
  )->Union [Tuple ,CausalLMOutput ]:
9362
 
 
9371
  output_hidden_states =output_hidden_states ,
9372
  return_dict =True ,
9373
  cache_position =cache_position ,
9374
+ thinking_depth =thinking_depth ,
9375
  )
9376
 
9377
  hidden_states =outputs .last_hidden_state
 
9422
  pad_token_id :Optional [int ]=None ,
9423
  eos_token_id :Optional [int ]=None ,
9424
  attention_mask :Optional [torch .Tensor ]=None ,
9425
+ thinking_depth :int =0 ,
9426
  **kwargs ,
9427
  )->torch .Tensor :
9428
  batch_size =input_ids .shape [0 ]
9429
  device =input_ids .device
9430
 
9431
  past_key_values =None
9432
+ is_prefill =True # Deep thinking only on first pass (full context)
9433
 
9434
  if attention_mask is None :
9435
  attention_mask =torch .ones_like (input_ids )
 
9441
  attention_mask =attention_mask ,
9442
  )
9443
 
9444
+ # Apply thinking depth only on prefill, not per-token steps
9445
+ current_depth = thinking_depth if is_prefill else 0
9446
+ outputs =self .forward (**model_inputs ,use_cache =True ,return_dict =True ,thinking_depth =current_depth )
9447
+ is_prefill =False
9448
 
9449
  next_token_logits =outputs .logits [:,-1 ,:]
9450
 
 
9708
 
9709
 
9710
  def apply_model_parallel (self ,device_map :Dict [str ,str ]):
9711
+ """Apply Model Parallelism by sharding components across devices.
9712
+
9713
+ Trained components get their layers split across all training GPUs.
9714
+ Frozen components go to CPU. Small components (projectors, markers)
9715
+ go to the primary GPU.
9716
+ """
9717
+ self .device_map =device_map
9718
+ training_gpus = device_map .get ('training_gpus', ['cuda:0'])
9719
+ primary = device_map .get ('primary', 'cuda:0')
9720
 
9721
+ if len (training_gpus ) <= 1 and not any (v == 'cpu' for v in device_map .values () if isinstance (v, str)):
9722
  logger .info (" ℹ️ Single device - no model parallelism needed")
9723
+ return self
9724
+
9725
+ self ._model_parallel = True
9726
+ logger .info ("Applying Model Parallelism (layer sharding)...")
9727
+
9728
+ def _shard_module (module, name, gpus):
9729
+ """Shard a module's sub-layers across GPUs."""
9730
+ # Find shardable sub-layers (nn.ModuleList children)
9731
+ layer_lists = []
9732
+ for attr_name in dir (module):
9733
+ attr = getattr (module, attr_name, None)
9734
+ if isinstance (attr, nn .ModuleList) and len (attr) > 0:
9735
+ layer_lists .append ((attr_name, attr))
9736
+
9737
+ if layer_lists:
9738
+ # Shard the largest ModuleList across GPUs
9739
+ layer_lists .sort (key=lambda x: len (x[1]), reverse=True)
9740
+ list_name, layers = layer_lists [0]
9741
+ for i, layer in enumerate (layers):
9742
+ target_gpu = gpus [i % len (gpus)]
9743
+ layer .to (target_gpu)
9744
+ # Put remaining params on primary GPU
9745
+ for param_name, param in module .named_parameters ():
9746
+ if not any (f'{list_name}.' in param_name for _ in [1]):
9747
+ param .data = param .data .to (gpus [0])
9748
+ logger .info (f" ✅ {name}: {len(layers)} layers sharded across {gpus}")
9749
+ else:
9750
+ # No layers to shard — put whole module on first GPU
9751
+ module .to (gpus [0])
9752
+ logger .info (f" ✅ {name} -> {gpus[0]}")
9753
+
9754
+ # Map component names to actual attributes
9755
+ component_attrs = {
9756
+ 'vision_encoder': 'vision_encoder',
9757
+ 'video_encoder': 'video_encoder',
9758
+ 'audio_encoder': 'audio_encoder',
9759
+ 'audio_decoder': 'audio_decoder',
9760
+ 'waveform_decoder': 'waveform_decoder',
9761
+ 'projector': 'projector',
9762
+ 'audio_projector': 'audio_projector',
9763
+ 'llm': 'llm',
9764
+ 'cross_attention': 'cross_attention_layers',
9765
+ 'generator': 'generator',
9766
+ 'video_generator': 'video_generator',
9767
+ }
9768
 
9769
+ for comp_name, attr_name in component_attrs .items ():
9770
+ comp = getattr (self, attr_name, None)
9771
+ if comp is None:
9772
+ continue
9773
+ target = device_map .get (comp_name, 'cpu')
9774
+ if target == 'cpu':
9775
+ comp .to ('cpu')
9776
+ logger .info (f" ❄️ {comp_name} -> cpu (frozen)")
9777
+ else:
9778
+ # Shard across all training GPUs
9779
+ _shard_module (comp, comp_name, training_gpus)
9780
+
9781
+ # Modality markers → primary GPU
9782
+ marker_device = device_map .get ('modality_markers', primary)
9783
+ if marker_device != 'cpu':
9784
+ marker_device = primary
9785
+ for marker_name in ['image_start', 'image_end', 'video_start', 'video_end', 'audio_start', 'audio_end']:
9786
+ marker = getattr (self, marker_name, None)
9787
+ if marker is not None:
9788
+ setattr (self, marker_name, nn .Parameter (marker .data .to (marker_device)))
9789
+ logger .info (f" ✅ Modality markers -> {marker_device}")
9790
 
9791
  logger .info ("Model Parallelism applied successfully!")
9792
  return self
 
10179
  def listen_and_respond (
10180
  self ,
10181
  audio_waveform :torch .Tensor ,
10182
+ tokenizer =None ,
10183
+ max_new_tokens :int =512 ,
10184
  speaker_embedding :torch .Tensor =None ,
10185
+ temperature :float =0.7 ,
10186
+ top_p :float =0.9 ,
10187
+ tool_executor =None ,
10188
+ available_tools :list =None ,
10189
+ system_prompt :str =None ,
10190
+ max_tool_calls :int =5 ,
10191
+ ) -> Dict [str ,Any ]:
 
 
 
 
 
 
10192
  """
10193
+ Agentic Speech-to-Speech: Listen, think, use tools, speak back.
 
 
 
 
 
 
 
10194
 
10195
+ This is the full agentic pipeline for live voice conversations.
10196
+ The model can detect when the user is asking for actions (e.g.
10197
+ "write me a Python script") and execute tools mid-generation.
10198
 
10199
+ Pipeline:
10200
+ 1. Encode input audio → audio embeddings (ASR)
10201
+ 2. Build context (system prompt with tools + audio embeddings)
10202
+ 3. Generate tokens, watching for <|tool_call|> sequences
10203
+ 4. When tool call detected: parse, execute, inject result, resume
10204
+ 5. Synthesize final spoken response from non-tool text
 
 
10205
 
10206
+ Args:
10207
+ audio_waveform: [B, T_audio] input audio waveform
10208
+ tokenizer: Tokenizer for decoding tokens to text (required for tools)
10209
+ max_new_tokens: Maximum total tokens to generate
10210
+ speaker_embedding: [B, D] optional speaker embedding for voice cloning
10211
+ temperature: Sampling temperature
10212
+ top_p: Nucleus sampling probability
10213
+ tool_executor: Callable(tool_name, args_dict) -> str result.
10214
+ If None, tool calls are detected but not executed.
10215
+ available_tools: List of tool definition dicts for system prompt.
10216
+ system_prompt: Optional system prompt override.
10217
+ max_tool_calls: Maximum number of tool calls per response (safety limit).
10218
 
10219
+ Returns:
10220
+ Dict with:
10221
+ 'waveform': [B, T_response] audio waveform tensor (in-memory, no file I/O)
10222
+ 'text': str full response text (excluding tool call markup)
10223
+ 'token_ids': [B, T_tokens] all generated token IDs
10224
+ 'mel': [B, 80, T_mel] intermediate mel spectrogram
10225
+ 'tool_calls': List[Dict] executed tool calls and their results
10226
+ 'speaking_text': str clean text that was spoken (no tool markup)
10227
+ """
10228
+ import re
10229
+ import json as _json
10230
+
10231
+ device = audio_waveform .device
10232
+ batch_size = audio_waveform .shape [0 ]
10233
+ llm_device = self .get_llm_device ()
10234
+
10235
+ # ── 1. Listen: encode input audio ──
10236
+ audio_embeds = self .encode_audio (audio_waveform )
10237
+
10238
+
10239
+ # Wrap with start/end markers
10240
+ audio_start = self .audio_start .expand (batch_size , -1 , -1 ).to (llm_device )
10241
+ audio_end = self .audio_end .expand (batch_size , -1 , -1 ).to (llm_device )
10242
+ audio_embeds = audio_embeds .to (llm_device )
10243
+
10244
+ # ── 2. Build context with system prompt + tools ──
10245
+ context_parts = []
10246
+
10247
+ if tokenizer is not None and (system_prompt or tool_executor):
10248
+ sys_text = system_prompt or "You are Xoron, an intelligent voice assistant. You can use tools to help the user."
10249
+ if tool_executor and hasattr (tool_executor , 'get_tool_prompt' ):
10250
+ sys_text = sys_text + "\n\n" + tool_executor .get_tool_prompt ()
10251
+ elif available_tools :
10252
+ from utils .tool_executor import format_tools_for_prompt
10253
+ sys_text = sys_text + "\n\n" + format_tools_for_prompt (available_tools )
10254
+
10255
+ # Encode system prompt and prepend
10256
+ sys_str = "<|system|>" + sys_text + "<|/system|>"
10257
+ sys_token_ids = tokenizer .encode (sys_str , return_tensors ="pt" ).to (llm_device )
10258
+ sys_embeds = self .llm .model .embed_tokens (sys_token_ids )
10259
+ context_parts .append (sys_embeds .squeeze (0 ) if sys_embeds .dim () == 3 else sys_embeds )
10260
+
10261
+ # Audio context
10262
+ context_parts .extend ([audio_start , audio_embeds , audio_end ])
10263
+
10264
+ # Assistant generation prompt
10265
+ if tokenizer is not None :
10266
+ asst_str = "<|assistant|>"
10267
+ asst_ids = tokenizer .encode (asst_str , return_tensors ="pt" ).to (llm_device )
10268
+ asst_embeds = self .llm .model .embed_tokens (asst_ids )
10269
+ context_parts .append (asst_embeds .squeeze (0 ) if asst_embeds .dim () == 3 else asst_embeds )
10270
+
10271
+ input_embeds = torch .cat (context_parts , dim =1 )
10272
+
10273
+ # ── 3. Agentic generation loop with tool call detection ──
10274
+ tool_call_start_token = "<|tool_call|>"
10275
+ tool_call_end_token = "<|/tool_call|>"
10276
+ fn_name_start = "<|function_name|>"
10277
+ fn_name_end = "<|/function_name|>"
10278
+ fn_args_start = "<|function_args|>"
10279
+ fn_args_end = "<|/function_args|>"
10280
+ tool_result_start = "<|tool_result|>"
10281
+ tool_result_end = "<|/tool_result|>"
10282
+ eos_token = "<|eos|>"
10283
+
10284
+ all_generated_ids = []
10285
+ tool_calls_made = []
10286
+ num_tool_calls = 0
10287
+ generated_text = ""
10288
+ total_tokens = 0
10289
+
10290
+ # Use standard generation if no tool executor
10291
+ if tool_executor is None or tokenizer is None :
10292
+ gen_kwargs = {
10293
+ 'inputs_embeds': input_embeds ,
10294
+ 'max_new_tokens': max_new_tokens ,
10295
+ 'do_sample': True ,
10296
+ 'temperature': temperature ,
10297
+ 'top_p': top_p ,
10298
+ 'use_cache': True ,
10299
+ }
10300
+ generated_ids = self .llm .generate (**gen_kwargs )
10301
+ all_generated_ids = [generated_ids ]
10302
 
10303
+ if tokenizer is not None :
10304
+ generated_text = tokenizer .batch_decode (generated_ids , skip_special_tokens =True )[0 ]
10305
+ else :
10306
+ # Token-by-token generation with tool call detection
10307
+ current_embeds = input_embeds
10308
+ past_key_values = None
10309
+ in_tool_call = False
10310
+ tool_call_buffer = ""
10311
+
10312
+ while total_tokens < max_new_tokens :
10313
+ outputs = self .llm (
10314
+ inputs_embeds =current_embeds ,
10315
+ past_key_values =past_key_values ,
10316
+ use_cache =True ,
10317
+ )
10318
+ past_key_values = outputs .past_key_values
10319
+ logits = outputs .logits [:, -1 :, :]
10320
+
10321
+ # Sample next token
10322
+ if temperature > 0 :
10323
+ logits = logits / temperature
10324
+ if top_p < 1.0 :
10325
+ sorted_logits , sorted_indices = torch .sort (logits , descending =True , dim =-1 )
10326
+ cumulative_probs = torch .cumsum (F .softmax (sorted_logits , dim =-1 ), dim =-1 )
10327
+ sorted_mask = cumulative_probs - F .softmax (sorted_logits , dim =-1 ) >= top_p
10328
+ sorted_logits [sorted_mask ] = float ('-inf' )
10329
+ logits .scatter_ (-1 , sorted_indices , sorted_logits )
10330
+ probs = F .softmax (logits , dim =-1 )
10331
+ next_token = torch .multinomial (probs .squeeze (1 ), num_samples =1 )
10332
+ else :
10333
+ next_token = logits .argmax (dim =-1 )
10334
+
10335
+ total_tokens += 1
10336
+ all_generated_ids .append (next_token )
10337
+
10338
+ # Decode the token
10339
+ token_text = tokenizer .decode (next_token [0 ], skip_special_tokens =False )
10340
+ generated_text = generated_text + token_text
10341
+
10342
+ # Check for EOS
10343
+ if eos_token in token_text or next_token .item () == tokenizer .eos_token_id :
10344
+ break
10345
+
10346
+ # ── Tool call detection ──
10347
+ if tool_call_start_token in generated_text and not in_tool_call :
10348
+ in_tool_call = True
10349
+ # Extract everything after the tool_call_start
10350
+ tc_start_idx = generated_text .rfind (tool_call_start_token )
10351
+ tool_call_buffer = generated_text [tc_start_idx :]
10352
+
10353
+ if in_tool_call :
10354
+ tool_call_buffer = tool_call_buffer + token_text if tool_call_buffer else generated_text
10355
+
10356
+ # Check if we have a complete tool call
10357
+ if tool_call_end_token in tool_call_buffer :
10358
+ in_tool_call = False
10359
+ num_tool_calls += 1
10360
+
10361
+ # Parse the tool call
10362
+ tool_name = ""
10363
+ tool_args = {}
10364
+ try :
10365
+ # Extract function name
10366
+ name_start = tool_call_buffer .find (fn_name_start ) + len (fn_name_start )
10367
+ name_end = tool_call_buffer .find (fn_name_end )
10368
+ if name_start > 0 and name_end > 0 :
10369
+ tool_name = tool_call_buffer [name_start :name_end ].strip ()
10370
+
10371
+ # Extract arguments
10372
+ args_start = tool_call_buffer .find (fn_args_start ) + len (fn_args_start )
10373
+ args_end = tool_call_buffer .find (fn_args_end )
10374
+ if args_start > 0 and args_end > 0 :
10375
+ args_str = tool_call_buffer [args_start :args_end ].strip ()
10376
+ try :
10377
+ import json as _json
10378
+ tool_args = _json .loads (args_str )
10379
+ except Exception :
10380
+ tool_args = {"raw": args_str }
10381
+ except Exception :
10382
+ pass
10383
+
10384
+ # Execute the tool
10385
+ tool_result = "[error]: Failed to parse tool call"
10386
+ if tool_name :
10387
+ tool_result = tool_executor (tool_name , tool_args )
10388
+
10389
+ tool_calls_made .append ({
10390
+ "name": tool_name ,
10391
+ "arguments": tool_args ,
10392
+ "result": tool_result ,
10393
+ })
10394
+
10395
+ # Inject tool result back into generation context
10396
+ result_str = tool_result_start + tool_result + tool_result_end
10397
+ result_ids = tokenizer .encode (result_str , return_tensors ="pt" ).to (llm_device )
10398
+ result_embeds = self .llm .model .embed_tokens (result_ids )
10399
+ current_embeds = result_embeds
10400
+ past_key_values = None # Reset KV cache to include result
10401
+ all_generated_ids .append (result_ids .squeeze (0 ))
10402
+
10403
+ generated_text = generated_text + result_str
10404
+ tool_call_buffer = ""
10405
+
10406
+ if num_tool_calls >= max_tool_calls :
10407
+ break
10408
+
10409
+ continue
10410
+
10411
+ # Prepare next input
10412
+ next_embeds = self .llm .model .embed_tokens (next_token )
10413
+ current_embeds = next_embeds
10414
+
10415
+ # Combine all generated IDs
10416
+ if all_generated_ids :
10417
+ flat_ids = []
10418
+ for t in all_generated_ids :
10419
+ if t .dim () == 0 :
10420
+ flat_ids .append (t .unsqueeze (0 ))
10421
+ elif t .dim () == 1 :
10422
+ flat_ids .append (t )
10423
+ else :
10424
+ flat_ids .append (t .view (-1 ))
10425
+ generated_ids = torch .cat (flat_ids , dim =0 ).unsqueeze (0 )
10426
+ else :
10427
+ generated_ids = torch .tensor ([[]], dtype =torch .long , device =llm_device )
10428
+
10429
+ # ── 4. Extract speaking text (strip tool call/result markup) ──
10430
+ speaking_text = generated_text
10431
+ # Remove tool call blocks
10432
+ while tool_call_start_token in speaking_text :
10433
+ tc_s = speaking_text .find (tool_call_start_token )
10434
+ tc_e = speaking_text .find (tool_call_end_token )
10435
+ if tc_e > tc_s :
10436
+ speaking_text = speaking_text [:tc_s ] + speaking_text [tc_e + len (tool_call_end_token ):]
10437
+ else :
10438
+ break
10439
+ # Remove tool result blocks
10440
+ while tool_result_start in speaking_text :
10441
+ tr_s = speaking_text .find (tool_result_start )
10442
+ tr_e = speaking_text .find (tool_result_end )
10443
+ if tr_e > tr_s :
10444
+ speaking_text = speaking_text [:tr_s ] + speaking_text [tr_e + len (tool_result_end ):]
10445
+ else :
10446
+ break
10447
+ speaking_text = speaking_text .strip ()
10448
 
10449
+ # ── 5. Speak: encode → mel → stream_decode → waveform ──
10450
+ response_embeds = self .llm .model .embed_tokens (generated_ids .to (llm_device ))
10451
 
10452
+ mel , durations , _ , _ = self .audio_decoder (
10453
  response_embeds ,
10454
  speaker_embedding =speaker_embedding ,
10455
+ )
 
10456
 
10457
+ mel_features = mel .transpose (1 , 2 )
10458
+ if not hasattr (self , '_mel_to_hidden' ):
10459
+ self ._mel_to_hidden = nn .Linear (80 , self .config .hidden_size ).to (mel .device )
10460
+ audio_features = self ._mel_to_hidden (mel_features )
 
10461
 
10462
+ waveform = self .waveform_decoder .stream_decode (audio_features )
10463
 
10464
+ return {
10465
+ 'waveform': waveform ,
10466
+ 'text': generated_text ,
10467
+ 'speaking_text': speaking_text ,
10468
+ 'token_ids': generated_ids ,
10469
+ 'mel': mel ,
10470
+ 'tool_calls': tool_calls_made ,
10471
+ }
10472
 
 
10473
 
10474
  def merge_lora_weights (self ):
10475
  """Merge LoRA weights into main weights for inference."""
 
11420
  model =cls (config ,device_map =device_map )
11421
 
11422
 
11423
+ if lora_was_applied:
11424
+ logger .info ("Checkpoint has LoRA weights. Applying LoRA structure before loading...")
11425
+ model .apply_lora ()
11426
+
11427
  components_json =os .path .join (path ,"components.json")
11428
  model_path =os .path .join (path ,"model.safetensors")
11429
 
 
11431
 
11432
  logger .info ("Loading from component-based format...")
11433
  model ._load_components (path ,strict =strict )
11434
+ model .lora_applied =False # Always allow fresh LoRA application (checkpoint has merged weights)
11435
 
11436
  elif os .path .exists (model_path ):
11437
  logger .info ("Loading weights from safetensors...")
 
11447
  model .load_state_dict (checkpoint_state_dict ,strict =False )
11448
  logger .info ("Loaded weights from checkpoint")
11449
 
11450
+ model .lora_applied =False # Always allow fresh LoRA application (checkpoint has merged weights)
11451
  else :
11452
 
11453
  pytorch_path =os .path .join (path ,"pytorch_model.bin")
 
11458
  model .load_state_dict (checkpoint_state_dict ,strict =False )
11459
  logger .info ("Loaded weights from checkpoint")
11460
 
11461
+ model .lora_applied =False # Always allow fresh LoRA application (checkpoint has merged weights)
11462
  else :
11463
  raise FileNotFoundError (f"No model weights found at {path }")
11464
 
streaming_state.json CHANGED
@@ -1,7 +1,7 @@
1
  {
2
- "epoch": 135,
3
- "unique_samples": 350,
4
- "total_yields": 700,
5
  "dataset_positions": {
6
  "WebSight": 386,
7
  "ScienceQA": 364,
@@ -30,7 +30,7 @@
30
  "NoRobots": 800,
31
  "Synth-LanguageSetup": 200,
32
  "Function-Calling-ChatML": 200,
33
- "Synth-CoT": 550,
34
  "Python-Code-18k": 200,
35
  "Code-Feedback": 200,
36
  "HumanEval-CPP": 164,
@@ -97,7 +97,9 @@
97
  "Cosmopedia-OpenStax": 600,
98
  "MedMCQA": 650,
99
  "Medical-Reasoning-SFT-Mega": 650,
100
- "Medical-O1-Reasoning-EN": 650
 
 
101
  },
102
  "modality_positions": {
103
  "text": {
@@ -154,7 +156,10 @@
154
  "Synth-FactCheck": 550,
155
  "Synth-ConfidenceLevel": 550,
156
  "Synth-Citation": 550,
157
- "Synth-Uncertainty": 550
 
 
 
158
  },
159
  "image": {
160
  "WebSight": 386,
@@ -179,10 +184,11 @@
179
  "audio": {}
180
  },
181
  "modality_counts": {
182
- "text": 350,
183
  "image": 0,
184
  "video": 0,
185
- "audio": 0
 
186
  },
187
  "last_modality": null
188
  }
 
1
  {
2
+ "epoch": 148,
3
+ "unique_samples": 150,
4
+ "total_yields": 300,
5
  "dataset_positions": {
6
  "WebSight": 386,
7
  "ScienceQA": 364,
 
30
  "NoRobots": 800,
31
  "Synth-LanguageSetup": 200,
32
  "Function-Calling-ChatML": 200,
33
+ "Synth-CoT": 900,
34
  "Python-Code-18k": 200,
35
  "Code-Feedback": 200,
36
  "HumanEval-CPP": 164,
 
97
  "Cosmopedia-OpenStax": 600,
98
  "MedMCQA": 650,
99
  "Medical-Reasoning-SFT-Mega": 650,
100
+ "Medical-O1-Reasoning-EN": 650,
101
+ "OpenThoughts-114k": 350,
102
+ "Bespoke-Stratos-17k": 350
103
  },
104
  "modality_positions": {
105
  "text": {
 
156
  "Synth-FactCheck": 550,
157
  "Synth-ConfidenceLevel": 550,
158
  "Synth-Citation": 550,
159
+ "Synth-Uncertainty": 550,
160
+ "OpenThoughts-114k": 350,
161
+ "Bespoke-Stratos-17k": 350,
162
+ "Synth-CoT": 900
163
  },
164
  "image": {
165
  "WebSight": 386,
 
184
  "audio": {}
185
  },
186
  "modality_counts": {
187
+ "text": 0,
188
  "image": 0,
189
  "video": 0,
190
+ "audio": 0,
191
+ "reasoning": 150
192
  },
193
  "last_modality": null
194
  }
trainer_state.json CHANGED
@@ -1,14 +1,14 @@
1
  {
2
  "best_model_checkpoint": "/kaggle/working/xoron-final",
3
- "best_metric": 6.622317645549774,
4
  "epoch": 7,
5
  "epochs_completed": 7,
6
- "global_step": 301,
7
  "is_local_process_zero": true,
8
  "is_world_process_zero": true,
9
  "log_history": [],
10
  "logging_steps": 50,
11
- "max_steps": 301,
12
  "num_train_epochs": 7,
13
  "total_flos": 0,
14
  "train_batch_size": 1,
 
1
  {
2
  "best_model_checkpoint": "/kaggle/working/xoron-final",
3
+ "best_metric": 4.672067043383916,
4
  "epoch": 7,
5
  "epochs_completed": 7,
6
+ "global_step": 126,
7
  "is_local_process_zero": true,
8
  "is_world_process_zero": true,
9
  "log_history": [],
10
  "logging_steps": 50,
11
+ "max_steps": 126,
12
  "num_train_epochs": 7,
13
  "total_flos": 0,
14
  "train_batch_size": 1,
training_state.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a89e9a0652c7c060ae5d2f1211f9a8ce9e301009c1282faa827cfb44a01e4db3
3
- size 1514912171
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:681e7dfcad848ba9070f7c69e68e4517dd623592df17ce9d5e050436701a3611
3
+ size 1514916733