Any-to-Any
Transformers
Safetensors
English
xoron
multimodal
Mixture of Experts
text-to-image
image editing
image to video
text-to-video
video editing
text-to-speech
speech-to-text
speech-to-speech
image-to-text
video-to-text
agentic
tool-use
flow-matching
3d-rope
titok
vidtok
dual-stream-attention
zero-shot-voice-cloning
bigvgan
snake-activation
multi-receptive-field-fusion
custom_code
Update model weights after training (epoch 7, loss 4.6721)
Browse files- cross_attention.safetensors +1 -1
- llm.safetensors +2 -2
- model.safetensors.index.json +4 -1
- modeling_xoron.py +408 -104
- streaming_state.json +14 -8
- trainer_state.json +3 -3
- training_state.pt +2 -2
cross_attention.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 174191400
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8cde5e1fb540a32b44b78415f2dcf2f037489d8b119a0f33b68a49811e6b8b50
|
| 3 |
size 174191400
|
llm.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3b5ffb3061f8b427f852d024bd147adf53062939c8d92babd677d4ceacf953a8
|
| 3 |
+
size 1506836434
|
model.safetensors.index.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
{
|
| 2 |
"metadata": {
|
| 3 |
-
"total_size":
|
| 4 |
"format": "components"
|
| 5 |
},
|
| 6 |
"weight_map": {
|
|
@@ -696,6 +696,9 @@
|
|
| 696 |
"llm.model.layers.11.mlp.shared_expert.down_proj.lora_B": "llm.safetensors",
|
| 697 |
"llm.model.layers.11.mlp.shared_expert.down_proj.linear.weight": "llm.safetensors",
|
| 698 |
"llm.model.norm.weight": "llm.safetensors",
|
|
|
|
|
|
|
|
|
|
| 699 |
"llm.lm_head.weight": "llm.safetensors",
|
| 700 |
"vision_encoder.vision_model.vision_model.embeddings.patch_embedding.weight": "vision_encoder.safetensors",
|
| 701 |
"vision_encoder.vision_model.vision_model.embeddings.patch_embedding.bias": "vision_encoder.safetensors",
|
|
|
|
| 1 |
{
|
| 2 |
"metadata": {
|
| 3 |
+
"total_size": 7309258640,
|
| 4 |
"format": "components"
|
| 5 |
},
|
| 6 |
"weight_map": {
|
|
|
|
| 696 |
"llm.model.layers.11.mlp.shared_expert.down_proj.lora_B": "llm.safetensors",
|
| 697 |
"llm.model.layers.11.mlp.shared_expert.down_proj.linear.weight": "llm.safetensors",
|
| 698 |
"llm.model.norm.weight": "llm.safetensors",
|
| 699 |
+
"llm.model.thought_gate.weight": "llm.safetensors",
|
| 700 |
+
"llm.model.thought_gate.bias": "llm.safetensors",
|
| 701 |
+
"llm.model.thought_layernorm.weight": "llm.safetensors",
|
| 702 |
"llm.lm_head.weight": "llm.safetensors",
|
| 703 |
"vision_encoder.vision_model.vision_model.embeddings.patch_embedding.weight": "vision_encoder.safetensors",
|
| 704 |
"vision_encoder.vision_model.vision_model.embeddings.patch_embedding.bias": "vision_encoder.safetensors",
|
modeling_xoron.py
CHANGED
|
@@ -9122,6 +9122,13 @@ class MoELlamaModel (nn .Module ):
|
|
| 9122 |
|
| 9123 |
self .num_moe_layers =sum (1 for layer in self .layers if layer .is_moe_layer )
|
| 9124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9125 |
self ._init_weights ()
|
| 9126 |
|
| 9127 |
def _init_weights (self ):
|
|
@@ -9147,6 +9154,7 @@ class MoELlamaModel (nn .Module ):
|
|
| 9147 |
output_hidden_states :bool =False ,
|
| 9148 |
return_dict :bool =True ,
|
| 9149 |
cache_position :Optional [torch .Tensor ]=None ,
|
|
|
|
| 9150 |
)->Union [Tuple ,MoELlamaModelOutput ]:
|
| 9151 |
|
| 9152 |
if inputs_embeds is None :
|
|
@@ -9207,6 +9215,39 @@ class MoELlamaModel (nn .Module ):
|
|
| 9207 |
if output_attentions and attn_weights is not None :
|
| 9208 |
all_attentions =all_attentions +(attn_weights ,)
|
| 9209 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9210 |
hidden_states =self .norm (hidden_states )
|
| 9211 |
|
| 9212 |
if output_hidden_states :
|
|
@@ -9315,6 +9356,7 @@ class MoELlamaForCausalLM (nn .Module ):
|
|
| 9315 |
output_hidden_states :bool =False ,
|
| 9316 |
return_dict :bool =True ,
|
| 9317 |
cache_position :Optional [torch .Tensor ]=None ,
|
|
|
|
| 9318 |
**kwargs ,
|
| 9319 |
)->Union [Tuple ,CausalLMOutput ]:
|
| 9320 |
|
|
@@ -9329,6 +9371,7 @@ class MoELlamaForCausalLM (nn .Module ):
|
|
| 9329 |
output_hidden_states =output_hidden_states ,
|
| 9330 |
return_dict =True ,
|
| 9331 |
cache_position =cache_position ,
|
|
|
|
| 9332 |
)
|
| 9333 |
|
| 9334 |
hidden_states =outputs .last_hidden_state
|
|
@@ -9379,12 +9422,14 @@ class MoELlamaForCausalLM (nn .Module ):
|
|
| 9379 |
pad_token_id :Optional [int ]=None ,
|
| 9380 |
eos_token_id :Optional [int ]=None ,
|
| 9381 |
attention_mask :Optional [torch .Tensor ]=None ,
|
|
|
|
| 9382 |
**kwargs ,
|
| 9383 |
)->torch .Tensor :
|
| 9384 |
batch_size =input_ids .shape [0 ]
|
| 9385 |
device =input_ids .device
|
| 9386 |
|
| 9387 |
past_key_values =None
|
|
|
|
| 9388 |
|
| 9389 |
if attention_mask is None :
|
| 9390 |
attention_mask =torch .ones_like (input_ids )
|
|
@@ -9396,7 +9441,10 @@ class MoELlamaForCausalLM (nn .Module ):
|
|
| 9396 |
attention_mask =attention_mask ,
|
| 9397 |
)
|
| 9398 |
|
| 9399 |
-
|
|
|
|
|
|
|
|
|
|
| 9400 |
|
| 9401 |
next_token_logits =outputs .logits [:,-1 ,:]
|
| 9402 |
|
|
@@ -9660,67 +9708,85 @@ class XoronMultimodalModel (nn .Module ):
|
|
| 9660 |
|
| 9661 |
|
| 9662 |
def apply_model_parallel (self ,device_map :Dict [str ,str ]):
|
| 9663 |
-
"""Apply Model Parallelism by
|
| 9664 |
-
|
| 9665 |
-
|
| 9666 |
-
|
| 9667 |
-
|
| 9668 |
-
|
| 9669 |
-
self .
|
|
|
|
|
|
|
| 9670 |
|
| 9671 |
-
if not
|
| 9672 |
logger .info (" ℹ️ Single device - no model parallelism needed")
|
| 9673 |
-
return self
|
| 9674 |
-
|
| 9675 |
-
|
| 9676 |
-
|
| 9677 |
-
|
| 9678 |
-
|
| 9679 |
-
|
| 9680 |
-
|
| 9681 |
-
|
| 9682 |
-
|
| 9683 |
-
|
| 9684 |
-
|
| 9685 |
-
|
| 9686 |
-
|
| 9687 |
-
|
| 9688 |
-
|
| 9689 |
-
|
| 9690 |
-
|
| 9691 |
-
|
| 9692 |
-
|
| 9693 |
-
|
| 9694 |
-
|
| 9695 |
-
|
| 9696 |
-
|
| 9697 |
-
|
| 9698 |
-
|
| 9699 |
-
|
| 9700 |
-
|
| 9701 |
-
|
| 9702 |
-
|
| 9703 |
-
|
| 9704 |
-
|
| 9705 |
-
|
| 9706 |
-
|
| 9707 |
-
|
| 9708 |
-
|
| 9709 |
-
|
| 9710 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9711 |
|
| 9712 |
-
|
| 9713 |
-
|
| 9714 |
-
|
| 9715 |
-
|
| 9716 |
-
|
| 9717 |
-
|
| 9718 |
-
|
| 9719 |
-
|
| 9720 |
-
|
| 9721 |
-
|
| 9722 |
-
|
| 9723 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9724 |
|
| 9725 |
logger .info ("Model Parallelism applied successfully!")
|
| 9726 |
return self
|
|
@@ -10113,63 +10179,297 @@ class XoronMultimodalModel (nn .Module ):
|
|
| 10113 |
def listen_and_respond (
|
| 10114 |
self ,
|
| 10115 |
audio_waveform :torch .Tensor ,
|
| 10116 |
-
|
|
|
|
| 10117 |
speaker_embedding :torch .Tensor =None ,
|
| 10118 |
-
|
| 10119 |
-
|
| 10120 |
-
|
| 10121 |
-
|
| 10122 |
-
|
| 10123 |
-
|
| 10124 |
-
|
| 10125 |
-
audio_waveform: [B, T_audio] input audio (what you said)
|
| 10126 |
-
max_new_tokens: Maximum tokens to generate for response
|
| 10127 |
-
speaker_embedding: Optional speaker embedding for response voice
|
| 10128 |
-
|
| 10129 |
-
Returns:
|
| 10130 |
-
response_audio: [B, T_response] audio waveform of the model's response
|
| 10131 |
"""
|
| 10132 |
-
|
| 10133 |
-
|
| 10134 |
-
|
| 10135 |
-
audio_embeds =self .listen (audio_waveform )
|
| 10136 |
-
|
| 10137 |
-
|
| 10138 |
-
batch_size =audio_waveform .shape [0 ]
|
| 10139 |
-
|
| 10140 |
|
|
|
|
|
|
|
|
|
|
| 10141 |
|
| 10142 |
-
|
| 10143 |
-
|
| 10144 |
-
|
| 10145 |
-
|
| 10146 |
-
|
| 10147 |
-
|
| 10148 |
-
audio_features =audio_waveform ,
|
| 10149 |
-
)
|
| 10150 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10151 |
|
| 10152 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10153 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10154 |
|
| 10155 |
-
|
|
|
|
| 10156 |
|
| 10157 |
-
|
| 10158 |
response_embeds ,
|
| 10159 |
speaker_embedding =speaker_embedding ,
|
| 10160 |
-
|
| 10161 |
-
|
| 10162 |
|
| 10163 |
-
|
| 10164 |
-
|
| 10165 |
-
|
| 10166 |
-
|
| 10167 |
-
response_audio =self .waveform_decoder (audio_features )
|
| 10168 |
|
| 10169 |
-
|
| 10170 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10171 |
|
| 10172 |
-
return torch .zeros (batch_size ,16000 ,device =device )
|
| 10173 |
|
| 10174 |
def merge_lora_weights (self ):
|
| 10175 |
"""Merge LoRA weights into main weights for inference."""
|
|
@@ -11120,6 +11420,10 @@ XoronForCausalLM.register_for_auto_class("AutoModelForCausalLM")
|
|
| 11120 |
model =cls (config ,device_map =device_map )
|
| 11121 |
|
| 11122 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11123 |
components_json =os .path .join (path ,"components.json")
|
| 11124 |
model_path =os .path .join (path ,"model.safetensors")
|
| 11125 |
|
|
@@ -11127,7 +11431,7 @@ XoronForCausalLM.register_for_auto_class("AutoModelForCausalLM")
|
|
| 11127 |
|
| 11128 |
logger .info ("Loading from component-based format...")
|
| 11129 |
model ._load_components (path ,strict =strict )
|
| 11130 |
-
model .lora_applied =
|
| 11131 |
|
| 11132 |
elif os .path .exists (model_path ):
|
| 11133 |
logger .info ("Loading weights from safetensors...")
|
|
@@ -11143,7 +11447,7 @@ XoronForCausalLM.register_for_auto_class("AutoModelForCausalLM")
|
|
| 11143 |
model .load_state_dict (checkpoint_state_dict ,strict =False )
|
| 11144 |
logger .info ("Loaded weights from checkpoint")
|
| 11145 |
|
| 11146 |
-
model .lora_applied =
|
| 11147 |
else :
|
| 11148 |
|
| 11149 |
pytorch_path =os .path .join (path ,"pytorch_model.bin")
|
|
@@ -11154,7 +11458,7 @@ XoronForCausalLM.register_for_auto_class("AutoModelForCausalLM")
|
|
| 11154 |
model .load_state_dict (checkpoint_state_dict ,strict =False )
|
| 11155 |
logger .info ("Loaded weights from checkpoint")
|
| 11156 |
|
| 11157 |
-
model .lora_applied =
|
| 11158 |
else :
|
| 11159 |
raise FileNotFoundError (f"No model weights found at {path }")
|
| 11160 |
|
|
|
|
| 9122 |
|
| 9123 |
self .num_moe_layers =sum (1 for layer in self .layers if layer .is_moe_layer )
|
| 9124 |
|
| 9125 |
+
# ── Coconut: Continuous Thought components ──
|
| 9126 |
+
# Learned gate controls how much recurrent thought vs original input
|
| 9127 |
+
# to retain at each thinking step. Sigmoid output in [0,1].
|
| 9128 |
+
self .thought_gate = nn .Linear (config .hidden_size , 1 , bias =True )
|
| 9129 |
+
nn .init .constant_ (self .thought_gate .bias , -2.0 ) # Initialize gate biased toward original (sigmoid(-2)≈0.12)
|
| 9130 |
+
self .thought_layernorm = LlamaRMSNorm (config .hidden_size , eps =config .rms_norm_eps )
|
| 9131 |
+
|
| 9132 |
self ._init_weights ()
|
| 9133 |
|
| 9134 |
def _init_weights (self ):
|
|
|
|
| 9154 |
output_hidden_states :bool =False ,
|
| 9155 |
return_dict :bool =True ,
|
| 9156 |
cache_position :Optional [torch .Tensor ]=None ,
|
| 9157 |
+
thinking_depth :int =0 ,
|
| 9158 |
)->Union [Tuple ,MoELlamaModelOutput ]:
|
| 9159 |
|
| 9160 |
if inputs_embeds is None :
|
|
|
|
| 9215 |
if output_attentions and attn_weights is not None :
|
| 9216 |
all_attentions =all_attentions +(attn_weights ,)
|
| 9217 |
|
| 9218 |
+
# ── Coconut: Continuous Thought Loop ──
|
| 9219 |
+
# After the normal pass, loop hidden states back through the
|
| 9220 |
+
# transformer layers for extra computation in latent space.
|
| 9221 |
+
# No tokens are decoded — pure continuous reasoning.
|
| 9222 |
+
if thinking_depth > 0 :
|
| 9223 |
+
original_hidden = hidden_states .clone ()
|
| 9224 |
+
thought_position_ids = torch .arange (
|
| 9225 |
+
seq_len , device =hidden_states .device
|
| 9226 |
+
).unsqueeze (0 ).expand (batch_size , -1 )
|
| 9227 |
+
|
| 9228 |
+
for thought_step in range (thinking_depth ):
|
| 9229 |
+
# Normalize before re-entering the layers
|
| 9230 |
+
hidden_states = self .thought_layernorm (hidden_states )
|
| 9231 |
+
|
| 9232 |
+
# Run through all layers again (no cache — full re-computation)
|
| 9233 |
+
for layer in self .layers :
|
| 9234 |
+
hidden_states , _ , _ , step_aux = layer (
|
| 9235 |
+
hidden_states =hidden_states ,
|
| 9236 |
+
attention_mask =None , # Self-attend freely in thought space
|
| 9237 |
+
position_ids =thought_position_ids ,
|
| 9238 |
+
past_key_value =None ,
|
| 9239 |
+
output_attentions =False ,
|
| 9240 |
+
use_cache =False ,
|
| 9241 |
+
)
|
| 9242 |
+
if step_aux is not None :
|
| 9243 |
+
total_aux_loss = total_aux_loss + step_aux
|
| 9244 |
+
|
| 9245 |
+
# Gated residual: blend thought with original
|
| 9246 |
+
# gate ∈ [0,1], initialized small so early training
|
| 9247 |
+
# stays close to original behavior
|
| 9248 |
+
gate = torch .sigmoid (self .thought_gate (hidden_states ))
|
| 9249 |
+
hidden_states = gate * hidden_states + (1.0 - gate ) * original_hidden
|
| 9250 |
+
|
| 9251 |
hidden_states =self .norm (hidden_states )
|
| 9252 |
|
| 9253 |
if output_hidden_states :
|
|
|
|
| 9356 |
output_hidden_states :bool =False ,
|
| 9357 |
return_dict :bool =True ,
|
| 9358 |
cache_position :Optional [torch .Tensor ]=None ,
|
| 9359 |
+
thinking_depth :int =0 ,
|
| 9360 |
**kwargs ,
|
| 9361 |
)->Union [Tuple ,CausalLMOutput ]:
|
| 9362 |
|
|
|
|
| 9371 |
output_hidden_states =output_hidden_states ,
|
| 9372 |
return_dict =True ,
|
| 9373 |
cache_position =cache_position ,
|
| 9374 |
+
thinking_depth =thinking_depth ,
|
| 9375 |
)
|
| 9376 |
|
| 9377 |
hidden_states =outputs .last_hidden_state
|
|
|
|
| 9422 |
pad_token_id :Optional [int ]=None ,
|
| 9423 |
eos_token_id :Optional [int ]=None ,
|
| 9424 |
attention_mask :Optional [torch .Tensor ]=None ,
|
| 9425 |
+
thinking_depth :int =0 ,
|
| 9426 |
**kwargs ,
|
| 9427 |
)->torch .Tensor :
|
| 9428 |
batch_size =input_ids .shape [0 ]
|
| 9429 |
device =input_ids .device
|
| 9430 |
|
| 9431 |
past_key_values =None
|
| 9432 |
+
is_prefill =True # Deep thinking only on first pass (full context)
|
| 9433 |
|
| 9434 |
if attention_mask is None :
|
| 9435 |
attention_mask =torch .ones_like (input_ids )
|
|
|
|
| 9441 |
attention_mask =attention_mask ,
|
| 9442 |
)
|
| 9443 |
|
| 9444 |
+
# Apply thinking depth only on prefill, not per-token steps
|
| 9445 |
+
current_depth = thinking_depth if is_prefill else 0
|
| 9446 |
+
outputs =self .forward (**model_inputs ,use_cache =True ,return_dict =True ,thinking_depth =current_depth )
|
| 9447 |
+
is_prefill =False
|
| 9448 |
|
| 9449 |
next_token_logits =outputs .logits [:,-1 ,:]
|
| 9450 |
|
|
|
|
| 9708 |
|
| 9709 |
|
| 9710 |
def apply_model_parallel (self ,device_map :Dict [str ,str ]):
|
| 9711 |
+
"""Apply Model Parallelism by sharding components across devices.
|
| 9712 |
+
|
| 9713 |
+
Trained components get their layers split across all training GPUs.
|
| 9714 |
+
Frozen components go to CPU. Small components (projectors, markers)
|
| 9715 |
+
go to the primary GPU.
|
| 9716 |
+
"""
|
| 9717 |
+
self .device_map =device_map
|
| 9718 |
+
training_gpus = device_map .get ('training_gpus', ['cuda:0'])
|
| 9719 |
+
primary = device_map .get ('primary', 'cuda:0')
|
| 9720 |
|
| 9721 |
+
if len (training_gpus ) <= 1 and not any (v == 'cpu' for v in device_map .values () if isinstance (v, str)):
|
| 9722 |
logger .info (" ℹ️ Single device - no model parallelism needed")
|
| 9723 |
+
return self
|
| 9724 |
+
|
| 9725 |
+
self ._model_parallel = True
|
| 9726 |
+
logger .info ("Applying Model Parallelism (layer sharding)...")
|
| 9727 |
+
|
| 9728 |
+
def _shard_module (module, name, gpus):
|
| 9729 |
+
"""Shard a module's sub-layers across GPUs."""
|
| 9730 |
+
# Find shardable sub-layers (nn.ModuleList children)
|
| 9731 |
+
layer_lists = []
|
| 9732 |
+
for attr_name in dir (module):
|
| 9733 |
+
attr = getattr (module, attr_name, None)
|
| 9734 |
+
if isinstance (attr, nn .ModuleList) and len (attr) > 0:
|
| 9735 |
+
layer_lists .append ((attr_name, attr))
|
| 9736 |
+
|
| 9737 |
+
if layer_lists:
|
| 9738 |
+
# Shard the largest ModuleList across GPUs
|
| 9739 |
+
layer_lists .sort (key=lambda x: len (x[1]), reverse=True)
|
| 9740 |
+
list_name, layers = layer_lists [0]
|
| 9741 |
+
for i, layer in enumerate (layers):
|
| 9742 |
+
target_gpu = gpus [i % len (gpus)]
|
| 9743 |
+
layer .to (target_gpu)
|
| 9744 |
+
# Put remaining params on primary GPU
|
| 9745 |
+
for param_name, param in module .named_parameters ():
|
| 9746 |
+
if not any (f'{list_name}.' in param_name for _ in [1]):
|
| 9747 |
+
param .data = param .data .to (gpus [0])
|
| 9748 |
+
logger .info (f" ✅ {name}: {len(layers)} layers sharded across {gpus}")
|
| 9749 |
+
else:
|
| 9750 |
+
# No layers to shard — put whole module on first GPU
|
| 9751 |
+
module .to (gpus [0])
|
| 9752 |
+
logger .info (f" ✅ {name} -> {gpus[0]}")
|
| 9753 |
+
|
| 9754 |
+
# Map component names to actual attributes
|
| 9755 |
+
component_attrs = {
|
| 9756 |
+
'vision_encoder': 'vision_encoder',
|
| 9757 |
+
'video_encoder': 'video_encoder',
|
| 9758 |
+
'audio_encoder': 'audio_encoder',
|
| 9759 |
+
'audio_decoder': 'audio_decoder',
|
| 9760 |
+
'waveform_decoder': 'waveform_decoder',
|
| 9761 |
+
'projector': 'projector',
|
| 9762 |
+
'audio_projector': 'audio_projector',
|
| 9763 |
+
'llm': 'llm',
|
| 9764 |
+
'cross_attention': 'cross_attention_layers',
|
| 9765 |
+
'generator': 'generator',
|
| 9766 |
+
'video_generator': 'video_generator',
|
| 9767 |
+
}
|
| 9768 |
|
| 9769 |
+
for comp_name, attr_name in component_attrs .items ():
|
| 9770 |
+
comp = getattr (self, attr_name, None)
|
| 9771 |
+
if comp is None:
|
| 9772 |
+
continue
|
| 9773 |
+
target = device_map .get (comp_name, 'cpu')
|
| 9774 |
+
if target == 'cpu':
|
| 9775 |
+
comp .to ('cpu')
|
| 9776 |
+
logger .info (f" ❄️ {comp_name} -> cpu (frozen)")
|
| 9777 |
+
else:
|
| 9778 |
+
# Shard across all training GPUs
|
| 9779 |
+
_shard_module (comp, comp_name, training_gpus)
|
| 9780 |
+
|
| 9781 |
+
# Modality markers → primary GPU
|
| 9782 |
+
marker_device = device_map .get ('modality_markers', primary)
|
| 9783 |
+
if marker_device != 'cpu':
|
| 9784 |
+
marker_device = primary
|
| 9785 |
+
for marker_name in ['image_start', 'image_end', 'video_start', 'video_end', 'audio_start', 'audio_end']:
|
| 9786 |
+
marker = getattr (self, marker_name, None)
|
| 9787 |
+
if marker is not None:
|
| 9788 |
+
setattr (self, marker_name, nn .Parameter (marker .data .to (marker_device)))
|
| 9789 |
+
logger .info (f" ✅ Modality markers -> {marker_device}")
|
| 9790 |
|
| 9791 |
logger .info ("Model Parallelism applied successfully!")
|
| 9792 |
return self
|
|
|
|
| 10179 |
def listen_and_respond (
|
| 10180 |
self ,
|
| 10181 |
audio_waveform :torch .Tensor ,
|
| 10182 |
+
tokenizer =None ,
|
| 10183 |
+
max_new_tokens :int =512 ,
|
| 10184 |
speaker_embedding :torch .Tensor =None ,
|
| 10185 |
+
temperature :float =0.7 ,
|
| 10186 |
+
top_p :float =0.9 ,
|
| 10187 |
+
tool_executor =None ,
|
| 10188 |
+
available_tools :list =None ,
|
| 10189 |
+
system_prompt :str =None ,
|
| 10190 |
+
max_tool_calls :int =5 ,
|
| 10191 |
+
) -> Dict [str ,Any ]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10192 |
"""
|
| 10193 |
+
Agentic Speech-to-Speech: Listen, think, use tools, speak back.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10194 |
|
| 10195 |
+
This is the full agentic pipeline for live voice conversations.
|
| 10196 |
+
The model can detect when the user is asking for actions (e.g.
|
| 10197 |
+
"write me a Python script") and execute tools mid-generation.
|
| 10198 |
|
| 10199 |
+
Pipeline:
|
| 10200 |
+
1. Encode input audio → audio embeddings (ASR)
|
| 10201 |
+
2. Build context (system prompt with tools + audio embeddings)
|
| 10202 |
+
3. Generate tokens, watching for <|tool_call|> sequences
|
| 10203 |
+
4. When tool call detected: parse, execute, inject result, resume
|
| 10204 |
+
5. Synthesize final spoken response from non-tool text
|
|
|
|
|
|
|
| 10205 |
|
| 10206 |
+
Args:
|
| 10207 |
+
audio_waveform: [B, T_audio] input audio waveform
|
| 10208 |
+
tokenizer: Tokenizer for decoding tokens to text (required for tools)
|
| 10209 |
+
max_new_tokens: Maximum total tokens to generate
|
| 10210 |
+
speaker_embedding: [B, D] optional speaker embedding for voice cloning
|
| 10211 |
+
temperature: Sampling temperature
|
| 10212 |
+
top_p: Nucleus sampling probability
|
| 10213 |
+
tool_executor: Callable(tool_name, args_dict) -> str result.
|
| 10214 |
+
If None, tool calls are detected but not executed.
|
| 10215 |
+
available_tools: List of tool definition dicts for system prompt.
|
| 10216 |
+
system_prompt: Optional system prompt override.
|
| 10217 |
+
max_tool_calls: Maximum number of tool calls per response (safety limit).
|
| 10218 |
|
| 10219 |
+
Returns:
|
| 10220 |
+
Dict with:
|
| 10221 |
+
'waveform': [B, T_response] audio waveform tensor (in-memory, no file I/O)
|
| 10222 |
+
'text': str full response text (excluding tool call markup)
|
| 10223 |
+
'token_ids': [B, T_tokens] all generated token IDs
|
| 10224 |
+
'mel': [B, 80, T_mel] intermediate mel spectrogram
|
| 10225 |
+
'tool_calls': List[Dict] executed tool calls and their results
|
| 10226 |
+
'speaking_text': str clean text that was spoken (no tool markup)
|
| 10227 |
+
"""
|
| 10228 |
+
import re
|
| 10229 |
+
import json as _json
|
| 10230 |
+
|
| 10231 |
+
device = audio_waveform .device
|
| 10232 |
+
batch_size = audio_waveform .shape [0 ]
|
| 10233 |
+
llm_device = self .get_llm_device ()
|
| 10234 |
+
|
| 10235 |
+
# ── 1. Listen: encode input audio ──
|
| 10236 |
+
audio_embeds = self .encode_audio (audio_waveform )
|
| 10237 |
+
|
| 10238 |
+
|
| 10239 |
+
# Wrap with start/end markers
|
| 10240 |
+
audio_start = self .audio_start .expand (batch_size , -1 , -1 ).to (llm_device )
|
| 10241 |
+
audio_end = self .audio_end .expand (batch_size , -1 , -1 ).to (llm_device )
|
| 10242 |
+
audio_embeds = audio_embeds .to (llm_device )
|
| 10243 |
+
|
| 10244 |
+
# ── 2. Build context with system prompt + tools ──
|
| 10245 |
+
context_parts = []
|
| 10246 |
+
|
| 10247 |
+
if tokenizer is not None and (system_prompt or tool_executor):
|
| 10248 |
+
sys_text = system_prompt or "You are Xoron, an intelligent voice assistant. You can use tools to help the user."
|
| 10249 |
+
if tool_executor and hasattr (tool_executor , 'get_tool_prompt' ):
|
| 10250 |
+
sys_text = sys_text + "\n\n" + tool_executor .get_tool_prompt ()
|
| 10251 |
+
elif available_tools :
|
| 10252 |
+
from utils .tool_executor import format_tools_for_prompt
|
| 10253 |
+
sys_text = sys_text + "\n\n" + format_tools_for_prompt (available_tools )
|
| 10254 |
+
|
| 10255 |
+
# Encode system prompt and prepend
|
| 10256 |
+
sys_str = "<|system|>" + sys_text + "<|/system|>"
|
| 10257 |
+
sys_token_ids = tokenizer .encode (sys_str , return_tensors ="pt" ).to (llm_device )
|
| 10258 |
+
sys_embeds = self .llm .model .embed_tokens (sys_token_ids )
|
| 10259 |
+
context_parts .append (sys_embeds .squeeze (0 ) if sys_embeds .dim () == 3 else sys_embeds )
|
| 10260 |
+
|
| 10261 |
+
# Audio context
|
| 10262 |
+
context_parts .extend ([audio_start , audio_embeds , audio_end ])
|
| 10263 |
+
|
| 10264 |
+
# Assistant generation prompt
|
| 10265 |
+
if tokenizer is not None :
|
| 10266 |
+
asst_str = "<|assistant|>"
|
| 10267 |
+
asst_ids = tokenizer .encode (asst_str , return_tensors ="pt" ).to (llm_device )
|
| 10268 |
+
asst_embeds = self .llm .model .embed_tokens (asst_ids )
|
| 10269 |
+
context_parts .append (asst_embeds .squeeze (0 ) if asst_embeds .dim () == 3 else asst_embeds )
|
| 10270 |
+
|
| 10271 |
+
input_embeds = torch .cat (context_parts , dim =1 )
|
| 10272 |
+
|
| 10273 |
+
# ── 3. Agentic generation loop with tool call detection ──
|
| 10274 |
+
tool_call_start_token = "<|tool_call|>"
|
| 10275 |
+
tool_call_end_token = "<|/tool_call|>"
|
| 10276 |
+
fn_name_start = "<|function_name|>"
|
| 10277 |
+
fn_name_end = "<|/function_name|>"
|
| 10278 |
+
fn_args_start = "<|function_args|>"
|
| 10279 |
+
fn_args_end = "<|/function_args|>"
|
| 10280 |
+
tool_result_start = "<|tool_result|>"
|
| 10281 |
+
tool_result_end = "<|/tool_result|>"
|
| 10282 |
+
eos_token = "<|eos|>"
|
| 10283 |
+
|
| 10284 |
+
all_generated_ids = []
|
| 10285 |
+
tool_calls_made = []
|
| 10286 |
+
num_tool_calls = 0
|
| 10287 |
+
generated_text = ""
|
| 10288 |
+
total_tokens = 0
|
| 10289 |
+
|
| 10290 |
+
# Use standard generation if no tool executor
|
| 10291 |
+
if tool_executor is None or tokenizer is None :
|
| 10292 |
+
gen_kwargs = {
|
| 10293 |
+
'inputs_embeds': input_embeds ,
|
| 10294 |
+
'max_new_tokens': max_new_tokens ,
|
| 10295 |
+
'do_sample': True ,
|
| 10296 |
+
'temperature': temperature ,
|
| 10297 |
+
'top_p': top_p ,
|
| 10298 |
+
'use_cache': True ,
|
| 10299 |
+
}
|
| 10300 |
+
generated_ids = self .llm .generate (**gen_kwargs )
|
| 10301 |
+
all_generated_ids = [generated_ids ]
|
| 10302 |
|
| 10303 |
+
if tokenizer is not None :
|
| 10304 |
+
generated_text = tokenizer .batch_decode (generated_ids , skip_special_tokens =True )[0 ]
|
| 10305 |
+
else :
|
| 10306 |
+
# Token-by-token generation with tool call detection
|
| 10307 |
+
current_embeds = input_embeds
|
| 10308 |
+
past_key_values = None
|
| 10309 |
+
in_tool_call = False
|
| 10310 |
+
tool_call_buffer = ""
|
| 10311 |
+
|
| 10312 |
+
while total_tokens < max_new_tokens :
|
| 10313 |
+
outputs = self .llm (
|
| 10314 |
+
inputs_embeds =current_embeds ,
|
| 10315 |
+
past_key_values =past_key_values ,
|
| 10316 |
+
use_cache =True ,
|
| 10317 |
+
)
|
| 10318 |
+
past_key_values = outputs .past_key_values
|
| 10319 |
+
logits = outputs .logits [:, -1 :, :]
|
| 10320 |
+
|
| 10321 |
+
# Sample next token
|
| 10322 |
+
if temperature > 0 :
|
| 10323 |
+
logits = logits / temperature
|
| 10324 |
+
if top_p < 1.0 :
|
| 10325 |
+
sorted_logits , sorted_indices = torch .sort (logits , descending =True , dim =-1 )
|
| 10326 |
+
cumulative_probs = torch .cumsum (F .softmax (sorted_logits , dim =-1 ), dim =-1 )
|
| 10327 |
+
sorted_mask = cumulative_probs - F .softmax (sorted_logits , dim =-1 ) >= top_p
|
| 10328 |
+
sorted_logits [sorted_mask ] = float ('-inf' )
|
| 10329 |
+
logits .scatter_ (-1 , sorted_indices , sorted_logits )
|
| 10330 |
+
probs = F .softmax (logits , dim =-1 )
|
| 10331 |
+
next_token = torch .multinomial (probs .squeeze (1 ), num_samples =1 )
|
| 10332 |
+
else :
|
| 10333 |
+
next_token = logits .argmax (dim =-1 )
|
| 10334 |
+
|
| 10335 |
+
total_tokens += 1
|
| 10336 |
+
all_generated_ids .append (next_token )
|
| 10337 |
+
|
| 10338 |
+
# Decode the token
|
| 10339 |
+
token_text = tokenizer .decode (next_token [0 ], skip_special_tokens =False )
|
| 10340 |
+
generated_text = generated_text + token_text
|
| 10341 |
+
|
| 10342 |
+
# Check for EOS
|
| 10343 |
+
if eos_token in token_text or next_token .item () == tokenizer .eos_token_id :
|
| 10344 |
+
break
|
| 10345 |
+
|
| 10346 |
+
# ── Tool call detection ──
|
| 10347 |
+
if tool_call_start_token in generated_text and not in_tool_call :
|
| 10348 |
+
in_tool_call = True
|
| 10349 |
+
# Extract everything after the tool_call_start
|
| 10350 |
+
tc_start_idx = generated_text .rfind (tool_call_start_token )
|
| 10351 |
+
tool_call_buffer = generated_text [tc_start_idx :]
|
| 10352 |
+
|
| 10353 |
+
if in_tool_call :
|
| 10354 |
+
tool_call_buffer = tool_call_buffer + token_text if tool_call_buffer else generated_text
|
| 10355 |
+
|
| 10356 |
+
# Check if we have a complete tool call
|
| 10357 |
+
if tool_call_end_token in tool_call_buffer :
|
| 10358 |
+
in_tool_call = False
|
| 10359 |
+
num_tool_calls += 1
|
| 10360 |
+
|
| 10361 |
+
# Parse the tool call
|
| 10362 |
+
tool_name = ""
|
| 10363 |
+
tool_args = {}
|
| 10364 |
+
try :
|
| 10365 |
+
# Extract function name
|
| 10366 |
+
name_start = tool_call_buffer .find (fn_name_start ) + len (fn_name_start )
|
| 10367 |
+
name_end = tool_call_buffer .find (fn_name_end )
|
| 10368 |
+
if name_start > 0 and name_end > 0 :
|
| 10369 |
+
tool_name = tool_call_buffer [name_start :name_end ].strip ()
|
| 10370 |
+
|
| 10371 |
+
# Extract arguments
|
| 10372 |
+
args_start = tool_call_buffer .find (fn_args_start ) + len (fn_args_start )
|
| 10373 |
+
args_end = tool_call_buffer .find (fn_args_end )
|
| 10374 |
+
if args_start > 0 and args_end > 0 :
|
| 10375 |
+
args_str = tool_call_buffer [args_start :args_end ].strip ()
|
| 10376 |
+
try :
|
| 10377 |
+
import json as _json
|
| 10378 |
+
tool_args = _json .loads (args_str )
|
| 10379 |
+
except Exception :
|
| 10380 |
+
tool_args = {"raw": args_str }
|
| 10381 |
+
except Exception :
|
| 10382 |
+
pass
|
| 10383 |
+
|
| 10384 |
+
# Execute the tool
|
| 10385 |
+
tool_result = "[error]: Failed to parse tool call"
|
| 10386 |
+
if tool_name :
|
| 10387 |
+
tool_result = tool_executor (tool_name , tool_args )
|
| 10388 |
+
|
| 10389 |
+
tool_calls_made .append ({
|
| 10390 |
+
"name": tool_name ,
|
| 10391 |
+
"arguments": tool_args ,
|
| 10392 |
+
"result": tool_result ,
|
| 10393 |
+
})
|
| 10394 |
+
|
| 10395 |
+
# Inject tool result back into generation context
|
| 10396 |
+
result_str = tool_result_start + tool_result + tool_result_end
|
| 10397 |
+
result_ids = tokenizer .encode (result_str , return_tensors ="pt" ).to (llm_device )
|
| 10398 |
+
result_embeds = self .llm .model .embed_tokens (result_ids )
|
| 10399 |
+
current_embeds = result_embeds
|
| 10400 |
+
past_key_values = None # Reset KV cache to include result
|
| 10401 |
+
all_generated_ids .append (result_ids .squeeze (0 ))
|
| 10402 |
+
|
| 10403 |
+
generated_text = generated_text + result_str
|
| 10404 |
+
tool_call_buffer = ""
|
| 10405 |
+
|
| 10406 |
+
if num_tool_calls >= max_tool_calls :
|
| 10407 |
+
break
|
| 10408 |
+
|
| 10409 |
+
continue
|
| 10410 |
+
|
| 10411 |
+
# Prepare next input
|
| 10412 |
+
next_embeds = self .llm .model .embed_tokens (next_token )
|
| 10413 |
+
current_embeds = next_embeds
|
| 10414 |
+
|
| 10415 |
+
# Combine all generated IDs
|
| 10416 |
+
if all_generated_ids :
|
| 10417 |
+
flat_ids = []
|
| 10418 |
+
for t in all_generated_ids :
|
| 10419 |
+
if t .dim () == 0 :
|
| 10420 |
+
flat_ids .append (t .unsqueeze (0 ))
|
| 10421 |
+
elif t .dim () == 1 :
|
| 10422 |
+
flat_ids .append (t )
|
| 10423 |
+
else :
|
| 10424 |
+
flat_ids .append (t .view (-1 ))
|
| 10425 |
+
generated_ids = torch .cat (flat_ids , dim =0 ).unsqueeze (0 )
|
| 10426 |
+
else :
|
| 10427 |
+
generated_ids = torch .tensor ([[]], dtype =torch .long , device =llm_device )
|
| 10428 |
+
|
| 10429 |
+
# ── 4. Extract speaking text (strip tool call/result markup) ──
|
| 10430 |
+
speaking_text = generated_text
|
| 10431 |
+
# Remove tool call blocks
|
| 10432 |
+
while tool_call_start_token in speaking_text :
|
| 10433 |
+
tc_s = speaking_text .find (tool_call_start_token )
|
| 10434 |
+
tc_e = speaking_text .find (tool_call_end_token )
|
| 10435 |
+
if tc_e > tc_s :
|
| 10436 |
+
speaking_text = speaking_text [:tc_s ] + speaking_text [tc_e + len (tool_call_end_token ):]
|
| 10437 |
+
else :
|
| 10438 |
+
break
|
| 10439 |
+
# Remove tool result blocks
|
| 10440 |
+
while tool_result_start in speaking_text :
|
| 10441 |
+
tr_s = speaking_text .find (tool_result_start )
|
| 10442 |
+
tr_e = speaking_text .find (tool_result_end )
|
| 10443 |
+
if tr_e > tr_s :
|
| 10444 |
+
speaking_text = speaking_text [:tr_s ] + speaking_text [tr_e + len (tool_result_end ):]
|
| 10445 |
+
else :
|
| 10446 |
+
break
|
| 10447 |
+
speaking_text = speaking_text .strip ()
|
| 10448 |
|
| 10449 |
+
# ── 5. Speak: encode → mel → stream_decode → waveform ──
|
| 10450 |
+
response_embeds = self .llm .model .embed_tokens (generated_ids .to (llm_device ))
|
| 10451 |
|
| 10452 |
+
mel , durations , _ , _ = self .audio_decoder (
|
| 10453 |
response_embeds ,
|
| 10454 |
speaker_embedding =speaker_embedding ,
|
| 10455 |
+
)
|
|
|
|
| 10456 |
|
| 10457 |
+
mel_features = mel .transpose (1 , 2 )
|
| 10458 |
+
if not hasattr (self , '_mel_to_hidden' ):
|
| 10459 |
+
self ._mel_to_hidden = nn .Linear (80 , self .config .hidden_size ).to (mel .device )
|
| 10460 |
+
audio_features = self ._mel_to_hidden (mel_features )
|
|
|
|
| 10461 |
|
| 10462 |
+
waveform = self .waveform_decoder .stream_decode (audio_features )
|
| 10463 |
|
| 10464 |
+
return {
|
| 10465 |
+
'waveform': waveform ,
|
| 10466 |
+
'text': generated_text ,
|
| 10467 |
+
'speaking_text': speaking_text ,
|
| 10468 |
+
'token_ids': generated_ids ,
|
| 10469 |
+
'mel': mel ,
|
| 10470 |
+
'tool_calls': tool_calls_made ,
|
| 10471 |
+
}
|
| 10472 |
|
|
|
|
| 10473 |
|
| 10474 |
def merge_lora_weights (self ):
|
| 10475 |
"""Merge LoRA weights into main weights for inference."""
|
|
|
|
| 11420 |
model =cls (config ,device_map =device_map )
|
| 11421 |
|
| 11422 |
|
| 11423 |
+
if lora_was_applied:
|
| 11424 |
+
logger .info ("Checkpoint has LoRA weights. Applying LoRA structure before loading...")
|
| 11425 |
+
model .apply_lora ()
|
| 11426 |
+
|
| 11427 |
components_json =os .path .join (path ,"components.json")
|
| 11428 |
model_path =os .path .join (path ,"model.safetensors")
|
| 11429 |
|
|
|
|
| 11431 |
|
| 11432 |
logger .info ("Loading from component-based format...")
|
| 11433 |
model ._load_components (path ,strict =strict )
|
| 11434 |
+
model .lora_applied =False # Always allow fresh LoRA application (checkpoint has merged weights)
|
| 11435 |
|
| 11436 |
elif os .path .exists (model_path ):
|
| 11437 |
logger .info ("Loading weights from safetensors...")
|
|
|
|
| 11447 |
model .load_state_dict (checkpoint_state_dict ,strict =False )
|
| 11448 |
logger .info ("Loaded weights from checkpoint")
|
| 11449 |
|
| 11450 |
+
model .lora_applied =False # Always allow fresh LoRA application (checkpoint has merged weights)
|
| 11451 |
else :
|
| 11452 |
|
| 11453 |
pytorch_path =os .path .join (path ,"pytorch_model.bin")
|
|
|
|
| 11458 |
model .load_state_dict (checkpoint_state_dict ,strict =False )
|
| 11459 |
logger .info ("Loaded weights from checkpoint")
|
| 11460 |
|
| 11461 |
+
model .lora_applied =False # Always allow fresh LoRA application (checkpoint has merged weights)
|
| 11462 |
else :
|
| 11463 |
raise FileNotFoundError (f"No model weights found at {path }")
|
| 11464 |
|
streaming_state.json
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
{
|
| 2 |
-
"epoch":
|
| 3 |
-
"unique_samples":
|
| 4 |
-
"total_yields":
|
| 5 |
"dataset_positions": {
|
| 6 |
"WebSight": 386,
|
| 7 |
"ScienceQA": 364,
|
|
@@ -30,7 +30,7 @@
|
|
| 30 |
"NoRobots": 800,
|
| 31 |
"Synth-LanguageSetup": 200,
|
| 32 |
"Function-Calling-ChatML": 200,
|
| 33 |
-
"Synth-CoT":
|
| 34 |
"Python-Code-18k": 200,
|
| 35 |
"Code-Feedback": 200,
|
| 36 |
"HumanEval-CPP": 164,
|
|
@@ -97,7 +97,9 @@
|
|
| 97 |
"Cosmopedia-OpenStax": 600,
|
| 98 |
"MedMCQA": 650,
|
| 99 |
"Medical-Reasoning-SFT-Mega": 650,
|
| 100 |
-
"Medical-O1-Reasoning-EN": 650
|
|
|
|
|
|
|
| 101 |
},
|
| 102 |
"modality_positions": {
|
| 103 |
"text": {
|
|
@@ -154,7 +156,10 @@
|
|
| 154 |
"Synth-FactCheck": 550,
|
| 155 |
"Synth-ConfidenceLevel": 550,
|
| 156 |
"Synth-Citation": 550,
|
| 157 |
-
"Synth-Uncertainty": 550
|
|
|
|
|
|
|
|
|
|
| 158 |
},
|
| 159 |
"image": {
|
| 160 |
"WebSight": 386,
|
|
@@ -179,10 +184,11 @@
|
|
| 179 |
"audio": {}
|
| 180 |
},
|
| 181 |
"modality_counts": {
|
| 182 |
-
"text":
|
| 183 |
"image": 0,
|
| 184 |
"video": 0,
|
| 185 |
-
"audio": 0
|
|
|
|
| 186 |
},
|
| 187 |
"last_modality": null
|
| 188 |
}
|
|
|
|
| 1 |
{
|
| 2 |
+
"epoch": 148,
|
| 3 |
+
"unique_samples": 150,
|
| 4 |
+
"total_yields": 300,
|
| 5 |
"dataset_positions": {
|
| 6 |
"WebSight": 386,
|
| 7 |
"ScienceQA": 364,
|
|
|
|
| 30 |
"NoRobots": 800,
|
| 31 |
"Synth-LanguageSetup": 200,
|
| 32 |
"Function-Calling-ChatML": 200,
|
| 33 |
+
"Synth-CoT": 900,
|
| 34 |
"Python-Code-18k": 200,
|
| 35 |
"Code-Feedback": 200,
|
| 36 |
"HumanEval-CPP": 164,
|
|
|
|
| 97 |
"Cosmopedia-OpenStax": 600,
|
| 98 |
"MedMCQA": 650,
|
| 99 |
"Medical-Reasoning-SFT-Mega": 650,
|
| 100 |
+
"Medical-O1-Reasoning-EN": 650,
|
| 101 |
+
"OpenThoughts-114k": 350,
|
| 102 |
+
"Bespoke-Stratos-17k": 350
|
| 103 |
},
|
| 104 |
"modality_positions": {
|
| 105 |
"text": {
|
|
|
|
| 156 |
"Synth-FactCheck": 550,
|
| 157 |
"Synth-ConfidenceLevel": 550,
|
| 158 |
"Synth-Citation": 550,
|
| 159 |
+
"Synth-Uncertainty": 550,
|
| 160 |
+
"OpenThoughts-114k": 350,
|
| 161 |
+
"Bespoke-Stratos-17k": 350,
|
| 162 |
+
"Synth-CoT": 900
|
| 163 |
},
|
| 164 |
"image": {
|
| 165 |
"WebSight": 386,
|
|
|
|
| 184 |
"audio": {}
|
| 185 |
},
|
| 186 |
"modality_counts": {
|
| 187 |
+
"text": 0,
|
| 188 |
"image": 0,
|
| 189 |
"video": 0,
|
| 190 |
+
"audio": 0,
|
| 191 |
+
"reasoning": 150
|
| 192 |
},
|
| 193 |
"last_modality": null
|
| 194 |
}
|
trainer_state.json
CHANGED
|
@@ -1,14 +1,14 @@
|
|
| 1 |
{
|
| 2 |
"best_model_checkpoint": "/kaggle/working/xoron-final",
|
| 3 |
-
"best_metric":
|
| 4 |
"epoch": 7,
|
| 5 |
"epochs_completed": 7,
|
| 6 |
-
"global_step":
|
| 7 |
"is_local_process_zero": true,
|
| 8 |
"is_world_process_zero": true,
|
| 9 |
"log_history": [],
|
| 10 |
"logging_steps": 50,
|
| 11 |
-
"max_steps":
|
| 12 |
"num_train_epochs": 7,
|
| 13 |
"total_flos": 0,
|
| 14 |
"train_batch_size": 1,
|
|
|
|
| 1 |
{
|
| 2 |
"best_model_checkpoint": "/kaggle/working/xoron-final",
|
| 3 |
+
"best_metric": 4.672067043383916,
|
| 4 |
"epoch": 7,
|
| 5 |
"epochs_completed": 7,
|
| 6 |
+
"global_step": 126,
|
| 7 |
"is_local_process_zero": true,
|
| 8 |
"is_world_process_zero": true,
|
| 9 |
"log_history": [],
|
| 10 |
"logging_steps": 50,
|
| 11 |
+
"max_steps": 126,
|
| 12 |
"num_train_epochs": 7,
|
| 13 |
"total_flos": 0,
|
| 14 |
"train_batch_size": 1,
|
training_state.pt
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:681e7dfcad848ba9070f7c69e68e4517dd623592df17ce9d5e050436701a3611
|
| 3 |
+
size 1514916733
|