""" Xoron Model for HuggingFace Transformers - Self-Contained Implementation. AUTO-GENERATED FILE - Do not edit directly! This module provides a complete, self-contained HuggingFace-compatible model class for the Xoron multimodal model. All components are embedded directly in this file to enable loading via AutoModel with trust_remote_code=True WITHOUT requiring the full Xoron-Dev package to be installed. Usage: from transformers import AutoModel, AutoConfig config = AutoConfig.from_pretrained("your-repo/xoron-model", trust_remote_code=True) model = AutoModel.from_pretrained("your-repo/xoron-model", trust_remote_code=True) """ import os import math import json import logging from dataclasses import dataclass, field from typing import Optional, Dict, List, Union, Tuple, Any import torch import torch.nn as nn import torch.nn.functional as F try: from safetensors.torch import save_file, load_file except ImportError: save_file, load_file = None, None from transformers import PreTrainedModel, LlamaConfig, LlamaModel, LlamaForCausalLM from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutputWithPast try: from transformers.models.llama.modeling_llama import ( LlamaAttention, LlamaDecoderLayer, LlamaRMSNorm, LlamaMLP, LlamaRotaryEmbedding, apply_rotary_pos_emb, repeat_kv ) except ImportError: LlamaAttention = LlamaDecoderLayer = LlamaRMSNorm = LlamaMLP = None LlamaRotaryEmbedding = apply_rotary_pos_emb = repeat_kv = None try: from .configuration_xoron import XoronConfig except ImportError: from configuration_xoron import XoronConfig logger = logging.getLogger(__name__) ============================================================================== MODELS.COMPONENTS.LORA ============================================================================== class LoRALinear (nn .Module ): """ SOTA LoRA layer with multiple variants. Supports: - Standard LoRA - DoRA (Weight-Decomposed LoRA) - rsLoRA (rank-stabilized scaling) MEMORY OPTIMIZATION: - Does NOT clone base weights - shares them with original module - Only LoRA params (A, B, magnitude) consume additional memory - Base weights are frozen and can be kept in lower precision """ def __init__ ( self , in_features :int , out_features :int , r :int =8 , lora_alpha :int =16 , lora_dropout :float =0.05 , merge_weights :bool =False , use_dora :bool =False , use_rslora :bool =True , base_layer :nn .Linear =None , ): super ().__init__ () self .r =r self .lora_alpha =lora_alpha self .merge_weights =merge_weights self .merged =False self .use_dora =use_dora self .use_rslora =use_rslora self .in_features =in_features self .out_features =out_features if base_layer is not None : self .linear =base_layer else : self .linear =nn .Linear (in_features ,out_features ,bias =False ) if r >0 : self .lora_A =nn .Parameter (torch .zeros (r ,in_features )) self .lora_B =nn .Parameter (torch .zeros (out_features ,r )) if use_rslora : self .scaling =lora_alpha /math .sqrt (r ) else : self .scaling =lora_alpha /r self .lora_dropout =nn .Dropout (p =lora_dropout )if lora_dropout >0 else nn .Identity () nn .init .kaiming_uniform_ (self .lora_A ,a =math .sqrt (5 )) nn .init .zeros_ (self .lora_B ) if use_dora : self .magnitude =nn .Parameter (torch .ones (out_features )) self .linear .weight .requires_grad =False if hasattr (self .linear ,'bias')and self .linear .bias is not None : self .linear .bias .requires_grad =False def forward (self ,x :torch .Tensor )->torch .Tensor : if self .r >0 and not self .merged : lora_out =self .lora_dropout (x )@self .lora_A .T @self .lora_B .T *self .scaling if self .use_dora : weight =self .linear .weight +(self .lora_B @self .lora_A )*self .scaling weight_norm =weight .norm (dim =1 ,keepdim =True ) weight_normalized =weight /(weight_norm +1e-6 ) result =F .linear (x ,weight_normalized *self .magnitude .unsqueeze (1 )) else : result =self .linear (x )+lora_out else : result =self .linear (x ) return result def merge_lora_weights (self ): """Merge LoRA weights into the main weights for inference.""" if self .r >0 and not self .merged : delta =(self .lora_B @self .lora_A )*self .scaling if self .use_dora : weight =self .linear .weight +delta weight_norm =weight .norm (dim =1 ,keepdim =True ) self .linear .weight .data =(weight /(weight_norm +1e-6 ))*self .magnitude .unsqueeze (1 ) else : self .linear .weight .data +=delta self .merged =True def unmerge_lora_weights (self ): """Unmerge LoRA weights for continued training.""" if self .r >0 and self .merged : self .linear .weight .data -=(self .lora_B @self .lora_A )*self .scaling self .merged =False class LoRAConfig : """ Configuration for SOTA LoRA adaptation. Supports multiple LoRA variants and configurations. """ def __init__ ( self , r :int =8 , lora_alpha :int =16 , lora_dropout :float =0.05 , target_modules :Optional [List [str ]]=None , enable_lora :bool =True , use_dora :bool =False , use_rslora :bool =True , lora_plus_lr_ratio :float =16.0 , ): self .r =r self .lora_alpha =lora_alpha self .lora_dropout =lora_dropout self .target_modules =target_modules or [ 'q_proj','k_proj','v_proj','o_proj', 'gate_proj','up_proj','down_proj', ] self .enable_lora =enable_lora self .use_dora =use_dora self .use_rslora =use_rslora self .lora_plus_lr_ratio =lora_plus_lr_ratio def apply_lora_to_model (model :nn .Module ,lora_config :LoRAConfig )->nn .Module : """ Apply LoRA to specified modules in a model. Returns the model with LoRA layers applied. MEMORY OPTIMIZATION: - Passes the original nn.Linear layer directly to LoRALinear - This SHARES weights instead of cloning them (saves ~50% memory for target modules) - Only LoRA parameters (A, B, magnitude) are newly allocated For a 16GB model with 30% of weights in target modules: - Old behavior: Clone ~5GB = 21GB total - New behavior: Share weights = 16GB + ~50MB LoRA params """ if not lora_config .enable_lora : return model lora_layers_added =0 modules_to_replace =[] total_base_params =0 for name ,module in model .named_modules (): if not isinstance (module ,nn .Linear ): continue module_name =name .split ('.')[-1 ] if module_name in lora_config .target_modules : modules_to_replace .append ((name ,module )) total_base_params +=module .weight .numel () for name ,module in modules_to_replace : parts =name .split ('.') attr_name =parts [-1 ] parent_name ='.'.join (parts [:-1 ]) if parent_name : parent =model .get_submodule (parent_name ) else : parent =model lora_layer =LoRALinear ( in_features =module .in_features , out_features =module .out_features , r =lora_config .r , lora_alpha =lora_config .lora_alpha , lora_dropout =lora_config .lora_dropout , use_dora =lora_config .use_dora , use_rslora =lora_config .use_rslora , base_layer =module , ) setattr (parent ,attr_name ,lora_layer ) lora_layers_added +=1 lora_params =lora_layers_added *(lora_config .r *(modules_to_replace [0 ][1 ].in_features +modules_to_replace [0 ][1 ].out_features ))if modules_to_replace else 0 base_mem_saved_mb =(total_base_params *2 )/(1024 *1024 ) lora_mem_added_mb =(lora_params *4 )/(1024 *1024 ) variant ="DoRA"if lora_config .use_dora else ("rsLoRA"if lora_config .use_rslora else "LoRA") print (f"โœ… {variant } applied to {lora_layers_added } layers (r={lora_config .r }, alpha={lora_config .lora_alpha })") print (f" ๐Ÿ’พ Memory optimization: {base_mem_saved_mb :.1f}MB base weights SHARED (not cloned)") print (f" ๐Ÿ“Š New LoRA params: ~{lora_mem_added_mb :.1f}MB (trainable)") return model def get_lora_parameters (model :nn .Module )->List [nn .Parameter ]: """ Get only the LoRA parameters from a model. NOTE: This does NOT change requires_grad on any parameters! It simply returns the LoRA params (lora_A, lora_B, magnitude). Use this when you want to get LoRA params for separate optimizer groups or for LoRA-only training mode. """ lora_params =[] for name ,param in model .named_parameters (): if 'lora_A'in name or 'lora_B'in name or 'magnitude'in name : lora_params .append (param ) return lora_params def enable_lora_training (model :nn .Module )->List [nn .Parameter ]: """ Enable training for LoRA parameters (ensure requires_grad=True). Returns list of LoRA parameters. """ lora_params =[] for name ,param in model .named_parameters (): if 'lora_A'in name or 'lora_B'in name or 'magnitude'in name : param .requires_grad =True lora_params .append (param ) return lora_params def freeze_non_lora_params (model :nn .Module )->int : """ Freeze all non-LoRA parameters and clear their gradients. USE THIS ONLY FOR LORA-ONLY TRAINING MODE (train_lora_only=True). For normal training with parallel fine-tuning (LoRA + full weights on active components), use the model's freeze_components() method instead, which respects the training mode flags (--text, --video, --image, --voice). Returns: Number of frozen parameters """ frozen_params =0 freed_memory =0 for name ,param in model .named_parameters (): is_lora ='lora_A'in name or 'lora_B'in name or 'magnitude'in name if not is_lora : param .requires_grad =False frozen_params +=param .numel () if param .grad is not None : freed_memory +=param .grad .numel ()*param .grad .element_size () param .grad =None print (f" โ„๏ธ Frozen {frozen_params :,} non-LoRA parameters") if freed_memory >0 : print (f" ๐Ÿงน Freed {freed_memory /(1024 **2 ):.1f}MB of gradient memory") return frozen_params def get_lora_plus_param_groups ( model :nn .Module , base_lr :float , lr_ratio :float =16.0 )->List [Dict ]: """ Get parameter groups for LoRA+ training. LoRA+ uses different learning rates for A and B matrices: - B matrix: base_lr * lr_ratio (learns faster) - A matrix: base_lr This improves convergence and final performance. """ lora_a_params =[] lora_b_params =[] magnitude_params =[] other_params =[] for name ,param in model .named_parameters (): if not param .requires_grad : continue if 'lora_A'in name : lora_a_params .append (param ) elif 'lora_B'in name : lora_b_params .append (param ) elif 'magnitude'in name : magnitude_params .append (param ) else : other_params .append (param ) param_groups =[] if lora_a_params : param_groups .append ({'params':lora_a_params ,'lr':base_lr ,'name':'lora_A'}) if lora_b_params : param_groups .append ({'params':lora_b_params ,'lr':base_lr *lr_ratio ,'name':'lora_B'}) if magnitude_params : param_groups .append ({'params':magnitude_params ,'lr':base_lr ,'name':'magnitude'}) if other_params : param_groups .append ({'params':other_params ,'lr':base_lr ,'name':'other'}) return param_groups def get_trainable_parameters (model :nn .Module ,train_lora_only :bool =False )->List [nn .Parameter ]: """Get trainable parameters, optionally only LoRA params.""" if train_lora_only : return get_lora_parameters (model ) else : return [p for p in model .parameters ()if p .requires_grad ] def count_lora_parameters (model :nn .Module )->Tuple [int ,int ,float ]: """ Count LoRA parameters vs total parameters. Returns: (lora_params, total_params, percentage) """ lora_params =0 total_params =0 for name ,param in model .named_parameters (): total_params +=param .numel () if 'lora_A'in name or 'lora_B'in name or 'magnitude'in name : lora_params +=param .numel () percentage =100.0 *lora_params /total_params if total_params >0 else 0.0 return lora_params ,total_params ,percentage ============================================================================== MODELS.COMPONENTS.ATTENTION ============================================================================== logger =logging .getLogger (__name__ ) def flash_attention_available ()->bool : """Check if Flash Attention (via SDPA) is available.""" try : from torch .nn .functional import scaled_dot_product_attention return True except ImportError : return False def compute_qk_scale (head_dim :int )->float : """Compute the Q/K pre-scaling factor for FP16 stability. By scaling both Q and K by head_dim^-0.25, the product Q@K^T is effectively scaled by head_dim^-0.5 (the standard attention scaling). This prevents overflow in FP16 when Q and K have large values. """ return head_dim **-0.25 class AttentionKVCache : """Pre-allocated KV Cache โ€” static buffer with index-based filling. Eliminates VRAM fragmentation from torch.cat during autoregressive generation. Buffer is allocated once at first use and reused via slice assignment. """ __slots__ =('key_cache','value_cache','seen_tokens','_max_len') def __init__ (self ,max_seq_len :int =131072 ): self .key_cache :torch .Tensor =None self .value_cache :torch .Tensor =None self .seen_tokens :int =0 self ._max_len =max_seq_len def _allocate (self ,batch :int ,heads :int ,head_dim :int ,device :torch .device ,dtype :torch .dtype ): """Allocate static buffer on first use.""" self .key_cache =torch .zeros (batch ,heads ,self ._max_len ,head_dim ,device =device ,dtype =dtype ) self .value_cache =torch .zeros (batch ,heads ,self ._max_len ,head_dim ,device =device ,dtype =dtype ) def update ( self , key_states :torch .Tensor , value_states :torch .Tensor , )->Tuple [torch .Tensor ,torch .Tensor ]: """ Update cache with new key/value states using index-based filling. Args: key_states: New key states [batch, num_heads, seq_len, head_dim] value_states: New value states [batch, num_heads, seq_len, head_dim] Returns: Updated key and value states including cache (views, no copy) """ batch ,heads ,new_len ,head_dim =key_states .shape if self .key_cache is None : self ._allocate (batch ,heads ,head_dim ,key_states .device ,key_states .dtype ) self .seen_tokens =0 if self .seen_tokens +new_len >self .key_cache .shape [2 ]: new_max =max (self .key_cache .shape [2 ]*2 ,self .seen_tokens +new_len ) new_key =torch .zeros (batch ,heads ,new_max ,head_dim ,device =key_states .device ,dtype =key_states .dtype ) new_val =torch .zeros (batch ,heads ,new_max ,head_dim ,device =key_states .device ,dtype =key_states .dtype ) new_key [:,:,:self .seen_tokens ]=self .key_cache [:,:,:self .seen_tokens ] new_val [:,:,:self .seen_tokens ]=self .value_cache [:,:,:self .seen_tokens ] self .key_cache =new_key self .value_cache =new_val self .key_cache [:,:,self .seen_tokens :self .seen_tokens +new_len ]=key_states self .value_cache [:,:,self .seen_tokens :self .seen_tokens +new_len ]=value_states self .seen_tokens +=new_len return self .key_cache [:,:,:self .seen_tokens ],self .value_cache [:,:,:self .seen_tokens ] def get_seq_length (self )->int : """Get current sequence length in cache.""" return self .seen_tokens def reset (self ): """Reset cache position without deallocating the buffer.""" self .seen_tokens =0 class FlashAttention (nn .Module ): """ SOTA Flash Attention with KV cache support and FP16-safe Q/K pre-scaling. Uses PyTorch's scaled_dot_product_attention when available, with fallback to standard attention. Supports: - KV caching for efficient generation - Causal masking - Attention dropout - Pre-scaled Q/K for FP16 stability """ def __init__ ( self , dropout :float =0.0 , causal :bool =False , head_dim :int =None , ): super ().__init__ () self .dropout =dropout self .causal =causal self ._flash_available =flash_attention_available () self ._head_dim =head_dim self ._qk_scale =compute_qk_scale (head_dim )if head_dim else None def forward ( self , query :torch .Tensor , key :torch .Tensor , value :torch .Tensor , attn_mask :torch .Tensor =None , is_causal :bool =None , past_key_value :Tuple [torch .Tensor ,torch .Tensor ]=None , use_cache :bool =False , output_attentions :bool =False , )->Tuple [torch .Tensor ,Tuple [torch .Tensor ,torch .Tensor ],torch .Tensor ]: """ Forward pass with KV cache support. Args: query: Query tensor [batch, num_heads, seq_len, head_dim] key: Key tensor [batch, num_heads, seq_len, head_dim] value: Value tensor [batch, num_heads, seq_len, head_dim] attn_mask: Optional attention mask is_causal: Override causal setting past_key_value: Optional tuple of (past_key, past_value) for KV cache use_cache: Whether to return updated KV cache output_attentions: Whether to return attention weights Returns: Tuple of (output, present_key_value, attention_weights) """ causal =is_causal if is_causal is not None else self .causal batch_size ,num_heads ,seq_len ,head_dim =query .shape qk_scale =self ._qk_scale if self ._qk_scale else compute_qk_scale (head_dim ) if past_key_value is not None : past_key ,past_value =past_key_value key =torch .cat ([past_key ,key ],dim =2 ) value =torch .cat ([past_value ,value ],dim =2 ) present_key_value =(key ,value )if use_cache else None kv_seq_len =key .shape [2 ] attn_weights =None if self ._flash_available and not output_attentions : query_scaled =query *qk_scale key_scaled =key *qk_scale dropout_p =self .dropout if self .training else 0.0 use_causal =causal and attn_mask is None and seq_len >1 and seq_len ==kv_seq_len output =F .scaled_dot_product_attention ( query_scaled ,key_scaled ,value , attn_mask =attn_mask , dropout_p =dropout_p , is_causal =use_causal , scale =1.0 , ) else : scale =1.0 /math .sqrt (head_dim ) attn_weights =torch .matmul (query ,key .transpose (-2 ,-1 ))*scale if causal and attn_mask is None and seq_len >1 : causal_mask =torch .triu ( torch .full ((seq_len ,kv_seq_len ),float ('-inf'),device =query .device ,dtype =query .dtype ), diagonal =kv_seq_len -seq_len +1 ) attn_weights =attn_weights +causal_mask .unsqueeze (0 ).unsqueeze (0 ) if attn_mask is not None : attn_weights =attn_weights +attn_mask attn_weights =F .softmax (attn_weights ,dim =-1 ,dtype =query .dtype ) if self .training and self .dropout >0 : attn_weights =F .dropout (attn_weights ,p =self .dropout ) output =torch .matmul (attn_weights ,value ) return output ,present_key_value ,attn_weights class MultimodalCrossAttention (nn .Module ): """ SOTA Cross-attention layer for multimodal fusion with KV cache support. Allows text to attend to image/video/audio features with: - KV caching for efficient generation - Gated residual connection for stable training - Flash Attention support with pre-scaled Q/K for FP16 stability - Optional attention weight output """ def __init__ ( self , hidden_size :int , num_heads :int =8 , dropout :float =0.1 , use_flash_attention :bool =True , gate_init :float =0.0 , ): super ().__init__ () self .hidden_size =hidden_size self .num_heads =num_heads self .head_dim =hidden_size //num_heads self .use_flash_attention =use_flash_attention and flash_attention_available () self .dropout_p =dropout self .qk_scale =compute_qk_scale (self .head_dim ) self .q_proj =nn .Linear (hidden_size ,hidden_size ,bias =False ) self .k_proj =nn .Linear (hidden_size ,hidden_size ,bias =False ) self .v_proj =nn .Linear (hidden_size ,hidden_size ,bias =False ) self .o_proj =nn .Linear (hidden_size ,hidden_size ,bias =False ) self .dropout =nn .Dropout (dropout ) self .layer_norm =nn .LayerNorm (hidden_size ) self .gate =nn .Parameter (torch .tensor (gate_init )) def forward ( self , text_hidden :torch .Tensor , modality_hidden :torch .Tensor , modality_mask :torch .Tensor =None , past_key_value :Tuple [torch .Tensor ,torch .Tensor ]=None , use_cache :bool =False , output_attentions :bool =False , )->Tuple [torch .Tensor ,Tuple [torch .Tensor ,torch .Tensor ],torch .Tensor ]: """ Cross-attention: text attends to modality features with KV cache support. Args: text_hidden: Text hidden states [batch, text_len, hidden_size] modality_hidden: Modality features [batch, modality_len, hidden_size] modality_mask: Optional attention mask for modality past_key_value: Optional cached (key, value) for this layer use_cache: Whether to return updated KV cache output_attentions: Whether to return attention weights Returns: Tuple of (output, present_key_value, attention_weights) """ batch_size ,text_len ,_ =text_hidden .shape query =self .q_proj (text_hidden ) query =query .view (batch_size ,text_len ,self .num_heads ,self .head_dim ).transpose (1 ,2 ) if past_key_value is not None : key ,value =past_key_value else : modality_len =modality_hidden .shape [1 ] key =self .k_proj (modality_hidden ) value =self .v_proj (modality_hidden ) key =key .view (batch_size ,modality_len ,self .num_heads ,self .head_dim ).transpose (1 ,2 ) value =value .view (batch_size ,modality_len ,self .num_heads ,self .head_dim ).transpose (1 ,2 ) present_key_value =(key ,value )if use_cache else None attn_weights =None if self .use_flash_attention and not output_attentions : query_scaled =query *self .qk_scale key_scaled =key *self .qk_scale dropout_p =self .dropout_p if self .training else 0.0 attn_output =F .scaled_dot_product_attention ( query_scaled ,key_scaled ,value , attn_mask =modality_mask , dropout_p =dropout_p , is_causal =False , scale =1.0 , ) else : scale =1.0 /math .sqrt (self .head_dim ) attn_weights =torch .matmul (query ,key .transpose (-2 ,-1 ))*scale if modality_mask is not None : attn_weights =attn_weights +modality_mask attn_weights =F .softmax (attn_weights ,dim =-1 ,dtype =text_hidden .dtype ) if self .training and self .dropout_p >0 : attn_weights =F .dropout (attn_weights ,p =self .dropout_p ) attn_output =torch .matmul (attn_weights ,value ) attn_output =attn_output .transpose (1 ,2 ).contiguous ().view (batch_size ,text_len ,self .hidden_size ) attn_output =self .o_proj (attn_output ) gate =torch .sigmoid (self .gate ) output =text_hidden +gate *self .dropout (attn_output ) output =self .layer_norm (output ) return output ,present_key_value ,attn_weights @dataclass class MultimodalFusionCache : """Cache for multimodal fusion layer KV states.""" image_kv :Tuple [torch .Tensor ,torch .Tensor ]=None video_kv :Tuple [torch .Tensor ,torch .Tensor ]=None audio_kv :Tuple [torch .Tensor ,torch .Tensor ]=None class MultimodalFusionLayer (nn .Module ): """ SOTA Multimodal fusion layer with cross-attention for all modalities and KV cache support. Features: - Separate cross-attention for each modality (image, video, audio) - KV caching for efficient generation - Gated fusion MLP - Flash Attention support """ def __init__ ( self , hidden_size :int , num_heads :int =8 , dropout :float =0.1 , use_flash_attention :bool =True , ): super ().__init__ () self .hidden_size =hidden_size self .image_cross_attn =MultimodalCrossAttention ( hidden_size ,num_heads ,dropout ,use_flash_attention ) self .video_cross_attn =MultimodalCrossAttention ( hidden_size ,num_heads ,dropout ,use_flash_attention ) self .audio_cross_attn =MultimodalCrossAttention ( hidden_size ,num_heads ,dropout ,use_flash_attention ) self .fusion_mlp =nn .Sequential ( nn .Linear (hidden_size ,hidden_size *4 ), nn .GELU (), nn .Dropout (dropout ), nn .Linear (hidden_size *4 ,hidden_size ), nn .Dropout (dropout ), ) self .fusion_norm =nn .LayerNorm (hidden_size ) def forward ( self , text_hidden :torch .Tensor , image_hidden :torch .Tensor =None , video_hidden :torch .Tensor =None , audio_hidden :torch .Tensor =None , image_mask :torch .Tensor =None , video_mask :torch .Tensor =None , audio_mask :torch .Tensor =None , past_key_values :MultimodalFusionCache =None , use_cache :bool =False , )->Tuple [torch .Tensor ,MultimodalFusionCache ]: """ Fuse text with available modalities via cross-attention with KV cache support. Args: text_hidden: Text hidden states [batch, text_len, hidden_size] image_hidden: Image features [batch, image_len, hidden_size] video_hidden: Video features [batch, video_len, hidden_size] audio_hidden: Audio features [batch, audio_len, hidden_size] image_mask: Attention mask for image video_mask: Attention mask for video audio_mask: Attention mask for audio past_key_values: Cached KV states from previous forward pass use_cache: Whether to return updated KV cache Returns: Tuple of (output, present_key_values) """ present_key_values =MultimodalFusionCache ()if use_cache else None past_image_kv =past_key_values .image_kv if past_key_values else None past_video_kv =past_key_values .video_kv if past_key_values else None past_audio_kv =past_key_values .audio_kv if past_key_values else None if self ._has_content (image_hidden )or past_image_kv is not None : try : text_hidden ,image_kv ,_ =self .image_cross_attn ( text_hidden , image_hidden if image_hidden is not None else torch .zeros (text_hidden .shape [0 ],1 ,self .hidden_size ,device =text_hidden .device ), image_mask , past_key_value =past_image_kv , use_cache =use_cache , ) if use_cache : present_key_values .image_kv =image_kv except Exception as e : logger .debug (f"Image cross-attention skipped: {e }") if self ._has_content (video_hidden )or past_video_kv is not None : try : text_hidden ,video_kv ,_ =self .video_cross_attn ( text_hidden , video_hidden if video_hidden is not None else torch .zeros (text_hidden .shape [0 ],1 ,self .hidden_size ,device =text_hidden .device ), video_mask , past_key_value =past_video_kv , use_cache =use_cache , ) if use_cache : present_key_values .video_kv =video_kv except Exception as e : logger .debug (f"Video cross-attention skipped: {e }") if self ._has_content (audio_hidden )or past_audio_kv is not None : try : text_hidden ,audio_kv ,_ =self .audio_cross_attn ( text_hidden , audio_hidden if audio_hidden is not None else torch .zeros (text_hidden .shape [0 ],1 ,self .hidden_size ,device =text_hidden .device ), audio_mask , past_key_value =past_audio_kv , use_cache =use_cache , ) if use_cache : present_key_values .audio_kv =audio_kv except Exception as e : logger .debug (f"Audio cross-attention skipped: {e }") residual =text_hidden text_hidden =self .fusion_mlp (text_hidden ) text_hidden =self .fusion_norm (residual +text_hidden ) return text_hidden ,present_key_values @staticmethod def _has_content (tensor :torch .Tensor )->bool : """Check if tensor has meaningful content.""" if tensor is None : return False if not isinstance (tensor ,torch .Tensor ): return False try : if tensor .numel ()==0 : return False return bool (tensor .any ()) except Exception : return False ============================================================================== MODELS.COMPONENTS.PROJECTORS ============================================================================== def compute_2d_rope (height :int ,width :int ,dim :int ,device :torch .device ,dtype :torch .dtype ,base :float =10000.0 )->Tuple [torch .Tensor ,torch .Tensor ]: """ Compute 2D Rotary Position Embeddings for spatial awareness. Args: height: Image height in patches width: Image width in patches dim: Embedding dimension (must be divisible by 4) device: Target device dtype: Target dtype base: RoPE base frequency Returns: cos, sin: [height*width, dim] position embeddings """ assert dim %4 ==0 ,"dim must be divisible by 4 for 2D RoPE" half_dim =dim //2 quarter_dim =dim //4 inv_freq =1.0 /(base **(torch .arange (0 ,quarter_dim ,device =device ,dtype =torch .float32 )/quarter_dim )) y_pos =torch .arange (height ,device =device ,dtype =torch .float32 ) x_pos =torch .arange (width ,device =device ,dtype =torch .float32 ) y_emb =torch .outer (y_pos ,inv_freq ) x_emb =torch .outer (x_pos ,inv_freq ) y_emb =y_emb .unsqueeze (1 ).expand (-1 ,width ,-1 ) x_emb =x_emb .unsqueeze (0 ).expand (height ,-1 ,-1 ) emb =torch .cat ([y_emb ,y_emb ,x_emb ,x_emb ],dim =-1 ) emb =emb .reshape (height *width ,dim ) return emb .cos ().to (dtype ),emb .sin ().to (dtype ) def compute_3d_rope ( depth :int ,height :int ,width :int ,dim :int , device :torch .device ,dtype :torch .dtype ,base :float =10000.0 )->Tuple [torch .Tensor ,torch .Tensor ]: """ Compute 3D Rotary Position Embeddings for video/temporal awareness. Args: depth: Temporal depth (number of frames) height: Image height in patches width: Image width in patches dim: Embedding dimension (must be divisible by 6) device: Target device dtype: Target dtype base: RoPE base frequency Returns: cos, sin: [depth*height*width, dim] position embeddings """ assert dim %6 ==0 ,"dim must be divisible by 6 for 3D RoPE" sixth_dim =dim //6 inv_freq =1.0 /(base **(torch .arange (0 ,sixth_dim ,device =device ,dtype =torch .float32 )/sixth_dim )) t_pos =torch .arange (depth ,device =device ,dtype =torch .float32 ) y_pos =torch .arange (height ,device =device ,dtype =torch .float32 ) x_pos =torch .arange (width ,device =device ,dtype =torch .float32 ) t_emb =torch .outer (t_pos ,inv_freq ) y_emb =torch .outer (y_pos ,inv_freq ) x_emb =torch .outer (x_pos ,inv_freq ) t_emb =t_emb .unsqueeze (1 ).unsqueeze (2 ).expand (-1 ,height ,width ,-1 ) y_emb =y_emb .unsqueeze (0 ).unsqueeze (2 ).expand (depth ,-1 ,width ,-1 ) x_emb =x_emb .unsqueeze (0 ).unsqueeze (1 ).expand (depth ,height ,-1 ,-1 ) emb =torch .cat ([t_emb ,t_emb ,y_emb ,y_emb ,x_emb ,x_emb ],dim =-1 ) emb =emb .reshape (depth *height *width ,dim ) return emb .cos ().to (dtype ),emb .sin ().to (dtype ) def apply_rope (x :torch .Tensor ,cos :torch .Tensor ,sin :torch .Tensor )->torch .Tensor : """Apply rotary position embeddings.""" x1 =x [...,:x .shape [-1 ]//2 ] x2 =x [...,x .shape [-1 ]//2 :] rotated =torch .cat ((-x2 ,x1 ),dim =-1 ) return x *cos +rotated *sin class ResidualBottleneckBlock (nn .Module ): """ Residual Bottleneck Block for locality-enhanced feature extraction. Preserves small-scale features (OCR, fine audio events) during compression. """ def __init__ (self ,in_channels :int ,out_channels :int ,bottleneck_ratio :float =0.25 ): super ().__init__ () bottleneck_channels =int (out_channels *bottleneck_ratio ) self .conv1 =nn .Conv2d (in_channels ,bottleneck_channels ,1 ,bias =False ) self .bn1 =nn .BatchNorm2d (bottleneck_channels ) self .conv2 =nn .Conv2d (bottleneck_channels ,bottleneck_channels ,3 ,padding =1 ,bias =False ) self .bn2 =nn .BatchNorm2d (bottleneck_channels ) self .conv3 =nn .Conv2d (bottleneck_channels ,out_channels ,1 ,bias =False ) self .bn3 =nn .BatchNorm2d (out_channels ) self .shortcut =nn .Identity ()if in_channels ==out_channels else nn .Sequential ( nn .Conv2d (in_channels ,out_channels ,1 ,bias =False ), nn .BatchNorm2d (out_channels ), ) self .relu =nn .ReLU (inplace =True ) def forward (self ,x :torch .Tensor )->torch .Tensor : identity =self .shortcut (x ) out =self .relu (self .bn1 (self .conv1 (x ))) out =self .relu (self .bn2 (self .conv2 (out ))) out =self .bn3 (self .conv3 (out )) out =out +identity out =self .relu (out ) return out class LocalityEnhancedResNetAbstractor (nn .Module ): """ Locality-Enhanced ResNet Abstractor. Upgrades the C-Abstractor with residual bottleneck blocks to preserve small-scale features (OCR/fine audio events) during compression. """ def __init__ ( self , input_dim :int , output_dim :int , num_tokens :int =64 , num_blocks :int =3 , use_2d_rope :bool =True , ): super ().__init__ () self .num_tokens =num_tokens self .use_2d_rope =use_2d_rope self .input_proj =nn .Linear (input_dim ,output_dim ) self .blocks =nn .ModuleList ([ ResidualBottleneckBlock (output_dim ,output_dim ) for _ in range (num_blocks ) ]) self .queries =nn .Parameter (torch .randn (1 ,num_tokens ,output_dim )*0.02 ) self .cross_attn =nn .MultiheadAttention ( embed_dim =output_dim , num_heads =8 , batch_first =True , dropout =0.1 , ) self .ff =nn .Sequential ( nn .LayerNorm (output_dim ), nn .Linear (output_dim ,output_dim *4 ), nn .GELU (), nn .Linear (output_dim *4 ,output_dim ), ) self .norm =nn .LayerNorm (output_dim ) print (f" ๐Ÿ—๏ธ LocalityEnhancedResNetAbstractor: {input_dim } -> {output_dim }, {num_tokens } tokens") def forward (self ,features :torch .Tensor ,spatial_size :Optional [Tuple [int ,int ]]=None )->torch .Tensor : """ Args: features: [B, seq_len, input_dim] or [B, H, W, input_dim] spatial_size: (H, W) if features are flattened Returns: abstracted: [B, num_tokens, output_dim] """ batch_size =features .shape [0 ] x =self .input_proj (features ) if features .dim ()==3 : seq_len =features .shape [1 ] if spatial_size is None : h =w =int (math .sqrt (seq_len )) else : h ,w =spatial_size x =x .view (batch_size ,h ,w ,-1 ) else : h ,w =features .shape [1 ],features .shape [2 ] x =x .permute (0 ,3 ,1 ,2 ) for block in self .blocks : x =block (x ) x =x .permute (0 ,2 ,3 ,1 ) x =x .reshape (batch_size ,h *w ,-1 ) if self .use_2d_rope : cos ,sin =compute_2d_rope (h ,w ,x .shape [-1 ],x .device ,x .dtype ) x =apply_rope (x ,cos .unsqueeze (0 ),sin .unsqueeze (0 )) queries =self .queries .expand (batch_size ,-1 ,-1 ) abstracted ,_ =self .cross_attn (queries ,x ,x ) abstracted =abstracted +self .ff (abstracted ) return self .norm (abstracted ) class MultiScaleFeatureFusion (nn .Module ): """ Multi-Scale Feature Fusion (MSFF). Extracts and weights features from multiple encoder depths (early, mid, late) to capture both low-level textures and high-level semantics. """ def __init__ ( self , feature_dims :List [int ], output_dim :int , num_scales :int =3 , ): super ().__init__ () self .num_scales =num_scales self .scale_projs =nn .ModuleList ([ nn .Linear (dim ,output_dim )for dim in feature_dims ]) self .scale_weights =nn .Parameter (torch .ones (num_scales )/num_scales ) self .fusion =nn .Sequential ( nn .Linear (output_dim ,output_dim *2 ), nn .GELU (), nn .Linear (output_dim *2 ,output_dim ), ) self .norm =nn .LayerNorm (output_dim ) print (f" ๐Ÿ”€ MultiScaleFeatureFusion: {feature_dims } -> {output_dim }") def forward (self ,multi_scale_features :List [torch .Tensor ])->torch .Tensor : """ Args: multi_scale_features: List of [B, seq_len, dim] features from different depths Returns: fused: [B, seq_len, output_dim] """ assert len (multi_scale_features )==self .num_scales projected =[] for i ,(features ,proj )in enumerate (zip (multi_scale_features ,self .scale_projs )): projected .append (proj (features )) weights =F .softmax (self .scale_weights ,dim =0 ) fused =sum (w *p for w ,p in zip (weights ,projected )) fused =fused +self .fusion (fused ) return self .norm (fused ) class MultiScaleDeformableAttention (nn .Module ): """ Multi-Scale Deformable Attention. Replaces fixed-grid cross-attention in Perceiver Resamplers, allowing the projector to "look" at non-uniform regions of interest. """ def __init__ ( self , dim :int , num_heads :int =8 , num_levels :int =4 , num_points :int =4 , dropout :float =0.1 , ): super ().__init__ () self .dim =dim self .num_heads =num_heads self .num_levels =num_levels self .num_points =num_points self .head_dim =dim //num_heads self .sampling_offsets =nn .Linear (dim ,num_heads *num_levels *num_points *2 ) self .attention_weights =nn .Linear (dim ,num_heads *num_levels *num_points ) self .value_proj =nn .Linear (dim ,dim ) self .output_proj =nn .Linear (dim ,dim ) self .dropout =nn .Dropout (dropout ) self ._reset_parameters () print (f" ๐ŸŽฏ MultiScaleDeformableAttention: {dim }d, {num_heads }H, {num_levels }L, {num_points }P") def _reset_parameters (self ): nn .init .constant_ (self .sampling_offsets .weight ,0.0 ) nn .init .constant_ (self .sampling_offsets .bias ,0.0 ) nn .init .xavier_uniform_ (self .attention_weights .weight ) nn .init .constant_ (self .attention_weights .bias ,0.0 ) nn .init .xavier_uniform_ (self .value_proj .weight ) nn .init .xavier_uniform_ (self .output_proj .weight ) def forward ( self , query :torch .Tensor , reference_points :torch .Tensor , input_flatten :torch .Tensor , input_spatial_shapes :torch .Tensor , )->torch .Tensor : """ Args: query: [B, num_queries, dim] reference_points: [B, num_queries, num_levels, 2] normalized reference points input_flatten: [B, sum(H*W), dim] flattened multi-scale features input_spatial_shapes: [num_levels, 2] spatial shapes of each level Returns: output: [B, num_queries, dim] """ batch_size ,num_queries ,_ =query .shape offsets =self .sampling_offsets (query ) offsets =offsets .view (batch_size ,num_queries ,self .num_heads ,self .num_levels ,self .num_points ,2 ) attn_weights =self .attention_weights (query ) attn_weights =attn_weights .view (batch_size ,num_queries ,self .num_heads ,self .num_levels *self .num_points ) attn_weights =F .softmax (attn_weights ,dim =-1 ) attn_weights =attn_weights .view (batch_size ,num_queries ,self .num_heads ,self .num_levels ,self .num_points ) sampling_locations =reference_points .unsqueeze (2 ).unsqueeze (4 )+offsets *0.1 sampling_locations =sampling_locations .clamp (0 ,1 ) value =self .value_proj (input_flatten ) value =value .view (batch_size ,-1 ,self .num_heads ,self .head_dim ) output =torch .zeros (batch_size ,num_queries ,self .num_heads ,self .head_dim ,device =query .device ,dtype =query .dtype ) start_idx =0 for level_idx in range (self .num_levels ): h ,w =input_spatial_shapes [level_idx ] end_idx =start_idx +h *w level_value =value [:,start_idx :end_idx ] level_value =level_value .view (batch_size ,h ,w ,self .num_heads ,self .head_dim ) level_locs =sampling_locations [:,:,:,level_idx ] level_weights =attn_weights [:,:,:,level_idx ] for point_idx in range (self .num_points ): loc =level_locs [:,:,:,point_idx ] weight =level_weights [:,:,:,point_idx :point_idx +1 ] y_idx =(loc [...,0 ]*(h -1 )).long ().clamp (0 ,h -1 ) x_idx =(loc [...,1 ]*(w -1 )).long ().clamp (0 ,w -1 ) for b in range (batch_size ): for q in range (num_queries ): for head in range (self .num_heads ): y ,x =y_idx [b ,q ,head ].item (),x_idx [b ,q ,head ].item () output [b ,q ,head ]+=weight [b ,q ,head ]*level_value [b ,y ,x ,head ] start_idx =end_idx output =output .view (batch_size ,num_queries ,self .dim ) output =self .output_proj (output ) output =self .dropout (output ) return output class DynamicTokenRouter (nn .Module ): """ Dynamic Token Router. Implements a sparse gating mechanism to drop redundant "background" tokens, drastically reducing KV-cache pressure for Ring Attention. """ def __init__ ( self , dim :int , num_tokens :int , keep_ratio :float =0.5 , temperature :float =1.0 , ): super ().__init__ () self .dim =dim self .num_tokens =num_tokens self .keep_ratio =keep_ratio self .temperature =temperature self .scorer =nn .Sequential ( nn .Linear (dim ,dim //2 ), nn .GELU (), nn .Linear (dim //2 ,1 ), ) self .threshold =nn .Parameter (torch .tensor (0.0 )) print (f" ๐Ÿšฆ DynamicTokenRouter: keep_ratio={keep_ratio }") def forward (self ,tokens :torch .Tensor ,return_mask :bool =False )->Tuple [torch .Tensor ,Optional [torch .Tensor ]]: """ Args: tokens: [B, num_tokens, dim] return_mask: Whether to return the selection mask Returns: selected_tokens: [B, num_kept, dim] mask: [B, num_tokens] selection mask (if return_mask=True) """ batch_size ,num_tokens ,_ =tokens .shape num_keep =max (1 ,int (num_tokens *self .keep_ratio )) scores =self .scorer (tokens ).squeeze (-1 ) scores =scores /self .temperature _ ,indices =torch .topk (scores ,num_keep ,dim =-1 ) indices =indices .sort (dim =-1 ).values indices_expanded =indices .unsqueeze (-1 ).expand (-1 ,-1 ,self .dim ) selected_tokens =torch .gather (tokens ,1 ,indices_expanded ) if return_mask : mask =torch .zeros (batch_size ,num_tokens ,device =tokens .device ,dtype =torch .bool ) mask .scatter_ (1 ,indices ,True ) return selected_tokens ,mask return selected_tokens ,None class PerceiverAttention (nn .Module ): """ Perceiver-style cross-attention for resampling with 2D/3D RoPE support. """ def __init__ ( self , dim :int , num_heads :int =8 , dim_head :int =64 , dropout :float =0.0 , use_rope :bool =True , ): super ().__init__ () inner_dim =dim_head *num_heads self .num_heads =num_heads self .dim_head =dim_head self .inner_dim =inner_dim self .scale =dim_head **-0.5 self .use_rope =use_rope self .norm_latents =nn .LayerNorm (dim ) self .norm_context =nn .LayerNorm (dim ) self .to_q =nn .Linear (dim ,inner_dim ,bias =False ) self .to_kv =nn .Linear (dim ,inner_dim *2 ,bias =False ) self .to_out =nn .Sequential ( nn .Linear (inner_dim ,dim ), nn .Dropout (dropout ) ) def forward ( self , latents :torch .Tensor , context :torch .Tensor , context_rope :Optional [Tuple [torch .Tensor ,torch .Tensor ]]=None , )->torch .Tensor : """ latents: [B, num_latents, dim] - learnable queries context: [B, seq_len, dim] - input features to attend to context_rope: Optional (cos, sin) for context positions """ latents =self .norm_latents (latents ) context =self .norm_context (context ) b ,n ,_ =latents .shape ctx_len =context .shape [1 ] h =self .num_heads d =self .dim_head q =self .to_q (latents ) kv =self .to_kv (context ).chunk (2 ,dim =-1 ) k ,v =kv q =q .reshape (b ,n ,h ,d ).transpose (1 ,2 ) k =k .reshape (b ,ctx_len ,h ,d ).transpose (1 ,2 ) v =v .reshape (b ,ctx_len ,h ,d ).transpose (1 ,2 ) if self .use_rope and context_rope is not None : cos ,sin =context_rope cos =cos .unsqueeze (0 ).unsqueeze (0 ) sin =sin .unsqueeze (0 ).unsqueeze (0 ) k =apply_rope (k ,cos ,sin ) qk_scale =d **-0.25 out =F .scaled_dot_product_attention ( q *qk_scale ,k *qk_scale ,v , is_causal =False ,scale =1.0 , ) out =out .transpose (1 ,2 ).reshape (b ,n ,self .inner_dim ) return self .to_out (out ) class PerceiverResampler (nn .Module ): """ Perceiver Resampler with 2D/3D RoPE and Dynamic Token Routing. """ def __init__ ( self , input_dim :int , output_dim :int , num_latents :int =64 , num_heads :int =8 , num_layers :int =2 , dropout :float =0.0 , use_rope :bool =True , use_dynamic_routing :bool =False , routing_keep_ratio :float =0.5 , ): super ().__init__ () self .num_latents =num_latents self .use_rope =use_rope self .use_dynamic_routing =use_dynamic_routing self .input_proj =nn .Linear (input_dim ,output_dim )if input_dim !=output_dim else nn .Identity () self .latents =nn .Parameter (torch .randn (1 ,num_latents ,output_dim )*0.02 ) self .layers =nn .ModuleList ([ nn .ModuleList ([ PerceiverAttention (output_dim ,num_heads ,output_dim //num_heads ,dropout ,use_rope ), nn .Sequential ( nn .LayerNorm (output_dim ), nn .Linear (output_dim ,output_dim *4 ), nn .GELU (), nn .Dropout (dropout ), nn .Linear (output_dim *4 ,output_dim ), nn .Dropout (dropout ), ) ]) for _ in range (num_layers ) ]) if use_dynamic_routing : self .token_router =DynamicTokenRouter (output_dim ,num_latents ,routing_keep_ratio ) else : self .token_router =None self .norm_out =nn .LayerNorm (output_dim ) def forward ( self , x :torch .Tensor , spatial_size :Optional [Tuple [int ,int ]]=None , temporal_size :Optional [int ]=None , )->torch .Tensor : """ x: [B, seq_len, input_dim] - input features spatial_size: (H, W) for 2D RoPE temporal_size: T for 3D RoPE (video) returns: [B, num_latents, output_dim] - compressed features """ batch_size =x .shape [0 ] x =self .input_proj (x ) context_rope =None if self .use_rope and spatial_size is not None : h ,w =spatial_size if temporal_size is not None : cos ,sin =compute_3d_rope (temporal_size ,h ,w ,x .shape [-1 ],x .device ,x .dtype ) else : cos ,sin =compute_2d_rope (h ,w ,x .shape [-1 ],x .device ,x .dtype ) context_rope =(cos ,sin ) latents =self .latents .expand (batch_size ,-1 ,-1 ) for attn ,ff in self .layers : latents =latents +attn (latents ,x ,context_rope ) latents =latents +ff (latents ) latents =self .norm_out (latents ) if self .token_router is not None : latents ,_ =self .token_router (latents ) return latents class SpatialAwareProjector (nn .Module ): """ Spatial-aware projector with 2D RoPE. """ def __init__ ( self , vision_hidden_size :int , llm_hidden_size :int , num_tokens :int =64 , spatial_pool_size :int =8 , use_rope :bool =True , ): super ().__init__ () self .num_tokens =num_tokens self .spatial_pool_size =spatial_pool_size self .use_rope =use_rope self .spatial_conv =nn .Sequential ( nn .Conv2d (vision_hidden_size ,llm_hidden_size ,3 ,padding =1 ), nn .GELU (), nn .Conv2d (llm_hidden_size ,llm_hidden_size ,3 ,padding =1 ), nn .GELU (), ) self .adaptive_pool =nn .AdaptiveAvgPool2d ((spatial_pool_size ,spatial_pool_size )) self .proj =nn .Sequential ( nn .Linear (llm_hidden_size ,llm_hidden_size ), nn .GELU (), nn .Linear (llm_hidden_size ,llm_hidden_size ), ) self .norm =nn .LayerNorm (llm_hidden_size ) def forward (self ,vision_features :torch .Tensor ,spatial_size :Optional [Tuple [int ,int ]]=None )->torch .Tensor : batch_size =vision_features .shape [0 ] if vision_features .dim ()==3 : seq_len =vision_features .shape [1 ] if spatial_size is None : h =w =int (math .sqrt (seq_len )) else : h ,w =spatial_size vision_features =vision_features .view (batch_size ,h ,w ,-1 ) x =vision_features .permute (0 ,3 ,1 ,2 ) x =self .spatial_conv (x ) x =self .adaptive_pool (x ) x =x .flatten (2 ).transpose (1 ,2 ) if self .use_rope : cos ,sin =compute_2d_rope (self .spatial_pool_size ,self .spatial_pool_size ,x .shape [-1 ],x .device ,x .dtype ) x =apply_rope (x ,cos .unsqueeze (0 ),sin .unsqueeze (0 )) x =self .proj (x ) x =self .norm (x ) return x class CAbstractor (nn .Module ): """ C-Abstractor: Compressed Abstraction for efficient multimodal fusion. Now with 2D RoPE support. """ def __init__ ( self , vision_hidden_size :int , llm_hidden_size :int , num_tokens :int =64 , num_heads :int =8 , compression_ratio :int =4 , use_rope :bool =True , ): super ().__init__ () self .num_tokens =num_tokens self .use_rope =use_rope self .input_proj =nn .Linear (vision_hidden_size ,llm_hidden_size ) self .compress =nn .Sequential ( nn .Conv1d (llm_hidden_size ,llm_hidden_size ,kernel_size =compression_ratio ,stride =compression_ratio ), nn .GELU (), ) self .queries =nn .Parameter (torch .randn (1 ,num_tokens ,llm_hidden_size )*0.02 ) self .cross_attn =nn .MultiheadAttention ( embed_dim =llm_hidden_size , num_heads =num_heads , batch_first =True , dropout =0.1 , ) self .ff =nn .Sequential ( nn .LayerNorm (llm_hidden_size ), nn .Linear (llm_hidden_size ,llm_hidden_size *4 ), nn .GELU (), nn .Linear (llm_hidden_size *4 ,llm_hidden_size ), ) self .norm =nn .LayerNorm (llm_hidden_size ) def forward (self ,vision_features :torch .Tensor ,spatial_size :Optional [Tuple [int ,int ]]=None )->torch .Tensor : batch_size =vision_features .shape [0 ] x =self .input_proj (vision_features ) if self .use_rope and spatial_size is not None : h ,w =spatial_size cos ,sin =compute_2d_rope (h ,w ,x .shape [-1 ],x .device ,x .dtype ) x =apply_rope (x ,cos .unsqueeze (0 ),sin .unsqueeze (0 )) x =x .transpose (1 ,2 ) x =self .compress (x ) x =x .transpose (1 ,2 ) queries =self .queries .expand (batch_size ,-1 ,-1 ) abstracted ,_ =self .cross_attn (queries ,x ,x ) abstracted =abstracted +self .ff (abstracted ) return self .norm (abstracted ) class MultimodalProjector (nn .Module ): """ SOTA Multimodal Projector with all advanced features. Combines: - Locality-Enhanced ResNet Abstractor - Multi-Scale Feature Fusion - Multi-Scale Deformable Attention - Dynamic Token Router - 2D/3D RoPE - Perceiver Resampler """ def __init__ ( self , vision_hidden_size :int , llm_hidden_size :int , num_tokens :int =64 , projector_type :str ="perceiver", num_heads :int =8 , num_layers :int =2 , use_rope :bool =True , use_dynamic_routing :bool =False , use_locality_enhanced :bool =False , use_msff :bool =False , use_deformable_attn :bool =False , ): super ().__init__ () self .num_tokens =num_tokens self .projector_type =projector_type self .use_rope =use_rope if projector_type =="perceiver": self .projector =PerceiverResampler ( input_dim =vision_hidden_size , output_dim =llm_hidden_size , num_latents =num_tokens , num_heads =num_heads , num_layers =num_layers , use_rope =use_rope , use_dynamic_routing =use_dynamic_routing , ) elif projector_type =="spatial": self .projector =SpatialAwareProjector ( vision_hidden_size =vision_hidden_size , llm_hidden_size =llm_hidden_size , num_tokens =num_tokens , use_rope =use_rope , ) elif projector_type =="c_abstractor": self .projector =CAbstractor ( vision_hidden_size =vision_hidden_size , llm_hidden_size =llm_hidden_size , num_tokens =num_tokens , num_heads =num_heads , use_rope =use_rope , ) elif projector_type =="locality_enhanced": self .projector =LocalityEnhancedResNetAbstractor ( input_dim =vision_hidden_size , output_dim =llm_hidden_size , num_tokens =num_tokens , use_2d_rope =use_rope , ) else : self .projector =nn .Sequential ( nn .Linear (vision_hidden_size ,llm_hidden_size ), nn .GELU (), nn .Linear (llm_hidden_size ,llm_hidden_size ), ) self .query_tokens =nn .Parameter (torch .randn (1 ,num_tokens ,llm_hidden_size )*0.02 ) self .cross_attn =nn .MultiheadAttention ( embed_dim =llm_hidden_size , num_heads =num_heads , batch_first =True ) self .norm =nn .LayerNorm (llm_hidden_size ) if use_msff : self .msff =MultiScaleFeatureFusion ( feature_dims =[vision_hidden_size ]*3 , output_dim =vision_hidden_size , ) else : self .msff =None if use_deformable_attn : self .deformable_attn =MultiScaleDeformableAttention ( dim =llm_hidden_size , num_heads =num_heads , ) else : self .deformable_attn =None if use_dynamic_routing and projector_type !="perceiver": self .token_router =DynamicTokenRouter (llm_hidden_size ,num_tokens ) else : self .token_router =None def forward ( self , vision_features :torch .Tensor , multi_scale_features :Optional [List [torch .Tensor ]]=None , spatial_size :Optional [Tuple [int ,int ]]=None , temporal_size :Optional [int ]=None , )->torch .Tensor : """Project and resample vision features.""" if self .msff is not None and multi_scale_features is not None : vision_features =self .msff (multi_scale_features ) if self .projector_type in ["perceiver"]: output =self .projector (vision_features ,spatial_size ,temporal_size ) elif self .projector_type in ["spatial","c_abstractor","locality_enhanced"]: output =self .projector (vision_features ,spatial_size ) else : batch_size =vision_features .shape [0 ] projected =self .projector (vision_features ) queries =self .query_tokens .expand (batch_size ,-1 ,-1 ) resampled ,_ =self .cross_attn (queries ,projected ,projected ) output =self .norm (resampled ) if self .token_router is not None : output ,_ =self .token_router (output ) return output ============================================================================== MODELS.COMPONENTS.MOE ============================================================================== EPS =1e-5 class ExpertUtilizationTracker : """ Tracks expert utilization across MoE layers. Attach to any MoE layer to log per-expert usage histograms. Every `report_interval` steps, prints a report showing: - Frequency of use per expert - Cold experts (used < 1% of tokens) - Count of experts offloaded to CPU (if ExpertOffloadManager is available) Usage: tracker = ExpertUtilizationTracker(num_experts=8, layer_name="layer.3.moe") """ def __init__ ( self , num_experts :int , layer_name :str ="moe", report_interval :int =100 , cold_threshold_pct :float =1.0 , ): self .num_experts =num_experts self .layer_name =layer_name self .report_interval =report_interval self .cold_threshold_pct =cold_threshold_pct self ._counts =torch .zeros (num_experts ,dtype =torch .long ) self ._total_tokens =0 self ._step =0 self ._offload_manager =None def link_offload_manager (self ,manager ): """Link an ExpertOffloadManager for cold-expert reporting.""" self ._offload_manager =manager def record (self ,expert_indices :torch .Tensor ): """ Record expert selections from a forward pass. Args: expert_indices: [num_tokens, top_k] tensor of selected expert indices """ indices_flat =expert_indices .detach ().cpu ().reshape (-1 ) for idx in range (self .num_experts ): self ._counts [idx ]+=(indices_flat ==idx ).sum ().item () self ._total_tokens +=expert_indices .shape [0 ] def step (self ): """Advance step counter. Prints report and resets when interval is hit.""" self ._step +=1 if self ._step %self .report_interval ==0 : self ._print_report () self ._reset () def _reset (self ): """Reset accumulators for next interval.""" self ._counts .zero_ () self ._total_tokens =0 def _print_report (self ): """Print expert utilization histogram.""" if self ._total_tokens ==0 : return freqs =self ._counts .float () total_assignments =freqs .sum ().item () if total_assignments ==0 : return pcts =(freqs /total_assignments *100 ).tolist () cold_experts =[i for i ,p in enumerate (pcts )if p 0 else 0 bar ="โ–ˆ"*bar_len cold_tag =" โ„๏ธ"if pct 6d}){cold_tag }") lines .append (f"{'โ”€'*60 }") if cold_experts : lines .append (f" โ„๏ธ Cold experts (<{self .cold_threshold_pct }%): {cold_experts }") else : lines .append (f" โœ… All experts active (no cold experts)") if self ._offload_manager is not None : status =self ._offload_manager .get_status () lines .append (f" ๐Ÿ’พ Offloaded to CPU: {status ['cpu']}/{status ['total']}") ideal_pct =100.0 /self .num_experts balance =1.0 -(sum (abs (p -ideal_pct )for p in pcts )/(2 *100 )) lines .append (f" โš–๏ธ Load balance score: {balance :.3f} (1.0 = perfect)") lines .append (f"{'='*60 }") print ("\n".join (lines )) def get_stats (self )->dict : """Return current stats as a dict (for programmatic access).""" total =self ._counts .sum ().item () if total ==0 : pcts =[0.0 ]*self .num_experts else : pcts =(self ._counts .float ()/total *100 ).tolist () cold =[i for i ,p in enumerate (pcts )if p 0 else 0.0 return { "step":self ._step , "layer_name":self .layer_name , "total_tokens":self ._total_tokens , "expert_counts":self ._counts .tolist (), "expert_pcts":pcts , "cold_experts":cold , "balance_score":balance , } def attach_utilization_trackers ( model :torch .nn .Module , report_interval :int =100 , )->list : """ Find all MoE layers in a model and attach ExpertUtilizationTrackers. Returns list of trackers for manual step() calls in the training loop. """ trackers =[] for name ,module in model .named_modules (): if hasattr (module ,'experts')and hasattr (module ,'router'): num_experts =len (module .experts ) tracker =ExpertUtilizationTracker ( num_experts =num_experts , layer_name =name , report_interval =report_interval , ) if hasattr (module ,'_expert_offload_manager'): tracker .link_offload_manager (module ._expert_offload_manager ) module ._utilization_tracker =tracker trackers .append (tracker ) if trackers : print (f" ๐Ÿ“Š Attached {len (trackers )} expert utilization trackers (report every {report_interval } steps)") return trackers class MoERouter (nn .Module ): """ SOTA Router for Mixture of Experts v2.0 - FP16 native. Supports both traditional aux-loss routing and aux-lossless routing. """ def __init__ (self ,hidden_size :int ,num_experts :int ,top_k :int =2 , noise_std :float =0.01 ,capacity_factor :float =1.25 , aux_lossless :bool =True ): super ().__init__ () self .num_experts =num_experts self .top_k =top_k self .noise_std =noise_std self .capacity_factor =capacity_factor self .hidden_size =hidden_size self .aux_lossless =aux_lossless self .input_norm =nn .LayerNorm (hidden_size ,eps =1e-5 ) self .gate =nn .Linear (hidden_size ,num_experts ,bias =False ) nn .init .normal_ (self .gate .weight ,mean =0.0 ,std =0.01 ) if aux_lossless : self .expert_bias =nn .Parameter (torch .zeros (num_experts )) def forward (self ,hidden_states :torch .Tensor )->Tuple [torch .Tensor ,torch .Tensor ,torch .Tensor ]: batch_size ,seq_len ,hidden_dim =hidden_states .shape hidden_flat =hidden_states .view (-1 ,hidden_dim ) hidden_norm =self .input_norm (hidden_flat ) router_logits =self .gate (hidden_norm ) if self .aux_lossless : router_logits =router_logits +self .expert_bias if self .training and self .noise_std >0 : noise =torch .randn_like (router_logits )*self .noise_std noisy_logits =router_logits +noise else : noisy_logits =router_logits router_probs =F .softmax (noisy_logits ,dim =-1 ,dtype =hidden_states .dtype ) top_k_probs ,top_k_indices =torch .topk (router_probs ,self .top_k ,dim =-1 ) prob_sum =top_k_probs .sum (dim =-1 ,keepdim =True ).clamp (min =EPS ) top_k_probs =top_k_probs /prob_sum return top_k_probs ,top_k_indices ,router_logits class MoEExpert (nn .Module ): """ Single expert FFN with SwiGLU activation - FP16 native. """ def __init__ (self ,hidden_size :int ,intermediate_size :int ,dropout :float =0.0 ): super ().__init__ () self .hidden_size =hidden_size self .intermediate_size =intermediate_size self .gate_proj =nn .Linear (hidden_size ,intermediate_size ,bias =False ) self .up_proj =nn .Linear (hidden_size ,intermediate_size ,bias =False ) self .down_proj =nn .Linear (intermediate_size ,hidden_size ,bias =False ) self .act_fn =nn .SiLU () self .dropout =nn .Dropout (dropout )if dropout >0 else nn .Identity () self ._init_weights () def _init_weights (self ): std =0.02 nn .init .normal_ (self .gate_proj .weight ,mean =0.0 ,std =std ) nn .init .normal_ (self .up_proj .weight ,mean =0.0 ,std =std ) nn .init .normal_ (self .down_proj .weight ,mean =0.0 ,std =std *0.5 ) def forward (self ,x :torch .Tensor )->torch .Tensor : gate =self .act_fn (self .gate_proj (x )) up =self .up_proj (x ) out =self .down_proj (gate *up ) return self .dropout (out ) class SharedExpert (nn .Module ): """ Isolated Shared Expert (v2.0) - FP16 native. Always active, separate from routed experts. The shared expert processes all tokens independently of routing decisions. """ def __init__ (self ,hidden_size :int ,intermediate_size :int ,dropout :float =0.0 , isolated :bool =True ): super ().__init__ () self .hidden_size =hidden_size self .intermediate_size =intermediate_size self .isolated =isolated self .gate_proj =nn .Linear (hidden_size ,intermediate_size ,bias =False ) self .up_proj =nn .Linear (hidden_size ,intermediate_size ,bias =False ) self .down_proj =nn .Linear (intermediate_size ,hidden_size ,bias =False ) self .act_fn =nn .SiLU () self .dropout =nn .Dropout (dropout )if dropout >0 else nn .Identity () self .shared_gate =nn .Parameter (torch .ones (1 )*0.5 ) if isolated : self .pre_norm =nn .LayerNorm (hidden_size ,eps =1e-5 ) self ._init_weights () def _init_weights (self ): std =0.02 nn .init .normal_ (self .gate_proj .weight ,mean =0.0 ,std =std ) nn .init .normal_ (self .up_proj .weight ,mean =0.0 ,std =std ) nn .init .normal_ (self .down_proj .weight ,mean =0.0 ,std =std *0.5 ) def forward (self ,x :torch .Tensor )->torch .Tensor : if self .isolated : x =self .pre_norm (x ) gate =self .act_fn (self .gate_proj (x )) up =self .up_proj (x ) out =self .down_proj (gate *up ) out =self .dropout (out ) return out *torch .sigmoid (self .shared_gate ) class MoELayer (nn .Module ): """ SOTA Mixture of Experts layer v2.0 - FP16 native. Supports Aux-Lossless MoE with Isolated Shared Expert. """ def __init__ ( self , hidden_size :int , intermediate_size :int , num_experts :int =8 , num_experts_per_tok :int =2 , use_shared_expert :bool =True , shared_expert_intermediate_size :Optional [int ]=None , capacity_factor :float =1.25 , expert_dropout :float =0.0 , aux_lossless :bool =True , isolated_shared :bool =True , ): super ().__init__ () self .hidden_size =hidden_size self .num_experts =num_experts self .num_experts_per_tok =num_experts_per_tok self .use_shared_expert =use_shared_expert self .capacity_factor =capacity_factor self .aux_lossless =aux_lossless self .router =MoERouter ( hidden_size ,num_experts ,num_experts_per_tok , capacity_factor =capacity_factor ,aux_lossless =aux_lossless ) self .experts =nn .ModuleList ([ MoEExpert (hidden_size ,intermediate_size ,expert_dropout ) for _ in range (num_experts ) ]) if use_shared_expert : shared_size =shared_expert_intermediate_size or intermediate_size self .shared_expert =SharedExpert ( hidden_size ,shared_size ,expert_dropout ,isolated =isolated_shared ) else : self .shared_expert =None def forward (self ,hidden_states :torch .Tensor )->Tuple [torch .Tensor ,torch .Tensor ]: batch_size ,seq_len ,hidden_size =hidden_states .shape hidden_flat =hidden_states .view (-1 ,hidden_size ) num_tokens =hidden_flat .shape [0 ] top_k_probs ,top_k_indices ,router_logits =self .router (hidden_states ) if hasattr (self ,'_utilization_tracker'): self ._utilization_tracker .record (top_k_indices ) final_output =torch .zeros_like (hidden_flat ) for expert_idx in range (self .num_experts ): expert =self .experts [expert_idx ] for k in range (self .num_experts_per_tok ): mask =(top_k_indices [:,k ]==expert_idx ) if mask .any (): expert_input =hidden_flat [mask ] expert_output =expert (expert_input ) weight =top_k_probs [mask ,k :k +1 ] final_output [mask ]=final_output [mask ]+weight *expert_output if self .shared_expert is not None : shared_output =self .shared_expert (hidden_flat ) final_output =final_output +shared_output final_output =final_output .view (batch_size ,seq_len ,hidden_size ) aux_loss =self ._compute_aux_loss (router_logits ,top_k_indices ,num_tokens ) return final_output ,aux_loss def _compute_aux_loss (self ,router_logits :torch .Tensor ,top_k_indices :torch .Tensor , num_tokens :int )->torch .Tensor : device =router_logits .device dtype =router_logits .dtype if self .aux_lossless : z_loss =torch .logsumexp (router_logits ,dim =-1 ).square ().mean ()*0.0001 return z_loss router_probs =F .softmax (router_logits ,dim =-1 ,dtype =dtype ) expert_mask =F .one_hot (top_k_indices ,self .num_experts ).to (dtype ) denominator =max (num_tokens *self .num_experts_per_tok ,1 ) tokens_per_expert =expert_mask .sum (dim =(0 ,1 ))/denominator avg_probs =router_probs .mean (dim =0 ) load_balance_loss =self .num_experts *(tokens_per_expert *avg_probs ).sum () z_loss =torch .logsumexp (router_logits ,dim =-1 ).square ().mean ()*0.001 router_probs_safe =router_probs .clamp (EPS ,1.0 -EPS ) log_probs =torch .log (router_probs_safe ) entropy =-(router_probs_safe *log_probs ).sum (dim =-1 ).mean () max_entropy =torch .log (torch .tensor (float (self .num_experts ),device =device ,dtype =dtype )) entropy_loss =(max_entropy -entropy ).clamp (min =0.0 )*0.01 expert_usage =(tokens_per_expert >0.01 ).to (dtype ).mean () utilization_loss =(1.0 -expert_usage )*0.1 total_aux_loss =load_balance_loss +z_loss +entropy_loss +utilization_loss return total_aux_loss class ExpertChoiceMoELayer (nn .Module ): """ Expert Choice MoE - FP16 native. """ def __init__ ( self , hidden_size :int , intermediate_size :int , num_experts :int =8 , capacity_factor :float =1.0 , ): super ().__init__ () self .hidden_size =hidden_size self .num_experts =num_experts self .capacity_factor =capacity_factor self .input_norm =nn .LayerNorm (hidden_size ,eps =1e-5 ) self .gate =nn .Linear (hidden_size ,num_experts ,bias =False ) nn .init .normal_ (self .gate .weight ,mean =0.0 ,std =0.01 ) self .experts =nn .ModuleList ([ MoEExpert (hidden_size ,intermediate_size ) for _ in range (num_experts ) ]) def forward (self ,hidden_states :torch .Tensor )->Tuple [torch .Tensor ,torch .Tensor ]: batch_size ,seq_len ,hidden_size =hidden_states .shape hidden_flat =hidden_states .view (-1 ,hidden_size ) num_tokens =hidden_flat .shape [0 ] hidden_norm =self .input_norm (hidden_flat ) router_logits =self .gate (hidden_norm ) router_probs =F .softmax (router_logits ,dim =0 ,dtype =hidden_states .dtype ) capacity =int (num_tokens *self .capacity_factor /self .num_experts ) capacity =max (capacity ,1 ) final_output =torch .zeros_like (hidden_flat ) token_counts =torch .zeros (num_tokens ,device =hidden_flat .device ,dtype =hidden_flat .dtype ) for expert_idx in range (self .num_experts ): expert =self .experts [expert_idx ] expert_probs =router_probs [:,expert_idx ] top_probs ,top_indices =torch .topk (expert_probs ,min (capacity ,num_tokens )) expert_input =hidden_flat [top_indices ] expert_output =expert (expert_input ) final_output [top_indices ]=final_output [top_indices ]+top_probs .unsqueeze (-1 )*expert_output token_counts [top_indices ]=token_counts [top_indices ]+top_probs token_counts =token_counts .clamp (min =EPS ) final_output =final_output /token_counts .unsqueeze (-1 ) final_output =final_output .view (batch_size ,seq_len ,hidden_size ) aux_loss =torch .logsumexp (router_logits ,dim =-1 ).square ().mean ()*0.001 return final_output ,aux_loss ============================================================================== MODELS.ENCODERS.VISION ============================================================================== EPS =1e-5 class RoPE2DEncoder (nn .Module ): """ 2D Rotary Position Embedding for vision encoder patches. Matches the 2D-RoPE in image generator for seamless integration. """ def __init__ (self ,dim :int ,max_height :int =128 ,max_width :int =128 ,base :float =10000.0 ): super ().__init__ () self .dim =dim self .max_height =max_height self .max_width =max_width self .base =base self .dim_x =dim //2 self .dim_y =dim -self .dim_x inv_freq_x =1.0 /(base **(torch .arange (0 ,self .dim_x ,2 ,dtype =torch .float32 )/self .dim_x )) inv_freq_y =1.0 /(base **(torch .arange (0 ,self .dim_y ,2 ,dtype =torch .float32 )/self .dim_y )) self .register_buffer ('inv_freq_x',inv_freq_x ,persistent =False ) self .register_buffer ('inv_freq_y',inv_freq_y ,persistent =False ) def forward (self ,x :torch .Tensor ,height :int ,width :int )->Tuple [torch .Tensor ,torch .Tensor ]: device =x .device dtype =x .dtype pos_x =torch .arange (width ,device =device ,dtype =torch .float32 ) pos_y =torch .arange (height ,device =device ,dtype =torch .float32 ) freqs_x =torch .outer (pos_x ,self .inv_freq_x .to (device )) freqs_y =torch .outer (pos_y ,self .inv_freq_y .to (device )) freqs_x =torch .cat ([freqs_x ,freqs_x ],dim =-1 ) freqs_y =torch .cat ([freqs_y ,freqs_y ],dim =-1 ) cos_2d =torch .zeros (height ,width ,self .dim ,device =device ,dtype =dtype ) sin_2d =torch .zeros (height ,width ,self .dim ,device =device ,dtype =dtype ) for y in range (height ): for w in range (width ): cos_2d [y ,w ,:self .dim_x ]=freqs_x [w ].cos ().to (dtype ) sin_2d [y ,w ,:self .dim_x ]=freqs_x [w ].sin ().to (dtype ) cos_2d [y ,w ,self .dim_x :]=freqs_y [y ].cos ().to (dtype ) sin_2d [y ,w ,self .dim_x :]=freqs_y [y ].sin ().to (dtype ) cos_2d =cos_2d .view (height *width ,self .dim ) sin_2d =sin_2d .view (height *width ,self .dim ) return cos_2d ,sin_2d def apply_rope_2d_encoder (x :torch .Tensor ,cos :torch .Tensor ,sin :torch .Tensor )->torch .Tensor : """Apply 2D rotary position embedding to tensor.""" x1 =x [...,:x .shape [-1 ]//2 ] x2 =x [...,x .shape [-1 ]//2 :] rotated =torch .cat ((-x2 ,x1 ),dim =-1 ) return x *cos +rotated *sin class TiTokTokenizer (nn .Module ): """ TiTok-style 1D Tokenizer for efficient visual representation. Converts 2D patch grid to 1D token sequence with learnable compression. """ def __init__ (self ,hidden_size :int ,num_tokens :int =256 ,num_patches :int =576 ): super ().__init__ () self .hidden_size =hidden_size self .num_tokens =num_tokens self .num_patches =num_patches self .compress =nn .Sequential ( nn .Linear (hidden_size ,hidden_size ), nn .GELU (), nn .Linear (hidden_size ,hidden_size ), ) self .token_queries =nn .Parameter (torch .randn (1 ,num_tokens ,hidden_size )*0.02 ) self .compress_attn =nn .MultiheadAttention ( embed_dim =hidden_size , num_heads =8 , batch_first =True , dropout =0.1 , ) self .compress_norm =nn .LayerNorm (hidden_size ) def forward (self ,x :torch .Tensor )->torch .Tensor : """ Compress patch features to TiTok-style 1D tokens. Args: x: [B, num_patches, hidden_size] patch features Returns: [B, num_tokens, hidden_size] compressed token features """ batch_size =x .shape [0 ] queries =self .token_queries .expand (batch_size ,-1 ,-1 ) x_proj =self .compress (x ) tokens ,_ =self .compress_attn (queries ,x_proj ,x_proj ) tokens =self .compress_norm (queries +tokens ) return tokens class DeepStack (nn .Module ): """ DeepStack: Fuses multi-level ViT features to capture fine-grained details and sharpen image-text alignment. SOTA: Instead of using only the final layer features, DeepStack combines features from multiple intermediate layers of the vision encoder, enabling: - Better fine-grained detail capture (early layers have high-resolution features) - Stronger image-text alignment (different layers capture different semantic levels) - Improved generation quality for both understanding and generation tasks Architecture: - Collects features from selected layers (typically: early, middle, late) - Projects each level to a common dimension - Combines via learned weighted sum or attention """ def __init__ (self ,hidden_size :int ,num_layers :int =3 ,use_attention :bool =True ): super ().__init__ () self .hidden_size =hidden_size self .num_layers =num_layers self .use_attention =use_attention self .level_projs =nn .ModuleList ([ nn .Linear (hidden_size ,hidden_size ) for _ in range (num_layers ) ]) self .level_norms =nn .ModuleList ([ nn .LayerNorm (hidden_size ) for _ in range (num_layers ) ]) if use_attention : self .fusion_query =nn .Parameter (torch .randn (1 ,1 ,hidden_size )*0.02 ) self .fusion_attn =nn .MultiheadAttention ( embed_dim =hidden_size , num_heads =8 , batch_first =True , dropout =0.1 , ) self .fusion_norm =nn .LayerNorm (hidden_size ) else : self .level_weights =nn .Parameter (torch .ones (num_layers )/num_layers ) self .output_proj =nn .Sequential ( nn .Linear (hidden_size ,hidden_size ), nn .GELU (), nn .Linear (hidden_size ,hidden_size ), ) def forward (self ,multi_level_features :list )->torch .Tensor : """ Fuse multi-level features. Args: multi_level_features: List of [B, seq_len, hidden_size] features from different layers Returns: [B, seq_len, hidden_size] fused features """ if len (multi_level_features )!=self .num_layers : multi_level_features =multi_level_features [-self .num_layers :]if len (multi_level_features )>self .num_layers else multi_level_features batch_size ,seq_len ,_ =multi_level_features [0 ].shape projected =[] for i ,(feat ,proj ,norm )in enumerate (zip (multi_level_features ,self .level_projs ,self .level_norms )): projected .append (norm (proj (feat ))) if self .use_attention : stacked =torch .cat (projected ,dim =1 ) query =self .fusion_query .expand (batch_size ,seq_len ,-1 ) fused ,_ =self .fusion_attn (query ,stacked ,stacked ) fused =self .fusion_norm (query +fused ) else : weights =F .softmax (self .level_weights ,dim =0 ) fused =sum (w *feat for w ,feat in zip (weights ,projected )) return self .output_proj (fused ) class DualStreamEncoderAttention (nn .Module ): """ Symmetric Dual-Stream Self-Attention for vision encoding. Matches the dual-stream architecture in image generator. """ def __init__ (self ,hidden_size :int ,num_heads :int =8 ,max_height :int =64 ,max_width :int =64 ): super ().__init__ () self .hidden_size =hidden_size self .num_heads =num_heads self .head_dim =hidden_size //num_heads self .scale =self .head_dim **-0.5 self .to_qkv_a =nn .Linear (hidden_size ,hidden_size *3 ,bias =False ) self .to_qkv_b =nn .Linear (hidden_size ,hidden_size *3 ,bias =False ) self .to_out_a =nn .Linear (hidden_size ,hidden_size ,bias =False ) self .to_out_b =nn .Linear (hidden_size ,hidden_size ,bias =False ) self .norm_a =nn .LayerNorm (hidden_size ) self .norm_b =nn .LayerNorm (hidden_size ) self .rope_2d =RoPE2DEncoder (self .head_dim ,max_height ,max_width ) def forward (self ,x_a :torch .Tensor ,x_b :torch .Tensor ,height :int ,width :int )->Tuple [torch .Tensor ,torch .Tensor ]: batch_size ,seq_len ,_ =x_a .shape x_a =self .norm_a (x_a ) x_b =self .norm_b (x_b ) qkv_a =self .to_qkv_a (x_a ).reshape (batch_size ,seq_len ,3 ,self .num_heads ,self .head_dim ) qkv_b =self .to_qkv_b (x_b ).reshape (batch_size ,seq_len ,3 ,self .num_heads ,self .head_dim ) q_a ,k_a ,v_a =qkv_a .unbind (dim =2 ) q_b ,k_b ,v_b =qkv_b .unbind (dim =2 ) cos ,sin =self .rope_2d (x_a ,height ,width ) cos =cos .unsqueeze (0 ).unsqueeze (0 ) sin =sin .unsqueeze (0 ).unsqueeze (0 ) q_a =q_a .transpose (1 ,2 ) k_a =k_a .transpose (1 ,2 ) v_a =v_a .transpose (1 ,2 ) q_b =q_b .transpose (1 ,2 ) k_b =k_b .transpose (1 ,2 ) v_b =v_b .transpose (1 ,2 ) q_a =apply_rope_2d_encoder (q_a ,cos ,sin ) k_a =apply_rope_2d_encoder (k_a ,cos ,sin ) q_b =apply_rope_2d_encoder (q_b ,cos ,sin ) k_b =apply_rope_2d_encoder (k_b ,cos ,sin ) k_combined =torch .cat ([k_a ,k_b ],dim =2 ) v_combined =torch .cat ([v_a ,v_b ],dim =2 ) attn_a =F .scaled_dot_product_attention (q_a ,k_combined ,v_combined ) attn_b =F .scaled_dot_product_attention (q_b ,k_combined ,v_combined ) attn_a =attn_a .transpose (1 ,2 ).reshape (batch_size ,seq_len ,self .hidden_size ) attn_b =attn_b .transpose (1 ,2 ).reshape (batch_size ,seq_len ,self .hidden_size ) out_a =self .to_out_a (attn_a ) out_b =self .to_out_b (attn_b ) return out_a ,out_b class VisionEncoderBlock (nn .Module ): """Single block with dual-stream attention and FFN.""" def __init__ (self ,hidden_size :int ,num_heads :int =8 ,ff_mult :int =4 ,max_height :int =64 ,max_width :int =64 ): super ().__init__ () self .dual_attn =DualStreamEncoderAttention (hidden_size ,num_heads ,max_height ,max_width ) self .ffn_a =nn .Sequential ( nn .LayerNorm (hidden_size ), nn .Linear (hidden_size ,hidden_size *ff_mult ), nn .GELU (), nn .Linear (hidden_size *ff_mult ,hidden_size ), ) self .ffn_b =nn .Sequential ( nn .LayerNorm (hidden_size ), nn .Linear (hidden_size ,hidden_size *ff_mult ), nn .GELU (), nn .Linear (hidden_size *ff_mult ,hidden_size ), ) def forward (self ,x_a :torch .Tensor ,x_b :torch .Tensor ,height :int ,width :int )->Tuple [torch .Tensor ,torch .Tensor ]: attn_a ,attn_b =self .dual_attn (x_a ,x_b ,height ,width ) x_a =x_a +attn_a x_b =x_b +attn_b x_a =x_a +self .ffn_a (x_a ) x_b =x_b +self .ffn_b (x_b ) return x_a ,x_b class VisionEncoder (nn .Module ): """ SOTA Vision Encoder with 2D-RoPE, TiTok tokenization, and Dual-Stream Attention. Features: - SigLIP 2 / CLIP backbone for robust visual features - 2D-RoPE for flexible aspect ratios - TiTok-style 1D tokenization for efficient representation - Dual-stream attention for symmetric processing - FP16-native numerical stability """ def __init__ ( self , model_name :str ="google/siglip-so400m-patch14-384", freeze :bool =False , use_pooled_output :bool =False , use_dual_stream :bool =True , use_titok :bool =True , num_titok_tokens :int =256 , num_dual_stream_layers :int =2 , ): super ().__init__ () self .model_name =model_name self .use_pooled_output =use_pooled_output self .use_dual_stream =use_dual_stream self .use_titok =use_titok self ._is_siglip ="siglip"in model_name .lower () print (f"\n๐Ÿ‘๏ธ Loading Vision Encoder: {model_name }") if self ._is_siglip : self ._init_siglip (model_name ,freeze ) else : self ._init_clip (model_name ,freeze ) self .rope_2d =RoPE2DEncoder ( dim =self .hidden_size , max_height =64 , max_width =64 , ) print (f" ๐Ÿ“ 2D-RoPE: Flexible aspect ratio support") if use_dual_stream : patch_size =getattr (self .vision_model .config ,'patch_size',14 ) image_size =getattr (self .vision_model .config ,'image_size',384 ) max_patches =(image_size //patch_size ) self .dual_stream_layers =nn .ModuleList ([ VisionEncoderBlock ( hidden_size =self .hidden_size , num_heads =8 , ff_mult =4 , max_height =max_patches , max_width =max_patches , ) for _ in range (num_dual_stream_layers ) ]) print (f" ๐Ÿ”„ Dual-Stream: {num_dual_stream_layers } layers") else : self .dual_stream_layers =None if use_titok : self .titok =TiTokTokenizer ( hidden_size =self .hidden_size , num_tokens =num_titok_tokens , num_patches =self .num_patches , ) print (f" ๐ŸŽซ TiTok: {self .num_patches } patches -> {num_titok_tokens } tokens") else : self .titok =None def _init_siglip (self ,model_name :str ,freeze :bool ): """Initialize SigLIP 2 vision encoder.""" try : from transformers import SiglipVisionModel ,SiglipImageProcessor self .vision_model =SiglipVisionModel .from_pretrained (model_name ) self .image_processor =SiglipImageProcessor .from_pretrained (model_name ) self .hidden_size =self .vision_model .config .hidden_size print (f" ๐ŸŽฏ Using SigLIP 2 (recommended for MoE)") print (f" โœ… Hidden size: {self .hidden_size }") print (f" ๐Ÿ“ Native size: {self .vision_model .config .image_size } (multi-scale: 256-512px)") print (f" ๐Ÿ”ฒ Patch size: {self .vision_model .config .patch_size }") except ImportError : print (" โš ๏ธ SigLIP not available, falling back to CLIP") self ._is_siglip =False self ._init_clip ("openai/clip-vit-large-patch14",freeze ) return if freeze : for param in self .vision_model .parameters (): param .requires_grad =False print (f" โ„๏ธ Vision encoder backbone frozen") else : print (f" ๐Ÿ”ฅ Vision encoder backbone trainable") def _init_clip (self ,model_name :str ,freeze :bool ): """Initialize CLIP vision encoder (legacy support).""" from transformers import CLIPVisionModel ,CLIPImageProcessor self .vision_model =CLIPVisionModel .from_pretrained (model_name ) self .image_processor =CLIPImageProcessor .from_pretrained (model_name ) self .hidden_size =self .vision_model .config .hidden_size print (f" ๐Ÿ“Ž Using CLIP") print (f" โœ… Hidden size: {self .hidden_size }") if freeze : for param in self .vision_model .parameters (): param .requires_grad =False print (f" โ„๏ธ Vision encoder backbone frozen") else : print (f" ๐Ÿ”ฅ Vision encoder backbone trainable") def forward (self ,pixel_values :torch .Tensor ,return_titok :bool =None )->torch .Tensor : """ Extract vision features from images with SOTA enhancements. Args: pixel_values: [B, C, H, W] tensor of images return_titok: Override for TiTok output (None uses self.use_titok) Returns: [B, num_tokens, hidden_size] tensor (TiTok) or [B, num_patches, hidden_size] tensor (standard) or [B, hidden_size] if use_pooled_output=True """ outputs =self .vision_model (pixel_values =pixel_values ) features =outputs .last_hidden_state if self .use_pooled_output : if hasattr (outputs ,'pooler_output')and outputs .pooler_output is not None : return outputs .pooler_output else : return features .mean (dim =1 ) batch_size ,num_patches ,hidden_size =features .shape patch_size =getattr (self .vision_model .config ,'patch_size',14 ) image_size =getattr (self .vision_model .config ,'image_size',384 ) if num_patches ==(image_size //patch_size )**2 +1 : cls_token =features [:,:1 ] features =features [:,1 :] num_patches =num_patches -1 has_cls =True else : cls_token =None has_cls =False height =width =int (math .sqrt (num_patches )) if self .dual_stream_layers is not None : x_a =features x_b =features .clone () for layer in self .dual_stream_layers : x_a ,x_b =layer (x_a ,x_b ,height ,width ) features =(x_a +x_b )/2 use_titok_now =return_titok if return_titok is not None else self .use_titok if use_titok_now and self .titok is not None : features =self .titok (features ) return features def get_image_processor (self ): """Return the image processor for preprocessing.""" return self .image_processor @property def num_patches (self )->int : """Get number of patches for the vision model.""" config =self .vision_model .config image_size =config .image_size patch_size =config .patch_size return (image_size //patch_size )**2 @property def image_size (self )->int : """Get expected image size.""" return self .vision_model .config .image_size @property def output_tokens (self )->int : """Get number of output tokens (considering TiTok compression).""" if self .use_titok and self .titok is not None : return self .titok .num_tokens return self .num_patches SIGLIP_MODELS ={ "siglip-base":"google/siglip-base-patch16-224", "siglip-base-384":"google/siglip-base-patch16-384", "siglip-large":"google/siglip-large-patch16-256", "siglip-large-384":"google/siglip-large-patch16-384", "siglip-so400m":"google/siglip-so400m-patch14-384", "siglip-so400m-224":"google/siglip-so400m-patch14-224", "clip-base":"openai/clip-vit-base-patch16", "clip-large":"openai/clip-vit-large-patch14", } def get_vision_encoder ( model_key :str ="siglip-so400m", freeze :bool =False , use_dual_stream :bool =True , use_titok :bool =True , **kwargs )->VisionEncoder : """ Get a vision encoder by key name with SOTA enhancements. Args: model_key: Key from SIGLIP_MODELS or full model name freeze: Whether to freeze encoder backbone weights use_dual_stream: Enable dual-stream attention use_titok: Enable TiTok 1D tokenization **kwargs: Additional arguments for VisionEncoder Returns: VisionEncoder instance """ model_name =SIGLIP_MODELS .get (model_key ,model_key ) return VisionEncoder ( model_name =model_name , freeze =freeze , use_dual_stream =use_dual_stream , use_titok =use_titok , **kwargs ) ============================================================================== MODELS.ENCODERS.VIDEO ============================================================================== EPS =1e-5 class TextTimestampAlignment (nn .Module ): """ Text-Timestamp Alignment: Precise timestamp-grounded event localization for stronger video temporal modeling. SOTA: Moves beyond T-RoPE by explicitly aligning text descriptions with video timestamps, enabling: - Precise temporal localization of events described in text - Better video captioning with accurate time references - Improved video question-answering with temporal reasoning - Enhanced video generation with temporal control Architecture: - Cross-attention between text features and frame-level video features - Learnable timestamp embeddings for each frame - Temporal alignment loss during training """ def __init__ (self ,hidden_size :int ,max_frames :int =64 ,num_heads :int =8 ): super ().__init__ () self .hidden_size =hidden_size self .max_frames =max_frames self .num_heads =num_heads self .timestamp_embedding =nn .Embedding (max_frames ,hidden_size ) self .video_proj =nn .Linear (hidden_size ,hidden_size ) self .text_proj =nn .Linear (hidden_size ,hidden_size ) self .cross_attn =nn .MultiheadAttention ( embed_dim =hidden_size , num_heads =num_heads , batch_first =True , dropout =0.1 , ) self .text_norm =nn .LayerNorm (hidden_size ) self .video_norm =nn .LayerNorm (hidden_size ) self .alignment_head =nn .Sequential ( nn .Linear (hidden_size ,hidden_size //2 ), nn .GELU (), nn .Linear (hidden_size //2 ,1 ), ) self .output_proj =nn .Linear (hidden_size ,hidden_size ) def forward ( self , video_features :torch .Tensor , text_features :torch .Tensor , num_frames :int , return_alignment_scores :bool =False , )->Tuple [torch .Tensor ,Optional [torch .Tensor ]]: """ Align text with video timestamps. Args: video_features: [B, T*H*W, hidden_size] video features text_features: [B, text_len, hidden_size] text features num_frames: Number of frames in the video return_alignment_scores: Whether to return alignment scores for loss Returns: aligned_features: [B, T*H*W, hidden_size] timestamp-aligned video features alignment_scores: Optional [B, text_len, T] alignment scores """ batch_size =video_features .shape [0 ] total_tokens =video_features .shape [1 ] spatial_tokens =total_tokens //num_frames timestamp_ids =torch .arange (num_frames ,device =video_features .device ) timestamp_embeds =self .timestamp_embedding (timestamp_ids ) timestamp_embeds =timestamp_embeds .unsqueeze (1 ).expand (-1 ,spatial_tokens ,-1 ) timestamp_embeds =timestamp_embeds .reshape (1 ,total_tokens ,-1 ) timestamp_embeds =timestamp_embeds .expand (batch_size ,-1 ,-1 ) video_feat =self .video_norm (self .video_proj (video_features )+timestamp_embeds ) text_feat =self .text_norm (self .text_proj (text_features )) aligned ,attn_weights =self .cross_attn (text_feat ,video_feat ,video_feat ) alignment_scores =None if return_alignment_scores : attn_reshaped =attn_weights .view (batch_size ,text_features .shape [1 ],num_frames ,spatial_tokens ) alignment_scores =attn_reshaped .mean (dim =-1 ) aligned_text =text_features +self .output_proj (aligned ) return aligned_text ,alignment_scores class AlphaBlender (nn .Module ): """ AlphaBlender operator from VidTok for temporal blending. Blends two inputs with a learnable or fixed alpha parameter. """ def __init__ (self ,alpha :float =0.55 ): super ().__init__ () self .alpha =alpha def forward (self ,x1 :torch .Tensor ,x2 :torch .Tensor )->torch .Tensor : return self .alpha *x1 +(1 -self .alpha )*x2 class VidTokEncoder (nn .Module ): """ VidTok-style Video Encoder following Microsoft's VidTok architecture. SOTA: Implements the VidTok encoder with: - 3D convolutions for input and bottleneck (information fusion) - 2D convolutions for spatial downsampling (efficiency) - AlphaBlender + 1D convolutions for temporal downsampling - Layer normalization for stability Compresses video [B, C, T, H, W] -> latent [B, latent_dim, t, h, w] """ def __init__ ( self , in_channels :int =3 , latent_channels :int =4 , base_channels :int =64 , temporal_downsample :int =4 , spatial_downsample :int =8 , causal :bool =True , ): super ().__init__ () self .in_channels =in_channels self .latent_channels =latent_channels self .base_channels =base_channels self .temporal_downsample =temporal_downsample self .spatial_downsample =spatial_downsample self .causal =causal self .num_spatial_downs =int (math .log2 (spatial_downsample )) self .num_temporal_downs =int (math .log2 (temporal_downsample )) self .input_block =nn .Sequential ( nn .Conv3d (in_channels ,base_channels ,kernel_size =3 ,padding =1 ), nn .GroupNorm (8 ,base_channels ), nn .SiLU (), ) self .spatial_down_blocks =nn .ModuleList () ch =base_channels for i in range (self .num_spatial_downs ): out_ch =min (ch *2 ,512 ) self .spatial_down_blocks .append ( self ._make_spatial_down_block (ch ,out_ch ) ) ch =out_ch self .temporal_down_blocks =nn .ModuleList () for i in range (self .num_temporal_downs ): self .temporal_down_blocks .append ( self ._make_temporal_down_block (ch ) ) self .bottleneck =nn .Sequential ( nn .Conv3d (ch ,ch ,kernel_size =3 ,padding =1 ), nn .GroupNorm (8 ,ch ), nn .SiLU (), nn .Conv3d (ch ,ch ,kernel_size =3 ,padding =1 ), nn .GroupNorm (8 ,ch ), nn .SiLU (), ) self .to_latent =nn .Conv3d (ch ,latent_channels ,kernel_size =1 ) print (f" ๐ŸŽฌ VidTokEncoder: {in_channels }ch -> {latent_channels }ch latent") print (f" Spatial: {spatial_downsample }x down ({self .num_spatial_downs } stages)") print (f" Temporal: {temporal_downsample }x down ({self .num_temporal_downs } stages)") def _make_spatial_down_block (self ,in_ch :int ,out_ch :int )->nn .Module : """Create a spatial downsampling block using 2D convolutions.""" return nn .Sequential ( Rearrange3Dto2D (), nn .Conv2d (in_ch ,out_ch ,kernel_size =3 ,stride =2 ,padding =1 ), nn .GroupNorm (8 ,out_ch ), nn .SiLU (), nn .Conv2d (out_ch ,out_ch ,kernel_size =3 ,padding =1 ), nn .GroupNorm (8 ,out_ch ), nn .SiLU (), Rearrange2Dto3D (), ) def _make_temporal_down_block (self ,channels :int )->nn .Module : """Create a temporal downsampling block using AlphaBlender + 1D conv.""" return TemporalDownBlock (channels ,causal =self .causal ) def forward (self ,x :torch .Tensor )->torch .Tensor : """ Encode video to latent space. Args: x: [B, C, T, H, W] input video Returns: [B, latent_channels, t, h, w] latent representation """ B ,C ,T ,H ,W =x .shape x =self .input_block (x ) for block in self .spatial_down_blocks : if hasattr (block [0 ],'set_temporal_dim'): block [0 ].set_temporal_dim (x .shape [2 ]) if hasattr (block [-1 ],'set_temporal_dim'): block [-1 ].set_temporal_dim (x .shape [2 ]) x =block (x ) for block in self .temporal_down_blocks : x =block (x ) x =self .bottleneck (x ) x =self .to_latent (x ) return x class VidTokDecoder (nn .Module ): """ VidTok-style Video Decoder following Microsoft's VidTok architecture. Reconstructs video from latent [B, latent_dim, t, h, w] -> [B, C, T, H, W] """ def __init__ ( self , out_channels :int =3 , latent_channels :int =4 , base_channels :int =64 , temporal_upsample :int =4 , spatial_upsample :int =8 , causal :bool =True , ): super ().__init__ () self .out_channels =out_channels self .latent_channels =latent_channels self .base_channels =base_channels self .temporal_upsample =temporal_upsample self .spatial_upsample =spatial_upsample self .causal =causal self .num_spatial_ups =int (math .log2 (spatial_upsample )) self .num_temporal_ups =int (math .log2 (temporal_upsample )) ch =min (base_channels *(2 **self .num_spatial_ups ),512 ) self .from_latent =nn .Conv3d (latent_channels ,ch ,kernel_size =1 ) self .bottleneck =nn .Sequential ( nn .Conv3d (ch ,ch ,kernel_size =3 ,padding =1 ), nn .GroupNorm (8 ,ch ), nn .SiLU (), nn .Conv3d (ch ,ch ,kernel_size =3 ,padding =1 ), nn .GroupNorm (8 ,ch ), nn .SiLU (), ) self .temporal_up_blocks =nn .ModuleList () for i in range (self .num_temporal_ups ): self .temporal_up_blocks .append ( TemporalUpBlock (ch ,causal =self .causal ) ) self .spatial_up_blocks =nn .ModuleList () for i in range (self .num_spatial_ups ): out_ch =max (ch //2 ,base_channels ) self .spatial_up_blocks .append ( self ._make_spatial_up_block (ch ,out_ch ) ) ch =out_ch self .output_block =nn .Sequential ( nn .Conv3d (ch ,out_channels ,kernel_size =3 ,padding =1 ), nn .Tanh (), ) print (f" ๐ŸŽฌ VidTokDecoder: {latent_channels }ch latent -> {out_channels }ch") def _make_spatial_up_block (self ,in_ch :int ,out_ch :int )->nn .Module : """Create a spatial upsampling block using 2D convolutions.""" return nn .Sequential ( Rearrange3Dto2D (), nn .ConvTranspose2d (in_ch ,out_ch ,kernel_size =4 ,stride =2 ,padding =1 ), nn .GroupNorm (8 ,out_ch ), nn .SiLU (), nn .Conv2d (out_ch ,out_ch ,kernel_size =3 ,padding =1 ), nn .GroupNorm (8 ,out_ch ), nn .SiLU (), Rearrange2Dto3D (), ) def forward (self ,z :torch .Tensor )->torch .Tensor : """ Decode latent to video. Args: z: [B, latent_channels, t, h, w] latent representation Returns: [B, C, T, H, W] reconstructed video """ x =self .from_latent (z ) x =self .bottleneck (x ) for block in self .temporal_up_blocks : x =block (x ) for block in self .spatial_up_blocks : x =block (x ) x =self .output_block (x ) return x class Rearrange3Dto2D (nn .Module ): """Reshape [B, C, T, H, W] -> [B*T, C, H, W] for 2D operations.""" def __init__ (self ): super ().__init__ () self .temporal_dim =None def set_temporal_dim (self ,t :int ): self .temporal_dim =t def forward (self ,x :torch .Tensor )->torch .Tensor : B ,C ,T ,H ,W =x .shape self .temporal_dim =T return x .permute (0 ,2 ,1 ,3 ,4 ).reshape (B *T ,C ,H ,W ) class Rearrange2Dto3D (nn .Module ): """Reshape [B*T, C, H, W] -> [B, C, T, H, W] after 2D operations.""" def __init__ (self ): super ().__init__ () self .temporal_dim =None def set_temporal_dim (self ,t :int ): self .temporal_dim =t def forward (self ,x :torch .Tensor )->torch .Tensor : BT ,C ,H ,W =x .shape T =self .temporal_dim if self .temporal_dim else 1 B =BT //T return x .reshape (B ,T ,C ,H ,W ).permute (0 ,2 ,1 ,3 ,4 ) class TemporalDownBlock (nn .Module ): """Temporal downsampling using AlphaBlender + 1D conv (VidTok style).""" def __init__ (self ,channels :int ,causal :bool =True ): super ().__init__ () self .channels =channels self .causal =causal self .alpha_blender =AlphaBlender () padding =(1 ,0 )if causal else 1 self .temporal_conv =nn .Conv1d (channels ,channels ,kernel_size =2 ,stride =2 ,padding =0 ) self .norm =nn .GroupNorm (8 ,channels ) self .act =nn .SiLU () def forward (self ,x :torch .Tensor )->torch .Tensor : """ Args: x: [B, C, T, H, W] Returns: [B, C, T//2, H, W] """ B ,C ,T ,H ,W =x .shape x =x .permute (0 ,3 ,4 ,1 ,2 ).reshape (B *H *W ,C ,T ) x =self .temporal_conv (x ) x =self .norm (x .unsqueeze (-1 )).squeeze (-1 ) x =self .act (x ) T_new =x .shape [2 ] x =x .reshape (B ,H ,W ,C ,T_new ).permute (0 ,3 ,4 ,1 ,2 ) return x class TemporalUpBlock (nn .Module ): """Temporal upsampling using AlphaBlender + 1D conv (VidTok style).""" def __init__ (self ,channels :int ,causal :bool =True ): super ().__init__ () self .channels =channels self .causal =causal self .alpha_blender =AlphaBlender () self .temporal_conv =nn .ConvTranspose1d (channels ,channels ,kernel_size =2 ,stride =2 ) self .norm =nn .GroupNorm (8 ,channels ) self .act =nn .SiLU () def forward (self ,x :torch .Tensor )->torch .Tensor : """ Args: x: [B, C, T, H, W] Returns: [B, C, T*2, H, W] """ B ,C ,T ,H ,W =x .shape x =x .permute (0 ,3 ,4 ,1 ,2 ).reshape (B *H *W ,C ,T ) x =self .temporal_conv (x ) x =self .norm (x .unsqueeze (-1 )).squeeze (-1 ) x =self .act (x ) T_new =x .shape [2 ] x =x .reshape (B ,H ,W ,C ,T_new ).permute (0 ,3 ,4 ,1 ,2 ) return x class VidTokTokenizer (nn .Module ): """ VidTok-style Video Tokenizer (3D VAE) following Microsoft's VidTok architecture. SOTA: Full encoder-decoder architecture for video compression to latent space. - Efficient 2D+1D architecture (separates spatial and temporal processing) - AlphaBlender for temporal blending - Supports both continuous (KL) and discrete (FSQ) tokenization - Causal mode for streaming/autoregressive applications Compresses video [B, C, T, H, W] -> latent [B, latent_dim, t, h, w] """ def __init__ ( self , in_channels :int =3 , latent_channels :int =4 , base_channels :int =64 , temporal_compression :int =4 , spatial_compression :int =8 , causal :bool =True , use_fsq :bool =False , fsq_levels :int =8 , ): super ().__init__ () self .in_channels =in_channels self .latent_channels =latent_channels self .temporal_compression =temporal_compression self .spatial_compression =spatial_compression self .causal =causal self .use_fsq =use_fsq self .fsq_levels =fsq_levels self .encoder =VidTokEncoder ( in_channels =in_channels , latent_channels =latent_channels *2 if not use_fsq else latent_channels , base_channels =base_channels , temporal_downsample =temporal_compression , spatial_downsample =spatial_compression , causal =causal , ) self .decoder =VidTokDecoder ( out_channels =in_channels , latent_channels =latent_channels , base_channels =base_channels , temporal_upsample =temporal_compression , spatial_upsample =spatial_compression , causal =causal , ) print (f" ๐ŸŽฌ VidTokTokenizer: {temporal_compression }x{spatial_compression }x{spatial_compression } compression") print (f" Mode: {'FSQ (discrete)'if use_fsq else 'KL (continuous)'}, Causal: {causal }") def encode (self ,x :torch .Tensor )->torch .Tensor : """Encode video to latent space.""" h =self .encoder (x ) if self .use_fsq : return self ._fsq_quantize (h ) else : mean ,logvar =h .chunk (2 ,dim =1 ) std =torch .exp (0.5 *logvar ) eps =torch .randn_like (std ) return mean +eps *std def decode (self ,z :torch .Tensor )->torch .Tensor : """Decode latent to video.""" return self .decoder (z ) def _fsq_quantize (self ,z :torch .Tensor )->torch .Tensor : """Finite Scalar Quantization - quantize each channel independently.""" z =torch .tanh (z ) z =torch .round ((z +1 )*(self .fsq_levels -1 )/2 )*2 /(self .fsq_levels -1 )-1 return z def forward (self ,x :torch .Tensor )->Tuple [torch .Tensor ,torch .Tensor ]: """ Full forward pass: encode then decode. Args: x: [B, C, T, H, W] input video Returns: Tuple of (reconstructed video, latent representation) """ z =self .encode (x ) x_recon =self .decode (z ) return x_recon ,z class RoPE3DEncoder (nn .Module ): """ 3D Rotary Position Embedding for (x, y, t) dimensions. Matches the 3D-RoPE in video generator for seamless integration. """ def __init__ (self ,dim :int ,max_height :int =64 ,max_width :int =64 ,max_frames :int =32 ,base :float =10000.0 ): super ().__init__ () self .dim =dim self .max_height =max_height self .max_width =max_width self .max_frames =max_frames self .base =base dim_per_axis =dim //3 self .dim_x =dim_per_axis self .dim_y =dim_per_axis self .dim_t =dim -2 *dim_per_axis inv_freq_x =1.0 /(base **(torch .arange (0 ,self .dim_x ,2 ,dtype =torch .float32 )/self .dim_x )) inv_freq_y =1.0 /(base **(torch .arange (0 ,self .dim_y ,2 ,dtype =torch .float32 )/self .dim_y )) inv_freq_t =1.0 /(base **(torch .arange (0 ,self .dim_t ,2 ,dtype =torch .float32 )/self .dim_t )) self .register_buffer ('inv_freq_x',inv_freq_x ,persistent =False ) self .register_buffer ('inv_freq_y',inv_freq_y ,persistent =False ) self .register_buffer ('inv_freq_t',inv_freq_t ,persistent =False ) def forward (self ,x :torch .Tensor ,height :int ,width :int ,frames :int )->Tuple [torch .Tensor ,torch .Tensor ]: device =x .device dtype =x .dtype pos_x =torch .arange (width ,device =device ,dtype =torch .float32 ) pos_y =torch .arange (height ,device =device ,dtype =torch .float32 ) pos_t =torch .arange (frames ,device =device ,dtype =torch .float32 ) freqs_x =torch .outer (pos_x ,self .inv_freq_x .to (device )) freqs_y =torch .outer (pos_y ,self .inv_freq_y .to (device )) freqs_t =torch .outer (pos_t ,self .inv_freq_t .to (device )) freqs_x =torch .cat ([freqs_x ,freqs_x ],dim =-1 ) freqs_y =torch .cat ([freqs_y ,freqs_y ],dim =-1 ) freqs_t =torch .cat ([freqs_t ,freqs_t ],dim =-1 ) cos_x =freqs_x .cos ().to (dtype ) sin_x =freqs_x .sin ().to (dtype ) cos_y =freqs_y .cos ().to (dtype ) sin_y =freqs_y .sin ().to (dtype ) cos_t =freqs_t .cos ().to (dtype ) sin_t =freqs_t .sin ().to (dtype ) cos_3d =torch .zeros (frames ,height ,width ,self .dim ,device =device ,dtype =dtype ) sin_3d =torch .zeros (frames ,height ,width ,self .dim ,device =device ,dtype =dtype ) for t in range (frames ): for y in range (height ): for w in range (width ): cos_3d [t ,y ,w ,:self .dim_x ]=cos_x [w ] sin_3d [t ,y ,w ,:self .dim_x ]=sin_x [w ] cos_3d [t ,y ,w ,self .dim_x :self .dim_x +self .dim_y ]=cos_y [y ] sin_3d [t ,y ,w ,self .dim_x :self .dim_x +self .dim_y ]=sin_y [y ] cos_3d [t ,y ,w ,self .dim_x +self .dim_y :]=cos_t [t ] sin_3d [t ,y ,w ,self .dim_x +self .dim_y :]=sin_t [t ] cos_3d =cos_3d .view (frames *height *width ,self .dim ) sin_3d =sin_3d .view (frames *height *width ,self .dim ) return cos_3d ,sin_3d def apply_rope_3d_encoder (x :torch .Tensor ,cos :torch .Tensor ,sin :torch .Tensor )->torch .Tensor : """Apply 3D rotary position embedding to tensor.""" x1 =x [...,:x .shape [-1 ]//2 ] x2 =x [...,x .shape [-1 ]//2 :] rotated =torch .cat ((-x2 ,x1 ),dim =-1 ) return x *cos +rotated *sin class TemporalExpertRouterEncoder (nn .Module ): """ Temporal-Aware Expert Router for video encoding. Routes tokens based on temporal context and motion patterns. """ def __init__ (self ,hidden_size :int ,num_experts :int =4 ,top_k :int =2 ): super ().__init__ () self .num_experts =num_experts self .top_k =top_k self .temporal_proj =nn .Linear (hidden_size ,hidden_size ) self .gate =nn .Linear (hidden_size ,num_experts ,bias =False ) nn .init .normal_ (self .gate .weight ,mean =0.0 ,std =0.01 ) def forward (self ,x :torch .Tensor ,temporal_context :Optional [torch .Tensor ]=None )->Tuple [torch .Tensor ,torch .Tensor ]: if temporal_context is not None : x =x +self .temporal_proj (temporal_context ) router_logits =self .gate (x ) router_probs =F .softmax (router_logits ,dim =-1 ,dtype =x .dtype ) top_k_probs ,top_k_indices =torch .topk (router_probs ,self .top_k ,dim =-1 ) top_k_probs =top_k_probs /(top_k_probs .sum (dim =-1 ,keepdim =True )+EPS ) return top_k_probs ,top_k_indices class VideoExpertEncoder (nn .Module ): """Single expert for video encoding with SwiGLU.""" def __init__ (self ,hidden_size :int ,intermediate_size :int ): super ().__init__ () self .gate_proj =nn .Linear (hidden_size ,intermediate_size ,bias =False ) self .up_proj =nn .Linear (hidden_size ,intermediate_size ,bias =False ) self .down_proj =nn .Linear (intermediate_size ,hidden_size ,bias =False ) self .act_fn =nn .SiLU () def forward (self ,x :torch .Tensor )->torch .Tensor : return self .down_proj (self .act_fn (self .gate_proj (x ))*self .up_proj (x )) class TemporalMoELayerEncoder (nn .Module ): """ Temporal-Aware MoE Layer for video encoding. Uses motion-aware routing for expert selection. """ def __init__ (self ,hidden_size :int ,intermediate_size :int ,num_experts :int =4 ,top_k :int =2 ): super ().__init__ () self .hidden_size =hidden_size self .num_experts =num_experts self .top_k =top_k self .router =TemporalExpertRouterEncoder (hidden_size ,num_experts ,top_k ) self .experts =nn .ModuleList ([ VideoExpertEncoder (hidden_size ,intermediate_size ) for _ in range (num_experts ) ]) self .shared_expert =VideoExpertEncoder (hidden_size ,intermediate_size ) def forward (self ,x :torch .Tensor ,temporal_context :Optional [torch .Tensor ]=None )->torch .Tensor : batch_size ,seq_len ,hidden_size =x .shape x_flat =x .view (-1 ,hidden_size ) top_k_probs ,top_k_indices =self .router (x_flat ,temporal_context .view (-1 ,hidden_size )if temporal_context is not None else None ) output =torch .zeros_like (x_flat ) for expert_idx in range (self .num_experts ): expert =self .experts [expert_idx ] for k in range (self .top_k ): mask =(top_k_indices [:,k ]==expert_idx ) if mask .any (): expert_input =x_flat [mask ] expert_output =expert (expert_input ) weight =top_k_probs [mask ,k :k +1 ] output [mask ]=output [mask ]+weight *expert_output shared_output =self .shared_expert (x_flat ) output =output +shared_output return output .view (batch_size ,seq_len ,hidden_size ) class Causal3DAttentionEncoder (nn .Module ): """ 3D Causal Self-Attention with 3D-RoPE for video encoding. Attends to all positions for encoding (non-causal during encoding). """ def __init__ (self ,hidden_size :int ,num_heads :int =8 ,max_frames :int =32 ,max_height :int =64 ,max_width :int =64 ): super ().__init__ () self .hidden_size =hidden_size self .num_heads =num_heads self .head_dim =hidden_size //num_heads self .scale =self .head_dim **-0.5 self .to_qkv =nn .Linear (hidden_size ,hidden_size *3 ,bias =False ) self .to_out =nn .Linear (hidden_size ,hidden_size ,bias =False ) self .norm =nn .LayerNorm (hidden_size ) self .rope_3d =RoPE3DEncoder (self .head_dim ,max_height ,max_width ,max_frames ) def forward (self ,x :torch .Tensor ,height :int ,width :int ,frames :int ,causal :bool =False )->torch .Tensor : batch_size ,seq_len ,_ =x .shape x_norm =self .norm (x ) qkv =self .to_qkv (x_norm ).reshape (batch_size ,seq_len ,3 ,self .num_heads ,self .head_dim ) q ,k ,v =qkv .unbind (dim =2 ) cos ,sin =self .rope_3d (x ,height ,width ,frames ) cos =cos .unsqueeze (0 ).unsqueeze (2 ) sin =sin .unsqueeze (0 ).unsqueeze (2 ) q =q .transpose (1 ,2 ) k =k .transpose (1 ,2 ) v =v .transpose (1 ,2 ) q =apply_rope_3d_encoder (q ,cos ,sin ) k =apply_rope_3d_encoder (k ,cos ,sin ) if causal : attn_output =F .scaled_dot_product_attention (q ,k ,v ,is_causal =True ) else : attn_output =F .scaled_dot_product_attention (q ,k ,v ) attn_output =attn_output .transpose (1 ,2 ).reshape (batch_size ,seq_len ,self .hidden_size ) return self .to_out (attn_output ) class VideoEncoderBlock (nn .Module ): """Single block with 3D causal attention and temporal MoE FFN.""" def __init__ ( self , hidden_size :int , num_heads :int =8 , num_experts :int =4 , max_frames :int =32 , max_height :int =64 , max_width :int =64 , ): super ().__init__ () self .attn =Causal3DAttentionEncoder (hidden_size ,num_heads ,max_frames ,max_height ,max_width ) self .moe =TemporalMoELayerEncoder (hidden_size ,hidden_size *4 ,num_experts ) self .norm =nn .LayerNorm (hidden_size ) def forward (self ,x :torch .Tensor ,height :int ,width :int ,frames :int ,causal :bool =False )->torch .Tensor : x =x +self .attn (x ,height ,width ,frames ,causal ) x =self .norm (x +self .moe (x )) return x class VideoTiTokTokenizer (nn .Module ): """ SOTA TiTok-style 1D Tokenizer for video features with temporal awareness. This compresses encoded video features (from vision encoder) to a smaller number of tokens, similar to how TiTokTokenizer works for images but with proper temporal modeling. SOTA Features: - Multi-layer transformer with temporal-aware attention - 3D positional encoding (spatial + temporal) - Hierarchical compression: spatial first, then temporal - Causal temporal attention for streaming compatibility - Gated cross-attention for selective feature extraction Note: This is different from VidTokTokenizer which is a 3D VAE for raw video compression. This tokenizer operates on already-encoded features, not raw pixels. Converts [B, T*H*W, hidden_size] -> [B, num_tokens, hidden_size] """ def __init__ ( self , hidden_size :int , num_tokens :int =64 , num_patches :int =576 , max_frames :int =32 , num_layers :int =2 , num_heads :int =8 , dropout :float =0.1 , ): super ().__init__ () self .hidden_size =hidden_size self .num_tokens =num_tokens self .num_patches =num_patches self .max_frames =max_frames self .num_heads =num_heads self .patches_per_frame =num_patches //max_frames if max_frames >0 else num_patches self .spatial_size =int (self .patches_per_frame **0.5 ) self .temporal_pos =nn .Parameter (torch .randn (1 ,max_frames ,1 ,hidden_size )*0.02 ) self .spatial_pos =nn .Parameter (torch .randn (1 ,1 ,self .patches_per_frame ,hidden_size )*0.02 ) self .input_norm =nn .LayerNorm (hidden_size ) self .input_proj =nn .Linear (hidden_size ,hidden_size ) self .num_temporal_tokens =min (num_tokens //4 ,max_frames ) self .num_content_tokens =num_tokens -self .num_temporal_tokens self .temporal_queries =nn .Parameter (torch .randn (1 ,self .num_temporal_tokens ,hidden_size )*0.02 ) self .content_queries =nn .Parameter (torch .randn (1 ,self .num_content_tokens ,hidden_size )*0.02 ) self .compress_layers =nn .ModuleList () for i in range (num_layers ): self .compress_layers .append (nn .ModuleDict ({ 'cross_attn':nn .MultiheadAttention ( embed_dim =hidden_size , num_heads =num_heads , batch_first =True , dropout =dropout , ), 'cross_gate':nn .Sequential ( nn .Linear (hidden_size ,hidden_size ), nn .Sigmoid (), ), 'cross_norm':nn .LayerNorm (hidden_size ), 'self_attn':nn .MultiheadAttention ( embed_dim =hidden_size , num_heads =num_heads , batch_first =True , dropout =dropout , ), 'self_norm':nn .LayerNorm (hidden_size ), 'ffn':nn .Sequential ( nn .Linear (hidden_size ,hidden_size *4 ), nn .GELU (), nn .Dropout (dropout ), nn .Linear (hidden_size *4 ,hidden_size ), nn .Dropout (dropout ), ), 'ffn_norm':nn .LayerNorm (hidden_size ), })) self .fusion_attn =nn .MultiheadAttention ( embed_dim =hidden_size , num_heads =num_heads , batch_first =True , dropout =dropout , ) self .fusion_norm =nn .LayerNorm (hidden_size ) self .output_proj =nn .Sequential ( nn .Linear (hidden_size ,hidden_size ), nn .GELU (), nn .Linear (hidden_size ,hidden_size ), ) self .output_norm =nn .LayerNorm (hidden_size ) print (f" ๐ŸŽฌ VideoTiTokTokenizer: {num_patches } patches -> {num_tokens } tokens") print (f" Temporal tokens: {self .num_temporal_tokens }, Content tokens: {self .num_content_tokens }") print (f" Layers: {num_layers }, Heads: {num_heads }") def _load_from_state_dict (self ,state_dict ,prefix ,local_metadata ,strict ,missing_keys ,unexpected_keys ,error_msgs ): """Production-grade hook to handle dynamic frame counts and token counts when loading checkpoints.""" t_pos_key =prefix +'temporal_pos' if t_pos_key in state_dict : ckpt_pos =state_dict [t_pos_key ] if ckpt_pos .shape !=self .temporal_pos .shape : print (f" โš ๏ธ VideoTiTokTokenizer: Interpolating {t_pos_key } from {ckpt_pos .shape [1 ]} to {self .max_frames } frames.") ckpt_pos =ckpt_pos .squeeze (2 ).transpose (1 ,2 ) resized =F .interpolate (ckpt_pos ,size =self .max_frames ,mode ='linear',align_corners =False ) state_dict [t_pos_key ]=resized .transpose (1 ,2 ).unsqueeze (2 ) t_query_key =prefix +'temporal_queries' if t_query_key in state_dict : ckpt_query =state_dict [t_query_key ] if ckpt_query .shape !=self .temporal_queries .shape : print (f" โš ๏ธ VideoTiTokTokenizer: Interpolating {t_query_key } from {ckpt_query .shape [1 ]} to {self .num_temporal_tokens } tokens.") ckpt_query =ckpt_query .transpose (1 ,2 ) resized =F .interpolate (ckpt_query ,size =self .num_temporal_tokens ,mode ='linear',align_corners =False ) state_dict [t_query_key ]=resized .transpose (1 ,2 ) c_query_key =prefix +'content_queries' if c_query_key in state_dict : ckpt_query =state_dict [c_query_key ] if ckpt_query .shape !=self .content_queries .shape : print (f" โš ๏ธ VideoTiTokTokenizer: Interpolating {c_query_key } from {ckpt_query .shape [1 ]} to {self .num_content_tokens } tokens.") ckpt_query =ckpt_query .transpose (1 ,2 ) resized =F .interpolate (ckpt_query ,size =self .num_content_tokens ,mode ='linear',align_corners =False ) state_dict [c_query_key ]=resized .transpose (1 ,2 ) super ()._load_from_state_dict (state_dict ,prefix ,local_metadata ,strict ,missing_keys ,unexpected_keys ,error_msgs ) def _add_3d_pos_encoding (self ,x :torch .Tensor ,num_frames :int ,patches_per_frame :int )->torch .Tensor : """Add 3D positional encoding (temporal + spatial).""" B ,seq_len ,D =x .shape x =x .reshape (B ,num_frames ,patches_per_frame ,D ) temporal_pos =self .temporal_pos [:,:num_frames ,:,:] x =x +temporal_pos spatial_pos =self .spatial_pos [:,:,:patches_per_frame ,:] x =x +spatial_pos return x .reshape (B ,seq_len ,D ) def forward (self ,x :torch .Tensor ,num_frames :int =None )->torch .Tensor : """ Compress video patch features to TiTok-style 1D tokens. Args: x: [B, T*H*W, hidden_size] video patch features (flattened spatial-temporal) or [B, T, H*W, hidden_size] video patch features per frame num_frames: Number of frames (optional, for temporal embedding) Returns: [B, num_tokens, hidden_size] compressed token features """ batch_size =x .shape [0 ] if x .dim ()==4 : B ,T ,HW ,D =x .shape x =x .reshape (B ,T *HW ,D ) num_frames =T patches_per_frame =HW else : seq_len =x .shape [1 ] if num_frames is None : num_frames =min (self .max_frames ,seq_len //self .patches_per_frame ) num_frames =max (1 ,num_frames ) patches_per_frame =seq_len //num_frames if num_frames >0 else seq_len x =self .input_norm (x ) x =self .input_proj (x ) x =self ._add_3d_pos_encoding (x ,num_frames ,patches_per_frame ) temporal_queries =self .temporal_queries [:,:min (self .num_temporal_tokens ,num_frames ),:].expand (batch_size ,-1 ,-1 ) content_queries =self .content_queries .expand (batch_size ,-1 ,-1 ) queries =torch .cat ([temporal_queries ,content_queries ],dim =1 ) for layer in self .compress_layers : cross_out ,_ =layer ['cross_attn'](queries ,x ,x ) gate =layer ['cross_gate'](queries ) queries =layer ['cross_norm'](queries +gate *cross_out ) self_out ,_ =layer ['self_attn'](queries ,queries ,queries ) queries =layer ['self_norm'](queries +self_out ) ffn_out =layer ['ffn'](queries ) queries =layer ['ffn_norm'](queries +ffn_out ) actual_temporal =temporal_queries .shape [1 ] temporal_tokens =queries [:,:actual_temporal ,:] content_tokens =queries [:,actual_temporal :,:] fused ,_ =self .fusion_attn (content_tokens ,temporal_tokens ,temporal_tokens ) content_tokens =self .fusion_norm (content_tokens +fused ) tokens =torch .cat ([temporal_tokens ,content_tokens ],dim =1 ) if tokens .shape [1 ]self .num_tokens : tokens =tokens [:,:self .num_tokens ,:] tokens =self .output_proj (tokens ) tokens =self .output_norm (tokens ) return tokens class VideoEncoder (nn .Module ): """ SOTA Video Encoder with 3D-RoPE, 3D Causal Attention, Temporal Expert Routing, and VidTokTokenizer. Features: - 3D-RoPE for flexible (x, y, t) positional encodings - 3D Causal Attention for temporal understanding - Temporal-Aware Expert Routing for motion patterns - VidTokTokenizer for efficient 1D token compression (mirrors TiTokTokenizer for images) - Integrated with vision encoder backbone - FP16-native numerical stability """ def __init__ ( self , vision_encoder :VisionEncoder , max_frames :int =32 , num_encoder_layers :int =4 , num_experts :int =4 , use_3d_rope :bool =True , use_temporal_moe :bool =True , use_video_tokenizer :bool =True , num_video_tokens :int =64 , ): super ().__init__ () self .vision_encoder =vision_encoder self .max_frames =max_frames self .hidden_size =vision_encoder .hidden_size self .use_3d_rope =use_3d_rope self .use_temporal_moe =use_temporal_moe self .use_video_tokenizer =use_video_tokenizer self .image_size =getattr (vision_encoder ,'image_size',384 ) self .patch_size =getattr (vision_encoder .vision_model .config ,'patch_size',14 ) self .patches_per_side =self .image_size //self .patch_size self .num_spatial_tokens =self .patches_per_side **2 if use_3d_rope : self .rope_3d =RoPE3DEncoder ( dim =self .hidden_size , max_height =self .patches_per_side , max_width =self .patches_per_side , max_frames =max_frames , ) print (f" ๐Ÿ“ 3D-RoPE: (x,y,t) position encoding") else : self .rope_3d =None self .encoder_blocks =nn .ModuleList ([ VideoEncoderBlock ( hidden_size =self .hidden_size , num_heads =8 , num_experts =num_experts if use_temporal_moe else 1 , max_frames =max_frames , max_height =self .patches_per_side , max_width =self .patches_per_side , ) for _ in range (num_encoder_layers ) ]) print (f" ๐ŸŽฌ 3D Causal Transformer: {num_encoder_layers } layers") if use_temporal_moe : print (f" ๐ŸŽฏ Temporal MoE: {num_experts } experts per layer") if use_video_tokenizer : self .vidtok =VideoTiTokTokenizer ( hidden_size =self .hidden_size , num_tokens =num_video_tokens , num_patches =self .num_spatial_tokens *max_frames , max_frames =max_frames , ) self .video_tokenizer =self .vidtok else : self .vidtok =None self .video_tokenizer =None self .temporal_pool_query =nn .Parameter (torch .randn (1 ,1 ,self .hidden_size )*0.02 ) self .temporal_pool_attn =nn .MultiheadAttention ( embed_dim =self .hidden_size , num_heads =8 , batch_first =True , dropout =0.1 , ) self .temporal_pool_norm =nn .LayerNorm (self .hidden_size ) self .frame_pos_embed =nn .Parameter (torch .randn (1 ,max_frames ,self .hidden_size )*0.02 ) print (f" ๐ŸŽฌ Video encoder: max {max_frames } frames (multi-scale enabled)") def _load_from_state_dict (self ,state_dict ,prefix ,local_metadata ,strict ,missing_keys ,unexpected_keys ,error_msgs ): """Production-grade hook to handle dynamic frame counts when loading checkpoints. Interpolates temporal embeddings if the checkpoint frames differ from max_frames. """ embed_key =prefix +'frame_pos_embed' if embed_key in state_dict : ckpt_embed =state_dict [embed_key ] if ckpt_embed .shape !=self .frame_pos_embed .shape : print (f" โš ๏ธ VideoEncoder: Interpolating {embed_key } from {ckpt_embed .shape [1 ]} to {self .max_frames } frames.") ckpt_embed =ckpt_embed .transpose (1 ,2 ) resized =F .interpolate (ckpt_embed ,size =self .max_frames ,mode ='linear',align_corners =False ) state_dict [embed_key ]=resized .transpose (1 ,2 ) super ()._load_from_state_dict (state_dict ,prefix ,local_metadata ,strict ,missing_keys ,unexpected_keys ,error_msgs ) def _extract_frame_features (self ,frames :torch .Tensor )->torch .Tensor : """Extract per-frame features using vision encoder.""" batch_size ,num_frames =frames .shape [:2 ] frames_flat =frames .view (-1 ,*frames .shape [2 :]) if frames_flat .shape [-1 ]!=self .image_size or frames_flat .shape [-2 ]!=self .image_size : frames_flat =F .interpolate ( frames_flat , size =(self .image_size ,self .image_size ), mode ='bilinear', align_corners =False ) if not any (p .requires_grad for p in self .vision_encoder .parameters ()): with torch .no_grad (): frame_features =self .vision_encoder (frames_flat ,return_titok =False ) else : frame_features =self .vision_encoder (frames_flat ,return_titok =False ) return frame_features ,batch_size ,num_frames def forward ( self , frames :torch .Tensor , return_all_frames :bool =False , causal :bool =False , return_tokens :bool =False , )->torch .Tensor : """ Process video frames with 3D-RoPE and Causal Attention. Args: frames: [B, T, C, H, W] tensor of video frames return_all_frames: If True, return all frame features; else return pooled causal: If True, use causal attention (for autoregressive) return_tokens: If True, return VideoTokenizer compressed tokens Returns: If return_tokens: [B, num_tokens, hidden_size] video tokens If return_all_frames: [B, T, hidden_size] per-frame features Else: [B, hidden_size] pooled video representation """ frame_features ,batch_size ,num_frames =self ._extract_frame_features (frames ) _ ,num_patches ,hidden_size =frame_features .shape height =width =int (math .sqrt (num_patches )) frame_features =frame_features .view (batch_size ,num_frames ,num_patches ,hidden_size ) frame_features =frame_features +self .frame_pos_embed [:,:num_frames ].unsqueeze (2 ) x =frame_features .view (batch_size ,num_frames *num_patches ,hidden_size ) for block in self .encoder_blocks : x =block (x ,height ,width ,num_frames ,causal =causal ) if return_tokens and self .vidtok is not None : return self .vidtok (x ,num_frames ) if return_all_frames : x =x .view (batch_size ,num_frames ,num_patches ,hidden_size ) return x .mean (dim =2 ) else : query =self .temporal_pool_query .expand (batch_size ,-1 ,-1 ) pooled ,_ =self .temporal_pool_attn (query ,x ,x ) pooled =self .temporal_pool_norm (query +pooled ) return pooled .squeeze (1 ) def encode_frames_separately (self ,frames :torch .Tensor )->torch .Tensor : """ Encode frames without temporal attention (for generation conditioning). Args: frames: [B, T, C, H, W] tensor of video frames Returns: [B, T, hidden_size] tensor of frame features """ frame_features ,batch_size ,num_frames =self ._extract_frame_features (frames ) frame_features =frame_features .mean (dim =1 ) return frame_features .view (batch_size ,num_frames ,-1 ) def encode_with_spatial (self ,frames :torch .Tensor )->torch .Tensor : """ Encode frames preserving spatial structure (for video generation). Args: frames: [B, T, C, H, W] tensor of video frames Returns: [B, T, H, W, hidden_size] tensor of spatio-temporal features """ frame_features ,batch_size ,num_frames =self ._extract_frame_features (frames ) _ ,num_patches ,hidden_size =frame_features .shape height =width =int (math .sqrt (num_patches )) frame_features =frame_features .view (batch_size ,num_frames ,height ,width ,hidden_size ) return frame_features ============================================================================== MODELS.ENCODERS.AUDIO ============================================================================== EPS =1e-5 class RawWaveformTokenizer (nn .Module ): """ Raw Waveform Tokenizer - directly tokenizes audio waveforms without mel spectrograms. Uses multi-scale 1D convolutions to extract features at different temporal resolutions, then combines them into a unified representation. """ def __init__ ( self , hidden_size :int =1024 , num_codebooks :int =8 , codebook_size :int =1024 , sample_rate :int =16000 , hop_length :int =320 , num_conv_layers :int =6 , ): super ().__init__ () self .hidden_size =hidden_size self .num_codebooks =num_codebooks self .codebook_size =codebook_size self .sample_rate =sample_rate self .hop_length =hop_length self .conv_layers =nn .ModuleList () in_channels =1 channels =[32 ,64 ,128 ,256 ,512 ,hidden_size ] kernel_sizes =[7 ,5 ,5 ,3 ,3 ,3 ] strides =[2 ,2 ,2 ,2 ,2 ,2 ] for i in range (num_conv_layers ): out_channels =channels [i ]if i =8 else 1 ,out_channels ), nn .SiLU (), )) in_channels =out_channels self .codebooks =nn .ModuleList ([ nn .Embedding (codebook_size ,hidden_size ) for _ in range (num_codebooks ) ]) self .commitment_weight =0.25 self .output_proj =nn .Linear (hidden_size ,hidden_size ) print (f" ๐ŸŽต RawWaveformTokenizer: {num_codebooks } codebooks x {codebook_size } codes") def encode (self ,waveform :torch .Tensor )->Tuple [torch .Tensor ,torch .Tensor ]: """ Encode waveform to continuous features. Args: waveform: [B, T] or [B, 1, T] raw audio waveform Returns: features: [B, T', hidden_size] encoded features indices: [B, T', num_codebooks] quantized indices """ if waveform .dim ()==2 : waveform =waveform .unsqueeze (1 ) x =waveform for conv in self .conv_layers : x =conv (x ) x =x .transpose (1 ,2 ) return x ,None def quantize (self ,features :torch .Tensor )->Tuple [torch .Tensor ,torch .Tensor ,torch .Tensor ]: """ Residual Vector Quantization. Args: features: [B, T, hidden_size] continuous features Returns: quantized: [B, T, hidden_size] quantized features indices: [B, T, num_codebooks] codebook indices commitment_loss: scalar commitment loss """ batch_size ,seq_len ,_ =features .shape residual =features quantized =torch .zeros_like (features ) all_indices =[] total_commitment_loss =0.0 for codebook in self .codebooks : distances =torch .cdist (residual ,codebook .weight ) indices =distances .argmin (dim =-1 ) all_indices .append (indices ) quantized_step =codebook (indices ) quantized =quantized +residual +(quantized_step -residual ).detach () commitment_loss =F .mse_loss (residual .detach (),quantized_step ) total_commitment_loss =total_commitment_loss +commitment_loss residual =residual -quantized_step .detach () indices =torch .stack (all_indices ,dim =-1 ) commitment_loss =total_commitment_loss *self .commitment_weight return quantized ,indices ,commitment_loss def forward (self ,waveform :torch .Tensor ,quantize :bool =False )->Tuple [torch .Tensor ,Optional [torch .Tensor ]]: """ Forward pass. Args: waveform: [B, T] or [B, 1, T] raw audio quantize: Whether to apply vector quantization Returns: features: [B, T', hidden_size] encoded features commitment_loss: Optional commitment loss if quantize=True """ features ,_ =self .encode (waveform ) if quantize : features ,indices ,commitment_loss =self .quantize (features ) features =self .output_proj (features ) return features ,commitment_loss features =self .output_proj (features ) return features ,None class SnakeActivation (nn .Module ): """ Snake activation function from BigVGAN. x + (1/a) * sin^2(a * x) Better than ReLU/SiLU for audio generation - preserves periodicity. """ def __init__ (self ,channels :int ,alpha :float =1.0 ): super ().__init__ () self .alpha =nn .Parameter (torch .ones (1 ,channels ,1 )*alpha ) def forward (self ,x :torch .Tensor )->torch .Tensor : return x +(1.0 /(self .alpha +1e-6 ))*torch .sin (self .alpha *x )**2 class ResidualBlock1D (nn .Module ): """Residual block with dilated convolutions for multi-receptive field.""" def __init__ (self ,channels :int ,kernel_size :int =3 ,dilation :int =1 ): super ().__init__ () padding =(kernel_size *dilation -dilation )//2 self .conv1 =nn .utils .parametrizations .weight_norm ( nn .Conv1d (channels ,channels ,kernel_size ,padding =padding ,dilation =dilation ) ) self .conv2 =nn .utils .parametrizations .weight_norm ( nn .Conv1d (channels ,channels ,kernel_size ,padding =kernel_size //2 ) ) self .activation =SnakeActivation (channels ) def forward (self ,x :torch .Tensor )->torch .Tensor : residual =x x =self .activation (self .conv1 (x )) x =self .activation (self .conv2 (x )) return x +residual class MultiReceptiveFieldFusion (nn .Module ): """ Multi-Receptive Field Fusion (MRF) from HiFi-GAN. Processes input through multiple parallel residual stacks with different kernel sizes and dilations, then sums results. """ def __init__ (self ,channels :int ,kernel_sizes :List [int ]=[3 ,7 ,11 ], dilations :List [List [int ]]=[[1 ,3 ,5 ],[1 ,3 ,5 ],[1 ,3 ,5 ]]): super ().__init__ () self .num_kernels =len (kernel_sizes ) self .resblocks =nn .ModuleList () for k ,d_list in zip (kernel_sizes ,dilations ): blocks =nn .ModuleList ([ ResidualBlock1D (channels ,k ,d )for d in d_list ]) self .resblocks .append (blocks ) def forward (self ,x :torch .Tensor )->torch .Tensor : out =None for blocks in self .resblocks : h =x for block in blocks : h =block (h ) out =h if out is None else out +h return out /self .num_kernels class RawWaveformDecoder (nn .Module ): """ SOTA Raw Waveform Decoder - BigVGAN/HiFi-GAN style architecture. Converts features directly to playable audio waveform without external vocoder. SOTA Features: - Snake activation (BigVGAN) - preserves audio periodicity - Multi-Receptive Field Fusion (HiFi-GAN) - captures patterns at multiple scales - Weight normalization - stable training - Efficient upsampling with careful kernel/stride ratios - Anti-aliased resampling - Streaming-capable architecture Speed optimizations: - Fewer layers with smarter architecture - Fused operations where possible - Efficient 256x total upsampling (vs 64x before) """ def __init__ ( self , hidden_size :int =1024 , sample_rate :int =16000 , upsample_rates :List [int ]=[8 ,8 ,2 ,2 ], upsample_kernel_sizes :List [int ]=[16 ,16 ,4 ,4 ], resblock_kernel_sizes :List [int ]=[3 ,7 ,11 ], resblock_dilations :List [List [int ]]=[[1 ,3 ,5 ],[1 ,3 ,5 ],[1 ,3 ,5 ]], initial_channels :int =512 , ): super ().__init__ () self .hidden_size =hidden_size self .sample_rate =sample_rate self .num_upsamples =len (upsample_rates ) self .input_proj =nn .utils .parametrizations .weight_norm ( nn .Conv1d (hidden_size ,initial_channels ,kernel_size =7 ,padding =3 ) ) self .upsamplers =nn .ModuleList () self .mrf_blocks =nn .ModuleList () channels =initial_channels for i ,(rate ,kernel )in enumerate (zip (upsample_rates ,upsample_kernel_sizes )): self .upsamplers .append ( nn .utils .parametrizations .weight_norm ( nn .ConvTranspose1d ( channels ,channels //2 , kernel_size =kernel ,stride =rate , padding =(kernel -rate )//2 ) ) ) channels =channels //2 self .mrf_blocks .append ( MultiReceptiveFieldFusion (channels ,resblock_kernel_sizes ,resblock_dilations ) ) self .final_activation =SnakeActivation (channels ) self .output_conv =nn .utils .parametrizations .weight_norm ( nn .Conv1d (channels ,1 ,kernel_size =7 ,padding =3 ) ) self .upsample_factor =1 for rate in upsample_rates : self .upsample_factor *=rate print (f" ๐Ÿ”Š RawWaveformDecoder (SOTA BigVGAN-style):") print (f" - Snake activation for audio periodicity") print (f" - Multi-Receptive Field Fusion") print (f" - {self .upsample_factor }x upsampling") print (f" - Weight normalized layers") def forward ( self , features :torch .Tensor , target_length :Optional [int ]=None , )->torch .Tensor : """ Decode features to raw waveform. Args: features: [B, T, hidden_size] encoded features target_length: Optional target waveform length (for matching input length) Returns: waveform: [B, T_audio] raw audio waveform in [-1, 1] """ x =features .transpose (1 ,2 ) x =self .input_proj (x ) for upsample ,mrf in zip (self .upsamplers ,self .mrf_blocks ): x =upsample (x ) x =mrf (x ) x =self .final_activation (x ) waveform =self .output_conv (x ) waveform =torch .tanh (waveform ) waveform =waveform .squeeze (1 ) if target_length is not None and waveform .shape [-1 ]!=target_length : waveform =F .interpolate ( waveform .unsqueeze (1 ), size =target_length , mode ='linear', align_corners =False ).squeeze (1 ) return waveform def decode_from_codes ( self , codes :torch .Tensor , codebooks :nn .ModuleList , target_length :Optional [int ]=None , )->torch .Tensor : """ Decode directly from codebook indices. Args: codes: [B, T, num_codebooks] codebook indices codebooks: List of nn.Embedding codebooks from encoder target_length: Optional target waveform length Returns: waveform: [B, T_audio] raw audio waveform """ features =torch .zeros ( codes .shape [0 ],codes .shape [1 ],codebooks [0 ].embedding_dim , device =codes .device ,dtype =codebooks [0 ].weight .dtype ) for i ,codebook in enumerate (codebooks ): features =features +codebook (codes [:,:,i ]) return self .forward (features ,target_length ) @torch .no_grad () def stream_decode ( self , features :torch .Tensor , chunk_size :int =10 , )->torch .Tensor : """ Streaming decode for real-time speech synthesis. Processes features in chunks for low-latency output. Args: features: [B, T, hidden_size] encoded features chunk_size: Number of feature frames per chunk Yields: waveform_chunk: [B, chunk_audio_len] audio chunk """ batch_size ,seq_len ,_ =features .shape audio_chunks =[] for start in range (0 ,seq_len ,chunk_size ): end =min (start +chunk_size ,seq_len ) chunk =features [:,start :end ,:] audio_chunk =self .forward (chunk ) audio_chunks .append (audio_chunk ) return torch .cat (audio_chunks ,dim =-1 ) class SpeakerEncoder (nn .Module ): """ Zero-Shot Speaker Encoder for speaker cloning. Extracts speaker embeddings from reference audio that can be used to clone the speaker's voice characteristics. """ def __init__ ( self , hidden_size :int =256 , output_size :int =256 , num_layers :int =3 , ): super ().__init__ () self .hidden_size =hidden_size self .output_size =output_size self .frame_encoder =nn .Sequential ( nn .Conv1d (80 ,hidden_size ,5 ,1 ,2 ), nn .ReLU (), nn .GroupNorm (1 ,hidden_size ), nn .Conv1d (hidden_size ,hidden_size ,5 ,1 ,2 ), nn .ReLU (), nn .GroupNorm (1 ,hidden_size ), nn .Conv1d (hidden_size ,hidden_size ,5 ,1 ,2 ), nn .ReLU (), nn .GroupNorm (1 ,hidden_size ), ) self .lstm =nn .LSTM ( hidden_size ,hidden_size , num_layers =num_layers , batch_first =True , bidirectional =True , ) self .attention =nn .Sequential ( nn .Linear (hidden_size *2 ,hidden_size ), nn .Tanh (), nn .Linear (hidden_size ,1 ), ) self .output_proj =nn .Linear (hidden_size *2 ,output_size ) print (f" ๐Ÿ‘ค SpeakerEncoder: {hidden_size }d -> {output_size }d speaker embedding") def forward (self ,mel_spectrogram :torch .Tensor )->torch .Tensor : """ Extract speaker embedding from mel spectrogram. Args: mel_spectrogram: [B, n_mels, T] mel spectrogram Returns: speaker_embedding: [B, output_size] speaker embedding """ x =self .frame_encoder (mel_spectrogram ) x =x .transpose (1 ,2 ) x ,_ =self .lstm (x ) attn_weights =self .attention (x ) attn_weights =F .softmax (attn_weights ,dim =1 ) x =(x *attn_weights ).sum (dim =1 ) speaker_embedding =self .output_proj (x ) speaker_embedding =F .normalize (speaker_embedding ,p =2 ,dim =-1 ) return speaker_embedding class MonotonicAlignmentSearch (nn .Module ): """ Monotonic Alignment Search (MAS) for text-to-audio alignment. Implements both: 1. Hard MAS for inference (dynamic programming) 2. Soft/Fluid MAS for training (differentiable) """ def __init__ (self ,hidden_size :int =1024 ): super ().__init__ () self .hidden_size =hidden_size self .alignment_proj =nn .Sequential ( nn .Linear (hidden_size *2 ,hidden_size ), nn .ReLU (), nn .Linear (hidden_size ,1 ), ) self .duration_predictor =nn .Sequential ( nn .Conv1d (hidden_size ,hidden_size ,3 ,padding =1 ), nn .ReLU (), nn .GroupNorm (1 ,hidden_size ), nn .Conv1d (hidden_size ,hidden_size ,3 ,padding =1 ), nn .ReLU (), nn .GroupNorm (1 ,hidden_size ), nn .Conv1d (hidden_size ,1 ,1 ), ) @staticmethod def hard_mas (log_probs :torch .Tensor )->torch .Tensor : """ Hard Monotonic Alignment Search using dynamic programming. Args: log_probs: [B, T_text, T_audio] log alignment probabilities Returns: alignment: [B, T_text, T_audio] hard alignment matrix """ batch_size ,text_len ,audio_len =log_probs .shape device =log_probs .device Q =torch .full ((batch_size ,text_len ,audio_len ),float ('-inf'),device =device ) Q [:,0 ,0 ]=log_probs [:,0 ,0 ] for j in range (1 ,audio_len ): Q [:,0 ,j ]=Q [:,0 ,j -1 ]+log_probs [:,0 ,j ] for i in range (1 ,text_len ): Q [:,i ,i ]=Q [:,i -1 ,i -1 ]+log_probs [:,i ,i ] for j in range (i +1 ,audio_len ): Q [:,i ,j ]=torch .max (Q [:,i -1 ,j -1 ],Q [:,i ,j -1 ])+log_probs [:,i ,j ] alignment =torch .zeros_like (log_probs ) for b in range (batch_size ): i ,j =text_len -1 ,audio_len -1 while i >=0 and j >=0 : alignment [b ,i ,j ]=1 if i ==0 : j -=1 elif j ==0 : i -=1 elif Q [b ,i -1 ,j -1 ]>=Q [b ,i ,j -1 ]: i -=1 j -=1 else : j -=1 return alignment def soft_mas ( self , text_hidden :torch .Tensor , audio_hidden :torch .Tensor , temperature :float =1.0 , )->torch .Tensor : """ Soft/Differentiable Monotonic Alignment Search. Args: text_hidden: [B, T_text, hidden_size] text features audio_hidden: [B, T_audio, hidden_size] audio features temperature: Softmax temperature Returns: soft_alignment: [B, T_text, T_audio] soft alignment matrix """ batch_size ,text_len ,_ =text_hidden .shape audio_len =audio_hidden .shape [1 ] text_expanded =text_hidden .unsqueeze (2 ).expand (-1 ,-1 ,audio_len ,-1 ) audio_expanded =audio_hidden .unsqueeze (1 ).expand (-1 ,text_len ,-1 ,-1 ) combined =torch .cat ([text_expanded ,audio_expanded ],dim =-1 ) logits =self .alignment_proj (combined ).squeeze (-1 ) logits =logits /temperature position_bias =torch .arange (audio_len ,device =logits .device ).float () position_bias =position_bias .unsqueeze (0 ).unsqueeze (0 ) text_positions =torch .arange (text_len ,device =logits .device ).float () text_positions =text_positions .unsqueeze (0 ).unsqueeze (2 ) expected_pos =text_positions *(audio_len /text_len ) monotonic_bias =-0.1 *(position_bias -expected_pos ).abs () logits =logits +monotonic_bias soft_alignment =F .softmax (logits ,dim =-1 ) return soft_alignment def predict_durations (self ,text_hidden :torch .Tensor )->torch .Tensor : """ Predict durations for each text token. Args: text_hidden: [B, T_text, hidden_size] text features Returns: durations: [B, T_text] predicted durations """ x =text_hidden .transpose (1 ,2 ) durations =self .duration_predictor (x ).squeeze (1 ) durations =F .softplus (durations ) return durations def forward ( self , text_hidden :torch .Tensor , audio_hidden :Optional [torch .Tensor ]=None , use_hard :bool =False , )->Tuple [torch .Tensor ,torch .Tensor ]: """ Compute alignment and durations. Args: text_hidden: [B, T_text, hidden_size] text features audio_hidden: [B, T_audio, hidden_size] audio features (optional for inference) use_hard: Use hard MAS instead of soft Returns: alignment: [B, T_text, T_audio] alignment matrix durations: [B, T_text] predicted durations """ durations =self .predict_durations (text_hidden ) if audio_hidden is None : return None ,durations if use_hard : text_norm =F .normalize (text_hidden ,dim =-1 ) audio_norm =F .normalize (audio_hidden ,dim =-1 ) log_probs =torch .bmm (text_norm ,audio_norm .transpose (1 ,2 )) alignment =self .hard_mas (log_probs ) else : alignment =self .soft_mas (text_hidden ,audio_hidden ) return alignment ,durations class RotaryMultiHeadLatentAttention (nn .Module ): """ Rotary Multi-Head Latent Attention (RMLA). Combines: - Multi-Head Latent Attention (MLA) for compressed KV cache - Rotary Position Embeddings (RoPE) for position awareness - Efficient attention computation """ def __init__ ( self , hidden_size :int =1024 , num_heads :int =16 , num_kv_heads :int =4 , head_dim :int =64 , kv_lora_rank :int =256 , max_position_embeddings :int =8192 , dropout :float =0.1 , ): super ().__init__ () self .hidden_size =hidden_size self .num_heads =num_heads self .num_kv_heads =num_kv_heads self .head_dim =head_dim self .kv_lora_rank =kv_lora_rank self .num_key_value_groups =num_heads //num_kv_heads self .q_proj =nn .Linear (hidden_size ,num_heads *head_dim ,bias =False ) self .kv_a_proj =nn .Linear (hidden_size ,kv_lora_rank +head_dim ,bias =False ) self .kv_b_proj =nn .Linear (kv_lora_rank ,num_kv_heads *head_dim *2 ,bias =False ) self .kv_norm =nn .LayerNorm (kv_lora_rank ) self .o_proj =nn .Linear (num_heads *head_dim ,hidden_size ,bias =False ) self .rotary_emb =self ._create_rotary_embedding (head_dim ,max_position_embeddings ) self .dropout =nn .Dropout (dropout ) self .scale =head_dim **-0.5 def _create_rotary_embedding (self ,dim :int ,max_seq_len :int )->nn .Module : """Create rotary position embeddings.""" inv_freq =1.0 /(10000 **(torch .arange (0 ,dim ,2 ).float ()/dim )) self .register_buffer ('inv_freq',inv_freq ) t =torch .arange (max_seq_len ).float () freqs =torch .einsum ('i,j->ij',t ,inv_freq ) emb =torch .cat ([freqs ,freqs ],dim =-1 ) self .register_buffer ('cos_cached',emb .cos ()) self .register_buffer ('sin_cached',emb .sin ()) return None def _apply_rotary (self ,x :torch .Tensor ,seq_len :int )->torch .Tensor : """Apply rotary position embeddings.""" cos =self .cos_cached [:seq_len ].unsqueeze (0 ).unsqueeze (0 ) sin =self .sin_cached [:seq_len ].unsqueeze (0 ).unsqueeze (0 ) x1 ,x2 =x [...,:x .shape [-1 ]//2 ],x [...,x .shape [-1 ]//2 :] rotated =torch .cat ([-x2 ,x1 ],dim =-1 ) return x *cos .to (x .dtype )+rotated *sin .to (x .dtype ) def forward ( self , hidden_states :torch .Tensor , attention_mask :Optional [torch .Tensor ]=None , past_key_value :Optional [Tuple [torch .Tensor ,torch .Tensor ]]=None , use_cache :bool =False , )->Tuple [torch .Tensor ,Optional [Tuple [torch .Tensor ,torch .Tensor ]]]: """ Forward pass with RMLA. Args: hidden_states: [B, T, hidden_size] attention_mask: Optional attention mask past_key_value: Optional cached KV states use_cache: Whether to return updated cache Returns: output: [B, T, hidden_size] present_key_value: Optional updated cache """ batch_size ,seq_len ,_ =hidden_states .shape query =self .q_proj (hidden_states ) query =query .view (batch_size ,seq_len ,self .num_heads ,self .head_dim ).transpose (1 ,2 ) kv_compressed =self .kv_a_proj (hidden_states ) kv_latent ,k_pe =kv_compressed .split ([self .kv_lora_rank ,self .head_dim ],dim =-1 ) kv_latent =self .kv_norm (kv_latent ) kv =self .kv_b_proj (kv_latent ) key ,value =kv .split (self .num_kv_heads *self .head_dim ,dim =-1 ) key =key .view (batch_size ,seq_len ,self .num_kv_heads ,self .head_dim ).transpose (1 ,2 ) value =value .view (batch_size ,seq_len ,self .num_kv_heads ,self .head_dim ).transpose (1 ,2 ) query =self ._apply_rotary (query ,seq_len ) key =self ._apply_rotary (key ,seq_len ) if past_key_value is not None : past_key ,past_value =past_key_value key =torch .cat ([past_key ,key ],dim =2 ) value =torch .cat ([past_value ,value ],dim =2 ) present_key_value =(key ,value )if use_cache else None qk_scale =self .head_dim **-0.25 kv_len =key .shape [2 ] use_causal =(attention_mask is None and seq_len >1 and seq_len ==kv_len ) dropout_p =self .dropout .p if self .training else 0.0 output =F .scaled_dot_product_attention ( query *qk_scale , key *qk_scale , value , attn_mask =attention_mask , is_causal =use_causal , dropout_p =dropout_p , scale =1.0 , enable_gqa =(self .num_key_value_groups >1 ), ) output =output .transpose (1 ,2 ).contiguous ().view (batch_size ,seq_len ,-1 ) output =self .o_proj (output ) return output ,present_key_value class InContextAudioPrompting (nn .Module ): """ In-Context Audio Prompting for conditioning generation on reference audio. Allows the model to use a reference audio clip to guide the style, speaker characteristics, and prosody of generated audio. """ def __init__ ( self , hidden_size :int =1024 , num_prompt_tokens :int =32 , num_heads :int =8 , ): super ().__init__ () self .hidden_size =hidden_size self .num_prompt_tokens =num_prompt_tokens self .prompt_tokens =nn .Parameter (torch .randn (1 ,num_prompt_tokens ,hidden_size )*0.02 ) self .cross_attn =nn .MultiheadAttention ( hidden_size ,num_heads , dropout =0.1 , batch_first =True , ) self .prompt_encoder =nn .Sequential ( nn .Linear (hidden_size ,hidden_size ), nn .SiLU (), nn .Linear (hidden_size ,hidden_size ), ) self .gate =nn .Parameter (torch .zeros (1 )) self .norm =nn .LayerNorm (hidden_size ) def encode_prompt (self ,audio_features :torch .Tensor )->torch .Tensor : """ Encode reference audio into prompt tokens. Args: audio_features: [B, T, hidden_size] reference audio features Returns: prompt: [B, num_prompt_tokens, hidden_size] encoded prompt """ batch_size =audio_features .shape [0 ] prompt =self .prompt_tokens .expand (batch_size ,-1 ,-1 ) prompt ,_ =self .cross_attn (prompt ,audio_features ,audio_features ) prompt =self .prompt_encoder (prompt ) return prompt def forward ( self , hidden_states :torch .Tensor , prompt_features :Optional [torch .Tensor ]=None , audio_prompt :Optional [torch .Tensor ]=None , )->torch .Tensor : """ Apply in-context audio prompting. Args: hidden_states: [B, T, hidden_size] input features prompt_features: [B, num_prompt_tokens, hidden_size] pre-encoded prompt audio_prompt: [B, T_prompt, hidden_size] raw audio features to encode Returns: output: [B, T, hidden_size] conditioned features """ if prompt_features is None and audio_prompt is not None : prompt_features =self .encode_prompt (audio_prompt ) if prompt_features is None : return hidden_states attended ,_ =self .cross_attn (hidden_states ,prompt_features ,prompt_features ) gate =torch .sigmoid (self .gate ) output =hidden_states +gate *attended output =self .norm (output ) return output class ConvolutionModule (nn .Module ): """Conformer convolution module with gating.""" def __init__ (self ,channels :int ,kernel_size :int =31 ,dropout :float =0.1 ): super ().__init__ () self .layer_norm =nn .LayerNorm (channels ) self .pointwise_conv1 =nn .Conv1d (channels ,2 *channels ,kernel_size =1 ) self .depthwise_conv =nn .Conv1d ( channels ,channels ,kernel_size =kernel_size , padding =(kernel_size -1 )//2 ,groups =channels ) self .batch_norm =nn .GroupNorm (1 ,channels ) self .pointwise_conv2 =nn .Conv1d (channels ,channels ,kernel_size =1 ) self .dropout =nn .Dropout (dropout ) def forward (self ,x :torch .Tensor )->torch .Tensor : """x: [B, T, C]""" x =self .layer_norm (x ) x =x .transpose (1 ,2 ) x =self .pointwise_conv1 (x ) x =F .glu (x ,dim =1 ) x =self .depthwise_conv (x ) x =self .batch_norm (x ) x =F .silu (x ) x =self .pointwise_conv2 (x ) x =self .dropout (x ) return x .transpose (1 ,2 ) class ConformerBlock (nn .Module ): """Single Conformer block with RMLA, feed-forward, and convolution.""" def __init__ ( self , d_model :int , num_heads :int =8 , ff_expansion :int =4 , conv_kernel_size :int =31 , dropout :float =0.1 , use_rmla :bool =True , ): super ().__init__ () self .use_rmla =use_rmla self .ff1_norm =nn .LayerNorm (d_model ) self .ff1 =nn .Sequential ( nn .Linear (d_model ,d_model *ff_expansion ), nn .SiLU (), nn .Dropout (dropout ), nn .Linear (d_model *ff_expansion ,d_model ), nn .Dropout (dropout ) ) if use_rmla : self .attn =RotaryMultiHeadLatentAttention ( hidden_size =d_model , num_heads =num_heads , num_kv_heads =max (1 ,num_heads //4 ), head_dim =d_model //num_heads , kv_lora_rank =d_model //4 , dropout =dropout , ) else : self .attn_norm =nn .LayerNorm (d_model ) self .attn =nn .MultiheadAttention (d_model ,num_heads ,dropout =dropout ,batch_first =True ) self .attn_dropout =nn .Dropout (dropout ) self .conv =ConvolutionModule (d_model ,conv_kernel_size ,dropout ) self .ff2_norm =nn .LayerNorm (d_model ) self .ff2 =nn .Sequential ( nn .Linear (d_model ,d_model *ff_expansion ), nn .SiLU (), nn .Dropout (dropout ), nn .Linear (d_model *ff_expansion ,d_model ), nn .Dropout (dropout ) ) self .final_norm =nn .LayerNorm (d_model ) def forward ( self , x :torch .Tensor , mask :Optional [torch .Tensor ]=None , past_key_value :Optional [Tuple ]=None , use_cache :bool =False , )->Tuple [torch .Tensor ,Optional [Tuple ]]: x =x +0.5 *self .ff1 (self .ff1_norm (x )) if self .use_rmla : attn_mask =None if mask is not None : attn_mask =mask .unsqueeze (1 ).unsqueeze (2 ) attn_mask =attn_mask .to (dtype =x .dtype ) attn_mask =attn_mask .masked_fill (attn_mask .bool (),float ('-inf')) attn_out ,present_kv =self .attn (x ,attention_mask =attn_mask ,past_key_value =past_key_value ,use_cache =use_cache ) else : attn_out ,_ =self .attn (self .attn_norm (x ),self .attn_norm (x ),self .attn_norm (x ),key_padding_mask =mask ) present_kv =None x =x +self .attn_dropout (attn_out ) x =x +self .conv (x ) x =x +0.5 *self .ff2 (self .ff2_norm (x )) return self .final_norm (x ),present_kv class AudioEncoder (nn .Module ): """ SOTA Audio Encoder with Raw Waveform Tokenization, RMLA, and Voice Enhancement. Features: - Raw waveform tokenization (no mel spectrogram) - Conformer blocks with RMLA - Zero-shot speaker encoding - In-context audio prompting - Gradient checkpointing support for memory efficiency Voice Enhancement Features (SOTA): - Prosody-aware EoT Prediction (interruption detection) - AVD Emotion Recognition (arousal/valence/dominance) - Dynamic Latent Vocalizations (singing/rapping) - Neural Sound Effects (beatboxing, breathing, expressions) - Speculative Decoding (mid-stream token rewriting) """ def __init__ ( self , hidden_size :int =1024 , n_mels :int =80 , max_audio_length :int =3000 , num_layers :int =6 , num_heads :int =8 , dropout :float =0.1 , use_raw_waveform :bool =True , enable_eot :bool =True , enable_emotion :bool =True , enable_singing :bool =True , enable_effects :bool =True , enable_speculative :bool =True , ): super ().__init__ () self .hidden_size =hidden_size self .max_audio_length =max_audio_length self .use_raw_waveform =use_raw_waveform self .gradient_checkpointing =False self .enable_eot =enable_eot self .enable_emotion =enable_emotion self .enable_singing =enable_singing self .enable_effects =enable_effects self .enable_speculative =enable_speculative if use_raw_waveform : self .waveform_tokenizer =RawWaveformTokenizer ( hidden_size =hidden_size , num_codebooks =8 , codebook_size =1024 , ) else : self .waveform_tokenizer =None self .conv_subsample =nn .Sequential ( nn .Conv1d (n_mels ,hidden_size //2 ,kernel_size =3 ,stride =2 ,padding =1 ), nn .GELU (), nn .Conv1d (hidden_size //2 ,hidden_size ,kernel_size =3 ,stride =2 ,padding =1 ), nn .GELU (), ) self .speaker_encoder =SpeakerEncoder ( hidden_size =256 , output_size =hidden_size //4 , ) self .audio_prompting =InContextAudioPrompting ( hidden_size =hidden_size , num_prompt_tokens =32 , ) self .conformer_blocks =nn .ModuleList ([ ConformerBlock ( hidden_size ,num_heads , ff_expansion =4 , conv_kernel_size =31 , dropout =dropout , use_rmla =True , ) for _ in range (num_layers ) ]) self .output_proj =nn .Linear (hidden_size ,hidden_size ) if enable_eot : self .eot_predictor =ProsodyAwareEoTPredictor (hidden_size ,dropout =dropout ) else : self .eot_predictor =None if enable_emotion : self .emotion_recognizer =AVDEmotionRecognizer (hidden_size ,dropout =dropout ) else : self .emotion_recognizer =None if enable_singing : self .vocalizer =DynamicLatentVocalizer (hidden_size ) else : self .vocalizer =None if enable_effects : self .effects_generator =NeuralSoundEffectGenerator (hidden_size ) else : self .effects_generator =None if enable_speculative : self .speculative_decoder =SpeculativeAudioDecoder (hidden_size ) else : self .speculative_decoder =None print (f" ๐ŸŽค AudioEncoder (RMLA Conformer): {hidden_size }d, {num_layers } layers") if use_raw_waveform : print (f" - Raw Waveform Tokenizer enabled") print (f" - Zero-Shot Speaker Encoder enabled") print (f" - In-Context Audio Prompting enabled") print (f" - EoT/Interruption Detection: {enable_eot }") print (f" - Emotion Recognition (AVD): {enable_emotion }") print (f" - Singing/Rapping (Vocalizer): {enable_singing }") print (f" - Sound Effects Generator: {enable_effects }") print (f" - Speculative Decoding: {enable_speculative }") def gradient_checkpointing_enable (self ): """Enable gradient checkpointing to save memory during training.""" self .gradient_checkpointing =True if hasattr (self ,'waveform_tokenizer')and self .waveform_tokenizer is not None : if hasattr (self .waveform_tokenizer ,'gradient_checkpointing'): self .waveform_tokenizer .gradient_checkpointing =True if hasattr (self ,'speaker_encoder')and self .speaker_encoder is not None : if hasattr (self .speaker_encoder ,'gradient_checkpointing'): self .speaker_encoder .gradient_checkpointing =True def gradient_checkpointing_disable (self ): """Disable gradient checkpointing.""" self .gradient_checkpointing =False def forward ( self , audio_input :torch .Tensor , speaker_ref :Optional [torch .Tensor ]=None , audio_prompt :Optional [torch .Tensor ]=None , mask :Optional [torch .Tensor ]=None , return_eot :bool =False , return_emotion :bool =False , )->Tuple [torch .Tensor ,Optional [torch .Tensor ],Optional [dict ]]: """ Process audio to features with optional voice enhancement outputs. Args: audio_input: [B, T] raw waveform or [B, n_mels, T] mel spectrogram speaker_ref: [B, n_mels, T_ref] reference audio for speaker cloning audio_prompt: [B, T_prompt, hidden_size] audio prompt features mask: Optional attention mask return_eot: Whether to return EoT/interruption predictions return_emotion: Whether to return emotion/AVD predictions Returns: features: [B, T', hidden_size] audio features speaker_embedding: [B, hidden_size//4] speaker embedding (if speaker_ref provided) extras: dict with EoT/emotion predictions (if requested) """ commitment_loss =None if self .use_raw_waveform and self .waveform_tokenizer is not None : if audio_input .dim ()==3 and audio_input .shape [1 ]==1 : audio_input =audio_input .squeeze (1 ) elif audio_input .dim ()==3 : audio_input =audio_input .mean (dim =1 ) x ,commitment_loss =self .waveform_tokenizer (audio_input ) elif hasattr (self ,'conv_subsample')and self .conv_subsample is not None : if audio_input .dim ()==2 : audio_input =audio_input .unsqueeze (1 ) x =self .conv_subsample (audio_input ) x =x .transpose (1 ,2 ) else : raise RuntimeError ( f"AudioEncoder: Incompatible configuration. " f"use_raw_waveform={self .use_raw_waveform }, " f"waveform_tokenizer={self .waveform_tokenizer is not None }, " f"conv_subsample={hasattr (self ,'conv_subsample')and self .conv_subsample is not None }" ) speaker_embedding =None if speaker_ref is not None : speaker_embedding =self .speaker_encoder (speaker_ref ) if audio_prompt is not None : x =self .audio_prompting (x ,audio_prompt =audio_prompt ) if self .gradient_checkpointing and self .training : from torch .utils .checkpoint import checkpoint for block in self .conformer_blocks : def create_custom_forward (module ): def custom_forward (*inputs ): return module (*inputs ) return custom_forward x ,_ =checkpoint (create_custom_forward (block ),x ,mask ,use_reentrant =False ) else : for block in self .conformer_blocks : x ,_ =block (x ,mask ) x =self .output_proj (x ) extras ={} if return_eot and self .eot_predictor is not None : extras ["eot"]=self .eot_predictor (x ,mask ) if return_emotion and self .emotion_recognizer is not None : extras ["emotion"]=self .emotion_recognizer (x ,mask ) return x ,speaker_embedding ,extras if extras else None def detect_interruption ( self , audio_features :torch .Tensor , attention_mask :Optional [torch .Tensor ]=None , )->Optional [dict ]: """ Detect interruptions, backchannels, and turn-taking events. Args: audio_features: [B, T, hidden_size] encoded audio attention_mask: [B, T] optional mask Returns: dict with eot_logits, event_logits, vad_logits, backoff_prob """ if self .eot_predictor is None : return None return self .eot_predictor (audio_features ,attention_mask ) def recognize_emotion ( self , audio_features :torch .Tensor , attention_mask :Optional [torch .Tensor ]=None , )->Optional [dict ]: """ Recognize emotion with AVD (arousal/valence/dominance) values. Args: audio_features: [B, T, hidden_size] encoded audio attention_mask: [B, T] optional mask Returns: dict with emotion_logits, arousal, valence, dominance, response_mode """ if self .emotion_recognizer is None : return None return self .emotion_recognizer (audio_features ,attention_mask ) def generate_vocals ( self , text_features :torch .Tensor , style_id :Optional [torch .Tensor ]=None , mode_id :Optional [torch .Tensor ]=None , target_pitch :Optional [torch .Tensor ]=None , tempo_bpm :Optional [torch .Tensor ]=None , )->Optional [dict ]: """ Generate singing/rapping vocals from text/lyrics. Args: text_features: [B, T, hidden_size] text embeddings style_id: [B] style indices (pop, rock, jazz, etc.) mode_id: [B] mode indices (speak, sing, rap, hum, etc.) target_pitch: [B, T] pitch targets tempo_bpm: [B] tempo in BPM Returns: dict with vocal_features, pitch_logits, alignment, durations """ if self .vocalizer is None : return None return self .vocalizer (text_features ,style_id ,mode_id ,target_pitch ,tempo_bpm ) def generate_effects ( self , effect_ids :torch .Tensor , context :Optional [torch .Tensor ]=None , intensity :Optional [torch .Tensor ]=None , )->Optional [dict ]: """ Generate sound effects (beatbox, clicks, breathing, etc.). Args: effect_ids: [B] or [B, N] effect type indices context: [B, T, hidden_size] optional context intensity: [B] intensity values Returns: dict with effect_features, waveform, duration, intensity """ if self .effects_generator is None : return None return self .effects_generator (effect_ids ,context ,intensity ) def speculative_generate ( self , context :torch .Tensor , generate_draft :bool =True , verify_with :Optional [torch .Tensor ]=None , )->Optional [dict ]: """ Generate speculative draft tokens for mid-stream rewriting. Args: context: [B, T, hidden_size] current context generate_draft: whether to generate new draft verify_with: [B, T', hidden_size] new context to verify against Returns: dict with checkpoint, draft_tokens, confidence, accept_prob """ if self .speculative_decoder is None : return None return self .speculative_decoder (context ,generate_draft ,verify_with ) class VariancePredictor (nn .Module ): """Variance predictor for duration, pitch, and energy.""" def __init__ (self ,hidden_size :int ,kernel_size :int =3 ,dropout :float =0.1 ): super ().__init__ () self .conv1 =nn .Conv1d (hidden_size ,hidden_size ,kernel_size ,padding =kernel_size //2 ) self .norm1 =nn .LayerNorm (hidden_size ) self .conv2 =nn .Conv1d (hidden_size ,hidden_size ,kernel_size ,padding =kernel_size //2 ) self .norm2 =nn .LayerNorm (hidden_size ) self .dropout =nn .Dropout (dropout ) self .linear =nn .Linear (hidden_size ,1 ) def forward (self ,x :torch .Tensor )->torch .Tensor : """x: [B, T, C] -> [B, T]""" out =self .conv1 (x .transpose (1 ,2 )).transpose (1 ,2 ) out =F .relu (out ) out =self .norm1 (out ) out =self .dropout (out ) out =self .conv2 (out .transpose (1 ,2 )).transpose (1 ,2 ) out =F .relu (out ) out =self .norm2 (out ) out =self .dropout (out ) return self .linear (out ).squeeze (-1 ) class FFTBlock (nn .Module ): """FFT block for mel decoder.""" def __init__ ( self , hidden_size :int , num_heads :int =4 , ff_expansion :int =4 , kernel_size :int =9 , dropout :float =0.1 , ): super ().__init__ () self .attn =RotaryMultiHeadLatentAttention ( hidden_size =hidden_size , num_heads =num_heads , num_kv_heads =max (1 ,num_heads //2 ), head_dim =hidden_size //num_heads , kv_lora_rank =hidden_size //4 , dropout =dropout , ) self .attn_norm =nn .LayerNorm (hidden_size ) self .attn_dropout =nn .Dropout (dropout ) self .ff_norm =nn .LayerNorm (hidden_size ) self .ff =nn .Sequential ( nn .Conv1d (hidden_size ,hidden_size *ff_expansion ,kernel_size ,padding =kernel_size //2 ), nn .ReLU (), nn .Conv1d (hidden_size *ff_expansion ,hidden_size ,kernel_size ,padding =kernel_size //2 ), nn .Dropout (dropout ) ) def forward (self ,x :torch .Tensor )->torch .Tensor : residual =x x =self .attn_norm (x ) x ,_ =self .attn (x ) x =residual +self .attn_dropout (x ) residual =x x =self .ff_norm (x ) x =self .ff (x .transpose (1 ,2 )).transpose (1 ,2 ) x =residual +x return x class AudioDecoder (nn .Module ): """ SOTA Audio Decoder with MAS, Zero-Shot Speaker Cloning, and Voice Enhancement Support. Features: - Monotonic Alignment Search for text-to-audio alignment - Zero-shot speaker cloning via speaker embeddings - In-context audio prompting - Variance adaptor with duration, pitch, energy prediction - RMLA-based FFT blocks - Gradient checkpointing support for memory efficiency Voice Enhancement Features (matching AudioEncoder): - Emotion conditioning for emotional speech synthesis - Singing/vocal style synthesis support - Sound effect generation and integration - Raw waveform output support (optional) - Speculative decoding integration """ def __init__ ( self , hidden_size :int =1024 , n_mels :int =80 , max_audio_length :int =1000 , num_speakers :int =256 , num_decoder_layers :int =4 , dropout :float =0.1 , enable_emotion :bool =True , enable_singing :bool =True , enable_effects :bool =True , enable_raw_waveform :bool =True , enable_speculative :bool =True , num_emotions :int =10 , num_vocal_styles :int =8 , num_vocal_modes :int =6 , num_effect_types :int =20 , ): super ().__init__ () self .hidden_size =hidden_size self .n_mels =n_mels self .max_audio_length =max_audio_length self .gradient_checkpointing =False self .enable_emotion =enable_emotion self .enable_singing =enable_singing self .enable_effects =enable_effects self .enable_raw_waveform =enable_raw_waveform self .enable_speculative =enable_speculative self .mas =MonotonicAlignmentSearch (hidden_size ) self .speaker_embed =nn .Embedding (num_speakers ,hidden_size //4 ) self .speaker_proj =nn .Linear (hidden_size //4 ,hidden_size //4 ) self .audio_prompting =InContextAudioPrompting ( hidden_size =hidden_size , num_prompt_tokens =32 , ) if enable_emotion : self .emotion_embed =nn .Embedding (num_emotions ,hidden_size //4 ) self .avd_proj =nn .Sequential ( nn .Linear (3 ,hidden_size //8 ), nn .SiLU (), nn .Linear (hidden_size //8 ,hidden_size //4 ), ) self .emotion_cond_size =hidden_size //4 else : self .emotion_embed =None self .avd_proj =None self .emotion_cond_size =0 if enable_singing : self .vocal_style_embed =nn .Embedding (num_vocal_styles ,hidden_size //4 ) self .vocal_mode_embed =nn .Embedding (num_vocal_modes ,hidden_size //4 ) self .tempo_proj =nn .Sequential ( nn .Linear (1 ,hidden_size //8 ), nn .SiLU (), nn .Linear (hidden_size //8 ,hidden_size //4 ), ) self .singing_cond_size =hidden_size //4 else : self .vocal_style_embed =None self .vocal_mode_embed =None self .tempo_proj =None self .singing_cond_size =0 if enable_effects : self .effect_embed =nn .Embedding (num_effect_types ,hidden_size //4 ) self .effect_intensity_proj =nn .Sequential ( nn .Linear (1 ,hidden_size //8 ), nn .SiLU (), nn .Linear (hidden_size //8 ,hidden_size //4 ), ) self .effect_cond_size =hidden_size //4 else : self .effect_embed =None self .effect_intensity_proj =None self .effect_cond_size =0 total_cond_size =hidden_size //4 total_cond_size +=self .emotion_cond_size total_cond_size +=self .singing_cond_size total_cond_size +=self .effect_cond_size self .input_proj =nn .Linear (hidden_size +total_cond_size ,hidden_size ) self .encoder_blocks =nn .ModuleList ([ FFTBlock (hidden_size ,num_heads =4 ,ff_expansion =4 ,dropout =dropout ) for _ in range (4 ) ]) self .duration_predictor =VariancePredictor (hidden_size ,dropout =dropout ) self .pitch_predictor =VariancePredictor (hidden_size ,dropout =dropout ) self .energy_predictor =VariancePredictor (hidden_size ,dropout =dropout ) self .pitch_embed =nn .Conv1d (1 ,hidden_size ,kernel_size =9 ,padding =4 ) self .energy_embed =nn .Conv1d (1 ,hidden_size ,kernel_size =9 ,padding =4 ) self .decoder_blocks =nn .ModuleList ([ FFTBlock (hidden_size ,num_heads =4 ,ff_expansion =4 ,dropout =dropout ) for _ in range (num_decoder_layers ) ]) self .mel_linear =nn .Linear (hidden_size ,n_mels ) self .postnet =nn .ModuleList ([ nn .Sequential ( nn .Conv1d (n_mels ,256 ,kernel_size =5 ,padding =2 ), nn .GroupNorm (1 ,256 ), nn .Tanh (), ), nn .Sequential ( nn .Conv1d (256 ,256 ,kernel_size =5 ,padding =2 ), nn .GroupNorm (1 ,256 ), nn .Tanh (), ), nn .Sequential ( nn .Conv1d (256 ,256 ,kernel_size =5 ,padding =2 ), nn .GroupNorm (1 ,256 ), nn .Tanh (), ), nn .Sequential ( nn .Conv1d (256 ,256 ,kernel_size =5 ,padding =2 ), nn .GroupNorm (1 ,256 ), nn .Tanh (), ), nn .Conv1d (256 ,n_mels ,kernel_size =5 ,padding =2 ), ]) if enable_raw_waveform : self .waveform_decoder =RawWaveformDecoder ( hidden_size =hidden_size , sample_rate =16000 , ) else : self .waveform_decoder =None if enable_speculative : self .speculative_decoder =SpeculativeAudioDecoder ( hidden_size =hidden_size , draft_length =10 , ) else : self .speculative_decoder =None print (f" ๐Ÿ”Š AudioDecoder (MAS + RMLA): {hidden_size }d -> {n_mels } mels") print (f" - Monotonic Alignment Search enabled") print (f" - Zero-Shot Speaker Cloning enabled") print (f" - In-Context Audio Prompting enabled") print (f" - Emotion Conditioning: {enable_emotion }") print (f" - Singing/Vocal Styles: {enable_singing }") print (f" - Sound Effects: {enable_effects }") print (f" - Raw Waveform Output: {enable_raw_waveform }") print (f" - Speculative Decoding: {enable_speculative }") def gradient_checkpointing_enable (self ): """Enable gradient checkpointing to save memory during training.""" self .gradient_checkpointing =True def gradient_checkpointing_disable (self ): """Disable gradient checkpointing.""" self .gradient_checkpointing =False def forward ( self , text_embeds :torch .Tensor , target_length :Optional [int ]=None , speaker :Optional [torch .Tensor ]=None , speaker_embedding :Optional [torch .Tensor ]=None , audio_prompt :Optional [torch .Tensor ]=None , audio_features :Optional [torch .Tensor ]=None , duration_target :Optional [torch .Tensor ]=None , pitch_target :Optional [torch .Tensor ]=None , energy_target :Optional [torch .Tensor ]=None , use_mas :bool =True , emotion_id :Optional [torch .Tensor ]=None , avd_values :Optional [torch .Tensor ]=None , vocal_style_id :Optional [torch .Tensor ]=None , vocal_mode_id :Optional [torch .Tensor ]=None , tempo_bpm :Optional [torch .Tensor ]=None , effect_id :Optional [torch .Tensor ]=None , effect_intensity :Optional [torch .Tensor ]=None , output_waveform :bool =False , use_speculative :bool =False , )->Tuple [torch .Tensor ,torch .Tensor ,Optional [torch .Tensor ],Optional [dict ]]: """ Generate mel-spectrogram from text embeddings with voice enhancement support. Args: text_embeds: [B, T, hidden_size] text embeddings target_length: target mel length (for training) speaker: [B] speaker IDs (for multi-speaker) speaker_embedding: [B, hidden_size//4] zero-shot speaker embedding audio_prompt: [B, T_prompt, hidden_size] audio prompt features audio_features: [B, T_audio, hidden_size] target audio features (for MAS training) duration_target: [B, T] ground truth durations pitch_target: [B, T'] ground truth pitch energy_target: [B, T'] ground truth energy use_mas: Whether to use MAS for alignment Voice enhancement args: emotion_id: [B] discrete emotion category (0-9) avd_values: [B, 3] continuous arousal/valence/dominance values vocal_style_id: [B] singing style (0-7: pop, rock, jazz, etc.) vocal_mode_id: [B] vocal mode (0-5: speak, sing, rap, hum, whistle, chant) tempo_bpm: [B] tempo in BPM for singing/rapping effect_id: [B] sound effect type (0-19) effect_intensity: [B] effect intensity (0-1) output_waveform: Whether to also output raw waveform use_speculative: Whether to use speculative decoding Returns: mel: [B, n_mels, T'] generated mel spectrogram durations: [B, T] predicted durations alignment: [B, T_text, T_audio] alignment matrix (if use_mas and audio_features provided) extras: dict with optional outputs (waveform, speculative results) """ batch_size ,seq_len ,_ =text_embeds .shape device =text_embeds .device dtype =text_embeds .dtype extras ={} if speaker_embedding is not None : spk_emb =self .speaker_proj (speaker_embedding ) elif speaker is not None : spk_emb =self .speaker_embed (speaker ) else : speaker =torch .zeros (batch_size ,dtype =torch .long ,device =device ) spk_emb =self .speaker_embed (speaker ) spk_emb =spk_emb .unsqueeze (1 ).expand (-1 ,seq_len ,-1 ).to (dtype ) cond_embeds =[spk_emb ] if self .enable_emotion : if emotion_id is not None : emo_emb =self .emotion_embed (emotion_id ).unsqueeze (1 ).expand (-1 ,seq_len ,-1 ).to (dtype ) elif avd_values is not None : emo_emb =self .avd_proj (avd_values .to (dtype )).unsqueeze (1 ).expand (-1 ,seq_len ,-1 ) else : neutral =torch .full ((batch_size ,),6 ,dtype =torch .long ,device =device ) emo_emb =self .emotion_embed (neutral ).unsqueeze (1 ).expand (-1 ,seq_len ,-1 ).to (dtype ) cond_embeds .append (emo_emb ) if self .enable_singing : if vocal_style_id is not None : style_emb =self .vocal_style_embed (vocal_style_id ).unsqueeze (1 ).expand (-1 ,seq_len ,-1 ).to (dtype ) else : default_style =torch .zeros (batch_size ,dtype =torch .long ,device =device ) style_emb =self .vocal_style_embed (default_style ).unsqueeze (1 ).expand (-1 ,seq_len ,-1 ).to (dtype ) if vocal_mode_id is not None : mode_emb =self .vocal_mode_embed (vocal_mode_id ).unsqueeze (1 ).expand (-1 ,seq_len ,-1 ).to (dtype ) else : default_mode =torch .zeros (batch_size ,dtype =torch .long ,device =device ) mode_emb =self .vocal_mode_embed (default_mode ).unsqueeze (1 ).expand (-1 ,seq_len ,-1 ).to (dtype ) if tempo_bpm is not None : tempo_norm =(tempo_bpm .float ()-60 )/120 tempo_emb =self .tempo_proj (tempo_norm .unsqueeze (-1 ).to (dtype )).unsqueeze (1 ).expand (-1 ,seq_len ,-1 ) else : tempo_emb =torch .zeros (batch_size ,seq_len ,self .hidden_size //4 ,device =device ,dtype =dtype ) singing_emb =style_emb +mode_emb +tempo_emb cond_embeds .append (singing_emb ) if self .enable_effects : if effect_id is not None : eff_emb =self .effect_embed (effect_id ).unsqueeze (1 ).expand (-1 ,seq_len ,-1 ).to (dtype ) if effect_intensity is not None : intensity_emb =self .effect_intensity_proj (effect_intensity .unsqueeze (-1 ).to (dtype )) eff_emb =eff_emb *intensity_emb .unsqueeze (1 ) else : eff_emb =torch .zeros (batch_size ,seq_len ,self .hidden_size //4 ,device =device ,dtype =dtype ) cond_embeds .append (eff_emb ) all_cond =torch .cat (cond_embeds ,dim =-1 ) x =torch .cat ([text_embeds ,all_cond ],dim =-1 ) x =self .input_proj (x ) if audio_prompt is not None : x =self .audio_prompting (x ,audio_prompt =audio_prompt ) if self .gradient_checkpointing and self .training : from torch .utils .checkpoint import checkpoint for block in self .encoder_blocks : def create_custom_forward (module ): def custom_forward (*inputs ): return module (*inputs ) return custom_forward x =checkpoint (create_custom_forward (block ),x ,use_reentrant =False ) else : for block in self .encoder_blocks : x =block (x ) alignment =None if use_mas and audio_features is not None : alignment ,durations =self .mas (x ,audio_features ,use_hard =not self .training ) else : _ ,durations =self .mas (x ) if duration_target is not None : durations =duration_target pitch_pred =self .pitch_predictor (x ) energy_pred =F .softplus (self .energy_predictor (x )) MIN_MEL_LENGTH =1 if target_length is not None : mel_length =max (MIN_MEL_LENGTH ,target_length ) else : mel_length =int (durations .sum (dim =1 ).max ().item ()) mel_length =max (16 ,min (mel_length ,self .max_audio_length )) x =F .interpolate (x .transpose (1 ,2 ),size =mel_length ,mode ='linear',align_corners =False ).transpose (1 ,2 ) pitch =pitch_target if pitch_target is not None else pitch_pred energy =energy_target if energy_target is not None else energy_pred pitch_up =F .interpolate (pitch .unsqueeze (1 ),size =mel_length ,mode ='linear',align_corners =False ) energy_up =F .interpolate (energy .unsqueeze (1 ),size =mel_length ,mode ='linear',align_corners =False ) pitch_emb =self .pitch_embed (pitch_up ).transpose (1 ,2 ) energy_emb =self .energy_embed (energy_up ).transpose (1 ,2 ) x =x +pitch_emb +energy_emb if self .gradient_checkpointing and self .training : from torch .utils .checkpoint import checkpoint for block in self .decoder_blocks : def create_custom_forward (module ): def custom_forward (*inputs ): return module (*inputs ) return custom_forward x =checkpoint (create_custom_forward (block ),x ,use_reentrant =False ) else : for block in self .decoder_blocks : x =block (x ) mel =self .mel_linear (x ).transpose (1 ,2 ) mel_post =mel for layer in self .postnet : mel_post =layer (mel_post ) mel =mel +mel_post if output_waveform and self .waveform_decoder is not None : waveform =self .waveform_decoder (x ) extras ["waveform"]=waveform if use_speculative and self .speculative_decoder is not None : spec_results =self .speculative_decoder (x ) extras ["speculative"]=spec_results return mel ,durations ,alignment ,extras if extras else None class ProsodyAwareEoTPredictor (nn .Module ): """ Prosody-aware End-of-Turn (EoT) Prediction for real-time interruption detection. Detects when a speaker is about to finish their turn, allowing the model to: - Detect user interruptions (coughs, laughs, "uh-huh", etc.) - Yield the floor when appropriate - Adjust response mid-stream Uses prosodic features (pitch, energy, rhythm) combined with semantic features. """ def __init__ ( self , hidden_size :int =1024 , num_eot_classes :int =5 , prosody_dim :int =128 , num_heads :int =4 , dropout :float =0.1 , ): super ().__init__ () self .hidden_size =hidden_size self .num_eot_classes =num_eot_classes self .pitch_conv =nn .Sequential ( nn .Conv1d (1 ,prosody_dim //2 ,kernel_size =5 ,padding =2 ), nn .SiLU (), nn .Conv1d (prosody_dim //2 ,prosody_dim ,kernel_size =3 ,padding =1 ), ) self .energy_conv =nn .Sequential ( nn .Conv1d (1 ,prosody_dim //2 ,kernel_size =5 ,padding =2 ), nn .SiLU (), nn .Conv1d (prosody_dim //2 ,prosody_dim ,kernel_size =3 ,padding =1 ), ) self .vad_head =nn .Sequential ( nn .Linear (hidden_size ,hidden_size //2 ), nn .SiLU (), nn .Linear (hidden_size //2 ,2 ), ) self .event_classifier =nn .Sequential ( nn .Linear (hidden_size +prosody_dim *2 ,hidden_size ), nn .SiLU (), nn .Dropout (dropout ), nn .Linear (hidden_size ,hidden_size //2 ), nn .SiLU (), nn .Linear (hidden_size //2 ,8 ), ) self .temporal_attn =nn .MultiheadAttention ( embed_dim =hidden_size , num_heads =num_heads , dropout =dropout , batch_first =True , ) self .eot_head =nn .Sequential ( nn .Linear (hidden_size +prosody_dim *2 ,hidden_size ), nn .SiLU (), nn .Dropout (dropout ), nn .Linear (hidden_size ,num_eot_classes ), ) self .backoff_head =nn .Sequential ( nn .Linear (hidden_size ,hidden_size //4 ), nn .SiLU (), nn .Linear (hidden_size //4 ,1 ), nn .Sigmoid (), ) print (f" ๐ŸŽ™๏ธ ProsodyAwareEoTPredictor: {num_eot_classes } turn states, {prosody_dim }d prosody") def extract_prosody (self ,audio_features :torch .Tensor )->Tuple [torch .Tensor ,torch .Tensor ]: """Extract pitch and energy prosodic features.""" batch_size ,seq_len ,hidden =audio_features .shape x =audio_features .transpose (1 ,2 ) pitch_proxy =x [:,:1 ,:] energy_proxy =x .pow (2 ).mean (dim =1 ,keepdim =True ) pitch_features =self .pitch_conv (pitch_proxy ).transpose (1 ,2 ) energy_features =self .energy_conv (energy_proxy ).transpose (1 ,2 ) return pitch_features ,energy_features def forward ( self , audio_features :torch .Tensor , attention_mask :Optional [torch .Tensor ]=None , )->dict : """ Predict end-of-turn and interruption events. Args: audio_features: [B, T, hidden_size] encoded audio attention_mask: [B, T] optional mask Returns: dict with: - eot_logits: [B, T, num_eot_classes] turn state predictions - event_logits: [B, T, 8] interruption event predictions - vad_logits: [B, T, 2] voice activity predictions - backoff_prob: [B, T, 1] backoff probability """ batch_size ,seq_len ,_ =audio_features .shape pitch_features ,energy_features =self .extract_prosody (audio_features ) if attention_mask is not None : key_padding_mask =~attention_mask .bool () else : key_padding_mask =None contextualized ,_ =self .temporal_attn ( audio_features ,audio_features ,audio_features , key_padding_mask =key_padding_mask , ) combined =torch .cat ([contextualized ,pitch_features ,energy_features ],dim =-1 ) eot_logits =self .eot_head (combined ) event_logits =self .event_classifier (combined ) vad_logits =self .vad_head (contextualized ) backoff_prob =self .backoff_head (contextualized ) return { "eot_logits":eot_logits , "event_logits":event_logits , "vad_logits":vad_logits , "backoff_prob":backoff_prob , } class AVDEmotionRecognizer (nn .Module ): """ Continuous AVD (Arousal/Valence/Dominance) Emotion Recognition. Predicts both discrete emotion categories and continuous AVD values for nuanced emotion understanding and response adaptation. """ def __init__ ( self , hidden_size :int =1024 , num_emotions :int =10 , num_layers :int =2 , dropout :float =0.1 , ): super ().__init__ () self .hidden_size =hidden_size self .num_emotions =num_emotions self .emotion_query =nn .Parameter (torch .randn (1 ,1 ,hidden_size )) self .emotion_attn =nn .MultiheadAttention ( embed_dim =hidden_size , num_heads =8 , dropout =dropout , batch_first =True , ) self .temporal_conv =nn .Sequential ( nn .Conv1d (hidden_size ,hidden_size ,kernel_size =5 ,padding =2 ,groups =8 ), nn .SiLU (), nn .Conv1d (hidden_size ,hidden_size ,kernel_size =3 ,padding =1 ), ) self .emotion_classifier =nn .Sequential ( nn .Linear (hidden_size ,hidden_size //2 ), nn .SiLU (), nn .Dropout (dropout ), nn .Linear (hidden_size //2 ,num_emotions ), ) self .arousal_head =nn .Sequential ( nn .Linear (hidden_size ,hidden_size //4 ), nn .SiLU (), nn .Linear (hidden_size //4 ,1 ), nn .Sigmoid (), ) self .valence_head =nn .Sequential ( nn .Linear (hidden_size ,hidden_size //4 ), nn .SiLU (), nn .Linear (hidden_size //4 ,1 ), nn .Tanh (), ) self .dominance_head =nn .Sequential ( nn .Linear (hidden_size ,hidden_size //4 ), nn .SiLU (), nn .Linear (hidden_size //4 ,1 ), nn .Sigmoid (), ) self .response_adaptation =nn .Sequential ( nn .Linear (hidden_size +3 ,hidden_size //2 ), nn .SiLU (), nn .Linear (hidden_size //2 ,4 ), ) print (f" ๐Ÿ˜Š AVDEmotionRecognizer: {num_emotions } emotions + continuous AVD") def forward ( self , audio_features :torch .Tensor , attention_mask :Optional [torch .Tensor ]=None , )->dict : """ Recognize emotion from audio features. Args: audio_features: [B, T, hidden_size] encoded audio attention_mask: [B, T] optional mask Returns: dict with: - emotion_logits: [B, num_emotions] discrete emotion - arousal: [B, 1] arousal value (0-1) - valence: [B, 1] valence value (-1 to 1) - dominance: [B, 1] dominance value (0-1) - response_mode: [B, 4] response adaptation logits """ batch_size ,seq_len ,_ =audio_features .shape x_conv =self .temporal_conv (audio_features .transpose (1 ,2 )).transpose (1 ,2 ) x =audio_features +x_conv query =self .emotion_query .expand (batch_size ,-1 ,-1 ) if attention_mask is not None : key_padding_mask =~attention_mask .bool () else : key_padding_mask =None emotion_context ,_ =self .emotion_attn ( query ,x ,x , key_padding_mask =key_padding_mask , ) emotion_vec =emotion_context .squeeze (1 ) emotion_logits =self .emotion_classifier (emotion_vec ) arousal =self .arousal_head (emotion_vec ) valence =self .valence_head (emotion_vec ) dominance =self .dominance_head (emotion_vec ) avd_concat =torch .cat ([emotion_vec ,arousal ,valence ,dominance ],dim =-1 ) response_mode =self .response_adaptation (avd_concat ) return { "emotion_logits":emotion_logits , "arousal":arousal , "valence":valence , "dominance":dominance , "response_mode":response_mode , } class DynamicLatentVocalizer (nn .Module ): """ Dynamic Latent Vocalizations for singing, rapping, humming, etc. Extends speech synthesis to include: - Singing with pitch control - Rapping with rhythm control - Humming, whistling, chanting - Musical style transfer """ def __init__ ( self , hidden_size :int =1024 , num_styles :int =8 , num_vocal_modes :int =6 , pitch_bins :int =256 , tempo_range :Tuple [int ,int ]=(60 ,180 ), ): super ().__init__ () self .hidden_size =hidden_size self .num_styles =num_styles self .num_vocal_modes =num_vocal_modes self .pitch_bins =pitch_bins self .tempo_range =tempo_range self .style_embed =nn .Embedding (num_styles ,hidden_size //4 ) self .mode_embed =nn .Embedding (num_vocal_modes ,hidden_size //4 ) self .pitch_embed =nn .Embedding (pitch_bins ,hidden_size //4 ) self .pitch_predictor =nn .Sequential ( nn .Linear (hidden_size ,hidden_size //2 ), nn .SiLU (), nn .Linear (hidden_size //2 ,pitch_bins ), ) self .tempo_encoder =nn .Sequential ( nn .Linear (1 ,hidden_size //8 ), nn .SiLU (), nn .Linear (hidden_size //8 ,hidden_size //4 ), ) self .rhythm_attn =nn .MultiheadAttention ( embed_dim =hidden_size , num_heads =4 , dropout =0.1 , batch_first =True , ) self .style_transfer =nn .Sequential ( nn .Linear (hidden_size +hidden_size //2 ,hidden_size ), nn .SiLU (), nn .Linear (hidden_size ,hidden_size ), ) self .lyrics_aligner =MonotonicAlignmentSearch (hidden_size ) self .output_proj =nn .Linear (hidden_size ,hidden_size ) print (f" ๐ŸŽต DynamicLatentVocalizer: {num_styles } styles, {num_vocal_modes } modes") def forward ( self , text_features :torch .Tensor , style_id :Optional [torch .Tensor ]=None , mode_id :Optional [torch .Tensor ]=None , target_pitch :Optional [torch .Tensor ]=None , tempo_bpm :Optional [torch .Tensor ]=None , )->dict : """ Generate vocalization features for singing/rapping/etc. Args: text_features: [B, T, hidden_size] text/lyrics embeddings style_id: [B] style indices (0-7) mode_id: [B] vocal mode indices (0-5) target_pitch: [B, T] optional pitch targets tempo_bpm: [B] optional tempo in BPM Returns: dict with: - vocal_features: [B, T', hidden_size] vocalization features - pitch_logits: [B, T, pitch_bins] predicted pitch - alignment: [B, T, T'] text-to-audio alignment """ batch_size ,seq_len ,_ =text_features .shape device =text_features .device if style_id is None : style_id =torch .zeros (batch_size ,dtype =torch .long ,device =device ) if mode_id is None : mode_id =torch .zeros (batch_size ,dtype =torch .long ,device =device ) style_emb =self .style_embed (style_id ).unsqueeze (1 ).expand (-1 ,seq_len ,-1 ) mode_emb =self .mode_embed (mode_id ).unsqueeze (1 ).expand (-1 ,seq_len ,-1 ) if tempo_bpm is not None : tempo_norm =(tempo_bpm .float ()-self .tempo_range [0 ])/(self .tempo_range [1 ]-self .tempo_range [0 ]) tempo_emb =self .tempo_encoder (tempo_norm .unsqueeze (-1 )).unsqueeze (1 ).expand (-1 ,seq_len ,-1 ) else : tempo_emb =torch .zeros (batch_size ,seq_len ,self .hidden_size //4 ,device =device ) pitch_logits =self .pitch_predictor (text_features ) if target_pitch is not None : pitch_emb =self .pitch_embed (target_pitch ) else : pitch_idx =pitch_logits .argmax (dim =-1 ) pitch_emb =self .pitch_embed (pitch_idx ) conditions =torch .cat ([style_emb ,mode_emb ,tempo_emb ,pitch_emb ],dim =-1 ) combined =torch .cat ([text_features ,conditions ],dim =-1 ) vocal_features =self .style_transfer (combined ) vocal_features ,_ =self .rhythm_attn (vocal_features ,vocal_features ,vocal_features ) alignment ,durations =self .lyrics_aligner (text_features ) vocal_features =self .output_proj (vocal_features ) return { "vocal_features":vocal_features , "pitch_logits":pitch_logits , "alignment":alignment , "durations":durations , } class NeuralSoundEffectGenerator (nn .Module ): """ Neural Style Transfer for Sound Effects and Non-verbal Vocalizations. Generates: - Beatboxing (kicks, snares, hi-hats) - Vocal clicks, pops, tongue sounds - Breathing, sighing, gasping - Non-verbal expressions (hmm, aha, wow, etc.) - Polyphonic ad-libs and harmonies """ def __init__ ( self , hidden_size :int =1024 , num_effect_types :int =20 , num_layers :int =3 , ): super ().__init__ () self .hidden_size =hidden_size self .num_effect_types =num_effect_types self .effect_embed =nn .Embedding (num_effect_types ,hidden_size ) self .generator =nn .Sequential ( nn .Linear (hidden_size ,hidden_size *4 ), nn .SiLU (), nn .Unflatten (1 ,(hidden_size ,4 )), nn .ConvTranspose1d (hidden_size ,hidden_size //2 ,4 ,2 ,1 ), nn .SiLU (), nn .ConvTranspose1d (hidden_size //2 ,hidden_size //4 ,4 ,2 ,1 ), nn .SiLU (), nn .ConvTranspose1d (hidden_size //4 ,hidden_size //8 ,4 ,2 ,1 ), nn .SiLU (), nn .ConvTranspose1d (hidden_size //8 ,1 ,4 ,2 ,1 ), nn .Tanh (), ) self .duration_head =nn .Sequential ( nn .Linear (hidden_size ,hidden_size //4 ), nn .SiLU (), nn .Linear (hidden_size //4 ,1 ), nn .Softplus (), ) self .intensity_head =nn .Sequential ( nn .Linear (hidden_size ,hidden_size //4 ), nn .SiLU (), nn .Linear (hidden_size //4 ,1 ), nn .Sigmoid (), ) self .blend_attn =nn .MultiheadAttention ( embed_dim =hidden_size , num_heads =4 , batch_first =True , ) print (f" ๐Ÿฅ NeuralSoundEffectGenerator: {num_effect_types } effect types") def forward ( self , effect_ids :torch .Tensor , context :Optional [torch .Tensor ]=None , intensity :Optional [torch .Tensor ]=None , )->dict : """ Generate sound effect features. Args: effect_ids: [B] or [B, N] effect type indices context: [B, T, hidden_size] optional context features intensity: [B] or [B, N] optional intensity values Returns: dict with: - effect_features: [B, T', hidden_size] generated features - waveform: [B, 1, samples] raw waveform (if generating directly) - duration: [B, 1] predicted duration """ if effect_ids .dim ()==1 : effect_ids =effect_ids .unsqueeze (1 ) batch_size ,num_effects =effect_ids .shape device =effect_ids .device effect_emb =self .effect_embed (effect_ids ) if num_effects >1 : effect_emb ,_ =self .blend_attn (effect_emb ,effect_emb ,effect_emb ) effect_vec =effect_emb .mean (dim =1 ) if context is not None : context_vec =context .mean (dim =1 ) effect_vec =effect_vec +context_vec duration =self .duration_head (effect_vec ) pred_intensity =self .intensity_head (effect_vec ) if intensity is not None : pred_intensity =intensity .unsqueeze (-1 )if intensity .dim ()==1 else intensity effect_vec =effect_vec *pred_intensity waveform =self .generator (effect_vec ) return { "effect_features":effect_emb , "waveform":waveform , "duration":duration , "intensity":pred_intensity , } class SpeculativeAudioDecoder (nn .Module ): """ Mid-stream Token Rewriting support for Speculative Decoding in audio. Allows the model to: - Generate draft audio tokens speculatively - Accept/reject based on user feedback or context change - Rollback and regenerate from checkpoints - Smooth transitions during rewrites """ def __init__ ( self , hidden_size :int =1024 , draft_length :int =10 , num_heads :int =8 , ): super ().__init__ () self .hidden_size =hidden_size self .draft_length =draft_length self .draft_head =nn .Sequential ( nn .Linear (hidden_size ,hidden_size ), nn .SiLU (), nn .Linear (hidden_size ,hidden_size ), ) self .verify_head =nn .Sequential ( nn .Linear (hidden_size *2 ,hidden_size ), nn .SiLU (), nn .Linear (hidden_size ,1 ), nn .Sigmoid (), ) self .checkpoint_encoder =nn .GRU ( input_size =hidden_size , hidden_size =hidden_size , num_layers =1 , batch_first =True , ) self .smoother =nn .Sequential ( nn .Linear (hidden_size *2 ,hidden_size ), nn .SiLU (), nn .Linear (hidden_size ,hidden_size ), ) self .confidence_head =nn .Sequential ( nn .Linear (hidden_size ,hidden_size //4 ), nn .SiLU (), nn .Linear (hidden_size //4 ,1 ), nn .Sigmoid (), ) print (f" โšก SpeculativeAudioDecoder: draft_length={draft_length }") def generate_draft ( self , context :torch .Tensor , num_tokens :int =None , )->Tuple [torch .Tensor ,torch .Tensor ]: """ Generate draft tokens speculatively. Args: context: [B, T, hidden_size] context features num_tokens: number of draft tokens (default: self.draft_length) Returns: draft_tokens: [B, N, hidden_size] draft features confidence: [B, N, 1] confidence per token """ if num_tokens is None : num_tokens =self .draft_length batch_size =context .shape [0 ] device =context .device seed =context [:,-1 :,:] draft_tokens =[] confidences =[] current =seed for _ in range (num_tokens ): draft =self .draft_head (current ) conf =self .confidence_head (draft ) draft_tokens .append (draft ) confidences .append (conf ) current =draft draft_tokens =torch .cat (draft_tokens ,dim =1 ) confidences =torch .cat (confidences ,dim =1 ) return draft_tokens ,confidences def verify_draft ( self , draft_tokens :torch .Tensor , new_context :torch .Tensor , )->torch .Tensor : """ Verify if draft tokens should be accepted given new context. Args: draft_tokens: [B, N, hidden_size] draft features new_context: [B, T, hidden_size] updated context Returns: accept_prob: [B, N, 1] probability to accept each token """ context_summary =new_context .mean (dim =1 ,keepdim =True ).expand (-1 ,draft_tokens .shape [1 ],-1 ) combined =torch .cat ([draft_tokens ,context_summary ],dim =-1 ) accept_prob =self .verify_head (combined ) return accept_prob def create_checkpoint (self ,hidden_state :torch .Tensor )->torch .Tensor : """Save hidden state for potential rollback.""" _ ,checkpoint =self .checkpoint_encoder (hidden_state ) return checkpoint .squeeze (0 ) def smooth_transition ( self , old_features :torch .Tensor , new_features :torch .Tensor , )->torch .Tensor : """Create smooth transition between old and new features.""" combined =torch .cat ([old_features ,new_features ],dim =-1 ) return self .smoother (combined ) def forward ( self , context :torch .Tensor , generate_draft :bool =True , verify_with :Optional [torch .Tensor ]=None , )->dict : """ Full speculative decoding step. Args: context: [B, T, hidden_size] current context generate_draft: whether to generate new draft verify_with: [B, T', hidden_size] new context to verify against Returns: dict with draft tokens, confidence, verification results """ results ={} results ["checkpoint"]=self .create_checkpoint (context ) if generate_draft : draft ,confidence =self .generate_draft (context ) results ["draft_tokens"]=draft results ["confidence"]=confidence if verify_with is not None and "draft_tokens"in results : accept_prob =self .verify_draft (results ["draft_tokens"],verify_with ) results ["accept_prob"]=accept_prob return results ============================================================================== MODELS.GENERATORS.IMAGE ============================================================================== EPS =1e-5 class RoPE2D (nn .Module ): """ 2D Rotary Position Embedding for flexible aspect ratios. Encodes (x, y) spatial positions for patch-based DiT. """ def __init__ (self ,dim :int ,max_height :int =128 ,max_width :int =128 ,base :float =10000.0 ): super ().__init__ () self .dim =dim self .max_height =max_height self .max_width =max_width self .base =base self .dim_x =dim //2 self .dim_y =dim -self .dim_x inv_freq_x =1.0 /(base **(torch .arange (0 ,self .dim_x ,2 ,dtype =torch .float32 )/self .dim_x )) inv_freq_y =1.0 /(base **(torch .arange (0 ,self .dim_y ,2 ,dtype =torch .float32 )/self .dim_y )) self .register_buffer ('inv_freq_x',inv_freq_x ,persistent =False ) self .register_buffer ('inv_freq_y',inv_freq_y ,persistent =False ) def forward (self ,x :torch .Tensor ,height :int ,width :int )->Tuple [torch .Tensor ,torch .Tensor ]: device =x .device dtype =x .dtype pos_x =torch .arange (width ,device =device ,dtype =torch .float32 ) pos_y =torch .arange (height ,device =device ,dtype =torch .float32 ) freqs_x =torch .outer (pos_x ,self .inv_freq_x .to (device )) freqs_y =torch .outer (pos_y ,self .inv_freq_y .to (device )) freqs_x =torch .cat ([freqs_x ,freqs_x ],dim =-1 ) freqs_y =torch .cat ([freqs_y ,freqs_y ],dim =-1 ) cos_2d =torch .zeros (height ,width ,self .dim ,device =device ,dtype =dtype ) sin_2d =torch .zeros (height ,width ,self .dim ,device =device ,dtype =dtype ) for y in range (height ): for w in range (width ): cos_2d [y ,w ,:self .dim_x ]=freqs_x [w ].cos ().to (dtype ) sin_2d [y ,w ,:self .dim_x ]=freqs_x [w ].sin ().to (dtype ) cos_2d [y ,w ,self .dim_x :]=freqs_y [y ].cos ().to (dtype ) sin_2d [y ,w ,self .dim_x :]=freqs_y [y ].sin ().to (dtype ) cos_2d =cos_2d .view (height *width ,self .dim ) sin_2d =sin_2d .view (height *width ,self .dim ) return cos_2d ,sin_2d def apply_rope_2d (x :torch .Tensor ,cos :torch .Tensor ,sin :torch .Tensor )->torch .Tensor : x1 =x [...,:x .shape [-1 ]//2 ] x2 =x [...,x .shape [-1 ]//2 :] rotated =torch .cat ((-x2 ,x1 ),dim =-1 ) return x *cos +rotated *sin class ImageExpert (nn .Module ): """Single expert for DiT with SwiGLU activation.""" def __init__ (self ,hidden_size :int ,intermediate_size :int ): super ().__init__ () self .gate_proj =nn .Linear (hidden_size ,intermediate_size ,bias =False ) self .up_proj =nn .Linear (hidden_size ,intermediate_size ,bias =False ) self .down_proj =nn .Linear (intermediate_size ,hidden_size ,bias =False ) self .act_fn =nn .SiLU () def forward (self ,x :torch .Tensor )->torch .Tensor : return self .down_proj (self .act_fn (self .gate_proj (x ))*self .up_proj (x )) class ImageMoERouter (nn .Module ): """Router for Image MoE with spatial awareness.""" def __init__ (self ,hidden_size :int ,num_experts :int =4 ,top_k :int =2 ): super ().__init__ () self .num_experts =num_experts self .top_k =top_k self .norm =nn .LayerNorm (hidden_size ) self .gate =nn .Linear (hidden_size ,num_experts ,bias =False ) nn .init .normal_ (self .gate .weight ,mean =0.0 ,std =0.01 ) def forward (self ,x :torch .Tensor )->Tuple [torch .Tensor ,torch .Tensor ]: x_norm =self .norm (x ) router_logits =self .gate (x_norm ) router_probs =F .softmax (router_logits ,dim =-1 ,dtype =x .dtype ) top_k_probs ,top_k_indices =torch .topk (router_probs ,self .top_k ,dim =-1 ) top_k_probs =top_k_probs /(top_k_probs .sum (dim =-1 ,keepdim =True )+EPS ) return top_k_probs ,top_k_indices class ImageMoELayer (nn .Module ): """MoE Layer for DiT with shared expert.""" def __init__ (self ,hidden_size :int ,intermediate_size :int ,num_experts :int =4 ,top_k :int =2 ): super ().__init__ () self .hidden_size =hidden_size self .num_experts =num_experts self .top_k =top_k self .router =ImageMoERouter (hidden_size ,num_experts ,top_k ) self .experts =nn .ModuleList ([ ImageExpert (hidden_size ,intermediate_size ) for _ in range (num_experts ) ]) self .shared_expert =ImageExpert (hidden_size ,intermediate_size ) def forward (self ,x :torch .Tensor )->torch .Tensor : batch_size ,seq_len ,hidden_size =x .shape x_flat =x .view (-1 ,hidden_size ) top_k_probs ,top_k_indices =self .router (x_flat ) output =torch .zeros_like (x_flat ) for expert_idx in range (self .num_experts ): expert =self .experts [expert_idx ] for k in range (self .top_k ): mask =(top_k_indices [:,k ]==expert_idx ) if mask .any (): expert_input =x_flat [mask ] expert_output =expert (expert_input ) weight =top_k_probs [mask ,k :k +1 ] output [mask ]=output [mask ]+weight *expert_output shared_output =self .shared_expert (x_flat ) output =output +shared_output return output .view (batch_size ,seq_len ,hidden_size ) class DualStreamSelfAttention (nn .Module ): """ Symmetric Dual-Stream Self-Attention (SD3/Flux-style). Two parallel streams with cross-stream information exchange. Uses Flash Attention 2.0 via SDPA for O(N) memory. """ def __init__ (self ,hidden_size :int ,num_heads :int =8 ,max_height :int =64 ,max_width :int =64 ): super ().__init__ () self .hidden_size =hidden_size self .num_heads =num_heads self .head_dim =hidden_size //num_heads self .scale =self .head_dim **-0.5 self ._qk_scale =self .head_dim **-0.25 self .to_qkv_a =nn .Linear (hidden_size ,hidden_size *3 ,bias =False ) self .to_qkv_b =nn .Linear (hidden_size ,hidden_size *3 ,bias =False ) self .to_out_a =nn .Linear (hidden_size ,hidden_size ,bias =False ) self .to_out_b =nn .Linear (hidden_size ,hidden_size ,bias =False ) self .norm_a =nn .LayerNorm (hidden_size ) self .norm_b =nn .LayerNorm (hidden_size ) self .rope_2d =RoPE2D (self .head_dim ,max_height ,max_width ) def forward (self ,x_a :torch .Tensor ,x_b :torch .Tensor ,height :int ,width :int )->Tuple [torch .Tensor ,torch .Tensor ]: batch_size ,seq_len ,_ =x_a .shape x_a =self .norm_a (x_a ) x_b =self .norm_b (x_b ) qkv_a =self .to_qkv_a (x_a ).reshape (batch_size ,seq_len ,3 ,self .num_heads ,self .head_dim ) qkv_b =self .to_qkv_b (x_b ).reshape (batch_size ,seq_len ,3 ,self .num_heads ,self .head_dim ) q_a ,k_a ,v_a =qkv_a .unbind (dim =2 ) q_b ,k_b ,v_b =qkv_b .unbind (dim =2 ) cos ,sin =self .rope_2d (x_a ,height ,width ) cos =cos .unsqueeze (0 ).unsqueeze (1 ) sin =sin .unsqueeze (0 ).unsqueeze (1 ) q_a =q_a .transpose (1 ,2 ) k_a =k_a .transpose (1 ,2 ) v_a =v_a .transpose (1 ,2 ) q_b =q_b .transpose (1 ,2 ) k_b =k_b .transpose (1 ,2 ) v_b =v_b .transpose (1 ,2 ) q_a =apply_rope_2d (q_a ,cos ,sin ) k_a =apply_rope_2d (k_a ,cos ,sin ) q_b =apply_rope_2d (q_b ,cos ,sin ) k_b =apply_rope_2d (k_b ,cos ,sin ) k_combined =torch .cat ([k_a ,k_b ],dim =2 ) v_combined =torch .cat ([v_a ,v_b ],dim =2 ) out_a =F .scaled_dot_product_attention ( q_a *self ._qk_scale ,k_combined *self ._qk_scale ,v_combined , is_causal =False ,scale =1.0 , ) out_b =F .scaled_dot_product_attention ( q_b *self ._qk_scale ,k_combined *self ._qk_scale ,v_combined , is_causal =False ,scale =1.0 , ) out_a =out_a .transpose (1 ,2 ).reshape (batch_size ,seq_len ,self .hidden_size ) out_b =out_b .transpose (1 ,2 ).reshape (batch_size ,seq_len ,self .hidden_size ) out_a =self .to_out_a (out_a ) out_b =self .to_out_b (out_b ) return out_a ,out_b class CrossAttention (nn .Module ): """Cross-attention for text conditioning.""" def __init__ (self ,query_dim :int ,context_dim :int =None ,heads :int =8 ): super ().__init__ () self .heads =heads context_dim =context_dim or query_dim self .head_dim =query_dim //heads self .scale =self .head_dim **-0.5 self .norm =nn .LayerNorm (query_dim ) self .to_q =nn .Linear (query_dim ,query_dim ,bias =False ) self .to_k =nn .Linear (context_dim ,query_dim ,bias =False ) self .to_v =nn .Linear (context_dim ,query_dim ,bias =False ) self .to_out =nn .Linear (query_dim ,query_dim ,bias =False ) def forward (self ,x :torch .Tensor ,context :torch .Tensor )->torch .Tensor : batch_size ,seq_len ,_ =x .shape ctx_len =context .shape [1 ] x =self .norm (x ) q =self .to_q (x ).reshape (batch_size ,seq_len ,self .heads ,self .head_dim ).transpose (1 ,2 ) k =self .to_k (context ).reshape (batch_size ,ctx_len ,self .heads ,self .head_dim ).transpose (1 ,2 ) v =self .to_v (context ).reshape (batch_size ,ctx_len ,self .heads ,self .head_dim ).transpose (1 ,2 ) qk_scale =self .head_dim **-0.25 out =F .scaled_dot_product_attention ( q *qk_scale ,k *qk_scale ,v , is_causal =False ,scale =1.0 , ) out =out .transpose (1 ,2 ).reshape (batch_size ,seq_len ,-1 ) out =self .to_out (out ) return out class DiTBlock (nn .Module ): """ DiT Block with Dual-Stream Attention and MoE FFN. """ def __init__ (self ,hidden_size :int ,context_dim :int ,num_heads :int =8 ,num_experts :int =4 ,max_height :int =64 ,max_width :int =64 ): super ().__init__ () self .dual_attn =DualStreamSelfAttention (hidden_size ,num_heads ,max_height ,max_width ) self .cross_attn_a =CrossAttention (hidden_size ,context_dim ,num_heads ) self .cross_attn_b =CrossAttention (hidden_size ,context_dim ,num_heads ) self .moe_a =ImageMoELayer (hidden_size ,hidden_size *4 ,num_experts ) self .moe_b =ImageMoELayer (hidden_size ,hidden_size *4 ,num_experts ) self .adaLN_a =nn .Sequential ( nn .SiLU (), nn .Linear (hidden_size ,hidden_size *6 ), ) self .adaLN_b =nn .Sequential ( nn .SiLU (), nn .Linear (hidden_size ,hidden_size *6 ), ) self .norm1_a =nn .LayerNorm (hidden_size ,elementwise_affine =False ) self .norm1_b =nn .LayerNorm (hidden_size ,elementwise_affine =False ) self .norm2_a =nn .LayerNorm (hidden_size ,elementwise_affine =False ) self .norm2_b =nn .LayerNorm (hidden_size ,elementwise_affine =False ) def forward (self ,x_a :torch .Tensor ,x_b :torch .Tensor ,context :torch .Tensor ,t_emb :torch .Tensor ,height :int ,width :int )->Tuple [torch .Tensor ,torch .Tensor ]: shift_a ,scale_a ,gate_a ,shift2_a ,scale2_a ,gate2_a =self .adaLN_a (t_emb ).chunk (6 ,dim =-1 ) shift_b ,scale_b ,gate_b ,shift2_b ,scale2_b ,gate2_b =self .adaLN_b (t_emb ).chunk (6 ,dim =-1 ) shift_a =shift_a .unsqueeze (1 ) scale_a =scale_a .unsqueeze (1 ) gate_a =gate_a .unsqueeze (1 ) shift2_a =shift2_a .unsqueeze (1 ) scale2_a =scale2_a .unsqueeze (1 ) gate2_a =gate2_a .unsqueeze (1 ) shift_b =shift_b .unsqueeze (1 ) scale_b =scale_b .unsqueeze (1 ) gate_b =gate_b .unsqueeze (1 ) shift2_b =shift2_b .unsqueeze (1 ) scale2_b =scale2_b .unsqueeze (1 ) gate2_b =gate2_b .unsqueeze (1 ) x_a_norm =self .norm1_a (x_a )*(1 +scale_a )+shift_a x_b_norm =self .norm1_b (x_b )*(1 +scale_b )+shift_b attn_out_a ,attn_out_b =self .dual_attn (x_a_norm ,x_b_norm ,height ,width ) x_a =x_a +gate_a *attn_out_a x_b =x_b +gate_b *attn_out_b x_a =x_a +self .cross_attn_a (x_a ,context ) x_b =x_b +self .cross_attn_b (x_b ,context ) x_a_norm =self .norm2_a (x_a )*(1 +scale2_a )+shift2_a x_b_norm =self .norm2_b (x_b )*(1 +scale2_b )+shift2_b x_a =x_a +gate2_a *self .moe_a (x_a_norm ) x_b =x_b +gate2_b *self .moe_b (x_b_norm ) return x_a ,x_b class FlowMatchingScheduler : """Flow Matching scheduler for image generation.""" def __init__ (self ,num_steps :int =50 ,sigma_min :float =0.002 ): self .num_steps =num_steps self .sigma_min =sigma_min self .timesteps =torch .linspace (1 ,0 ,num_steps +1 ) def get_velocity (self ,x_t :torch .Tensor ,x_0 :torch .Tensor ,t :torch .Tensor )->torch .Tensor : return x_0 -x_t def step (self ,model_output :torch .Tensor ,t :torch .Tensor ,t_prev :torch .Tensor ,x_t :torch .Tensor )->torch .Tensor : dt =t -t_prev x_prev =x_t +model_output *dt .view (-1 ,1 ,1 ,1 ) return x_prev def add_noise (self ,x_0 :torch .Tensor ,t :torch .Tensor )->torch .Tensor : noise =torch .randn_like (x_0 ) t =t .to (x_0 .dtype ).view (-1 ,1 ,1 ,1 ) x_t =t *noise +(1 -t )*x_0 return x_t class PatchEmbed (nn .Module ): """Patch embedding for DiT.""" def __init__ (self ,patch_size :int =2 ,in_channels :int =4 ,hidden_size :int =512 ): super ().__init__ () self .patch_size =patch_size self .proj =nn .Conv2d (in_channels ,hidden_size ,kernel_size =patch_size ,stride =patch_size ) def forward (self ,x :torch .Tensor )->torch .Tensor : x =self .proj (x ) x =x .flatten (2 ).transpose (1 ,2 ) return x class UnpatchEmbed (nn .Module ): """Unpatch embedding to reconstruct image from patches.""" def __init__ (self ,patch_size :int =2 ,out_channels :int =4 ,hidden_size :int =512 ): super ().__init__ () self .patch_size =patch_size self .out_channels =out_channels self .proj =nn .Linear (hidden_size ,patch_size *patch_size *out_channels ) def forward (self ,x :torch .Tensor ,height :int ,width :int )->torch .Tensor : x =self .proj (x ) batch_size =x .shape [0 ] x =x .reshape (batch_size ,height ,width ,self .patch_size ,self .patch_size ,self .out_channels ) x =x .permute (0 ,5 ,1 ,3 ,2 ,4 ).reshape (batch_size ,self .out_channels ,height *self .patch_size ,width *self .patch_size ) return x class MoEDiT (nn .Module ): """ MoE Diffusion Transformer with Dual-Stream Attention. """ def __init__ ( self , in_channels :int =4 , out_channels :int =4 , hidden_size :int =512 , context_dim :int =1024 , num_layers :int =8 , num_heads :int =8 , num_experts :int =4 , patch_size :int =2 , max_image_size :int =64 , ): super ().__init__ () self .hidden_size =hidden_size self .patch_size =patch_size max_patches =max_image_size //patch_size self .time_embed =nn .Sequential ( nn .Linear (hidden_size ,hidden_size *4 ), nn .SiLU (), nn .Linear (hidden_size *4 ,hidden_size ), ) self .patch_embed =PatchEmbed (patch_size ,in_channels ,hidden_size ) self .context_proj =nn .Linear (context_dim ,hidden_size ) self .blocks =nn .ModuleList ([ DiTBlock (hidden_size ,hidden_size ,num_heads ,num_experts ,max_patches ,max_patches ) for _ in range (num_layers ) ]) self .final_norm =nn .LayerNorm (hidden_size ) self .unpatch_embed =UnpatchEmbed (patch_size ,out_channels ,hidden_size ) self .gradient_checkpointing =False self ._init_weights () def _init_weights (self ): nn .init .zeros_ (self .unpatch_embed .proj .weight ) nn .init .zeros_ (self .unpatch_embed .proj .bias ) def enable_gradient_checkpointing (self ): """Enable gradient checkpointing for memory efficiency.""" self .gradient_checkpointing =True def forward (self ,x :torch .Tensor ,timesteps :torch .Tensor ,context :torch .Tensor ,mask :Optional [torch .Tensor ]=None )->torch .Tensor : batch_size ,channels ,height ,width =x .shape patch_height =height //self .patch_size patch_width =width //self .patch_size half_dim =self .hidden_size //2 t_emb =math .log (10000 )/(half_dim -1 ) t_emb =torch .exp (torch .arange (half_dim ,device =x .device ,dtype =x .dtype )*-t_emb ) t_emb =timesteps [:,None ].to (x .dtype )*t_emb [None ,:] t_emb =torch .cat ([torch .sin (t_emb ),torch .cos (t_emb )],dim =-1 ) t_emb =self .time_embed (t_emb ) x_patches =self .patch_embed (x ) context_proj =self .context_proj (context ) x_a =x_patches x_b =x_patches .clone () for block in self .blocks : if self .gradient_checkpointing and self .training : x_a ,x_b =torch .utils .checkpoint .checkpoint ( block ,x_a ,x_b ,context_proj ,t_emb ,patch_height ,patch_width , use_reentrant =False ) else : x_a ,x_b =block (x_a ,x_b ,context_proj ,t_emb ,patch_height ,patch_width ) x_combined =(x_a +x_b )/2 x_combined =self .final_norm (x_combined ) velocity =self .unpatch_embed (x_combined ,patch_height ,patch_width ) return velocity class ImageVAE (nn .Module ): """Lightweight VAE for image encoding/decoding.""" def __init__ (self ,in_channels :int =3 ,latent_channels :int =4 ,base_channels :int =64 ): super ().__init__ () self .encoder =nn .Sequential ( nn .Conv2d (in_channels ,base_channels ,3 ,padding =1 ), nn .SiLU (), nn .Conv2d (base_channels ,base_channels *2 ,3 ,stride =2 ,padding =1 ), nn .SiLU (), nn .Conv2d (base_channels *2 ,base_channels *4 ,3 ,stride =2 ,padding =1 ), nn .SiLU (), nn .Conv2d (base_channels *4 ,latent_channels *2 ,3 ,padding =1 ), ) self .decoder =nn .Sequential ( nn .Conv2d (latent_channels ,base_channels *4 ,3 ,padding =1 ), nn .SiLU (), nn .Upsample (scale_factor =2 ,mode ='bilinear',align_corners =False ), nn .Conv2d (base_channels *4 ,base_channels *2 ,3 ,padding =1 ), nn .SiLU (), nn .Upsample (scale_factor =2 ,mode ='bilinear',align_corners =False ), nn .Conv2d (base_channels *2 ,base_channels ,3 ,padding =1 ), nn .SiLU (), nn .Conv2d (base_channels ,in_channels ,3 ,padding =1 ), ) def encode (self ,x :torch .Tensor )->Tuple [torch .Tensor ,torch .Tensor ,torch .Tensor ]: h =self .encoder (x ) mean ,logvar =h .chunk (2 ,dim =1 ) logvar =torch .clamp (logvar ,-30 ,20 ) std =torch .exp (0.5 *logvar ) z =mean +std *torch .randn_like (std ) return z ,mean ,logvar def decode (self ,z :torch .Tensor )->torch .Tensor : return self .decoder (z ) class MobileDiffusionGenerator (nn .Module ): """ SOTA Image Diffusion with MoE-DiT, Flow Matching, 2D-RoPE, Dual-Stream. Optimized for 2x T4 GPUs with FP16. """ def __init__ ( self , latent_channels :int =4 , base_channels :int =128 , context_dim :int =1024 , num_inference_steps :int =50 , image_size :int =256 , cfg_scale :float =7.5 , ): super ().__init__ () self .latent_channels =latent_channels self .context_dim =context_dim self .image_size =image_size self .latent_size =image_size //4 self .num_inference_steps =num_inference_steps self .cfg_scale =cfg_scale self .vae_encoder =ImageVAE (3 ,latent_channels ,base_channels //2 ) self .vae_decoder =self .vae_encoder self .unet =MoEDiT ( in_channels =latent_channels , out_channels =latent_channels , hidden_size =base_channels *4 , context_dim =context_dim , num_layers =8 , num_heads =8 , num_experts =4 , patch_size =2 , max_image_size =self .latent_size , ) self .scheduler =FlowMatchingScheduler (num_inference_steps ) def encode (self ,x :torch .Tensor )->Tuple [torch .Tensor ,torch .Tensor ,torch .Tensor ]: return self .vae_encoder .encode (x ) def decode (self ,z :torch .Tensor )->torch .Tensor : return self .vae_decoder .decode (z ) def training_step (self ,images :torch .Tensor ,context :torch .Tensor ,mask :Optional [torch .Tensor ]=None )->dict : device =images .device dtype =images .dtype batch_size =images .shape [0 ] z ,mean ,logvar =self .encode (images *2 -1 ) del images t =torch .rand (batch_size ,device =device ,dtype =dtype ) x_t =self .scheduler .add_noise (z ,t ) target_velocity =self .scheduler .get_velocity (x_t ,z ,t ) if self .training : drop_mask =torch .rand (batch_size ,device =device )<0.1 drop_mask_expanded =drop_mask .view (batch_size ,1 ,1 ).expand_as (context ) null_ctx =torch .zeros_like (context ) context =torch .where (drop_mask_expanded ,null_ctx ,context ) del drop_mask ,drop_mask_expanded ,null_ctx pred_velocity =self .unet (x_t ,(t *1000 ).to (dtype ),context ,mask ) del x_t ,context flow_loss =F .mse_loss (pred_velocity ,target_velocity ) del pred_velocity ,target_velocity kl_loss =-0.5 *torch .mean (1 +logvar -mean .pow (2 )-logvar .exp ()) del z ,mean ,logvar total_loss =flow_loss +0.0001 *kl_loss return { 'flow_loss':flow_loss , 'kl_loss':kl_loss , 'total_loss':total_loss , } @torch .no_grad () def generate (self ,context :torch .Tensor ,guidance_scale :float =None ,num_steps :int =None ,init_latents :Optional [torch .Tensor ]=None ,mask :Optional [torch .Tensor ]=None ,masked_image_latents :Optional [torch .Tensor ]=None )->torch .Tensor : device =context .device batch_size =context .shape [0 ] seq_len =context .shape [1 ] guidance_scale =guidance_scale or self .cfg_scale num_steps =num_steps or self .num_inference_steps if init_latents is not None : latents =init_latents else : latents =torch .randn (batch_size ,self .latent_channels ,self .latent_size ,self .latent_size ,device =device ) timesteps =torch .linspace (1 ,0 ,num_steps +1 ,device =device ) if guidance_scale >1.0 : null_ctx =torch .zeros (batch_size ,seq_len ,self .context_dim ,device =device ,dtype =context .dtype ) context =torch .cat ([null_ctx ,context ]) for i in range (num_steps ): t =timesteps [i ] t_prev =timesteps [i +1 ] t_batch =t .expand (batch_size )*1000 if guidance_scale >1.0 : latent_input =torch .cat ([latents ,latents ]) t_input =torch .cat ([t_batch ,t_batch ]) velocity_pred =self .unet (latent_input ,t_input ,context ,mask ) velocity_uncond ,velocity_cond =velocity_pred .chunk (2 ) velocity_pred =velocity_uncond +guidance_scale *(velocity_cond -velocity_uncond ) else : velocity_pred =self .unet (latents ,t_batch ,context ,mask ) latents =self .scheduler .step (velocity_pred ,t ,t_prev ,latents ) if mask is not None and masked_image_latents is not None : latents =masked_image_latents *mask +latents *(1 -mask ) images =self .decode (latents ) images =(images +1 )/2 return torch .clamp (images ,0 ,1 ) @torch .no_grad () def edit_image (self ,image :torch .Tensor ,context :torch .Tensor ,mask :torch .Tensor ,strength :float =0.8 ,guidance_scale :float =None )->torch .Tensor : device =image .device image_norm =image *2 -1 z ,_ ,_ =self .encode (image_norm ) mask_latent =F .interpolate (mask ,size =(self .latent_size ,self .latent_size ),mode ='nearest') num_steps =int (self .num_inference_steps *strength ) t =torch .tensor ([strength ],device =device ) noisy_z =self .scheduler .add_noise (z ,t .expand (z .shape [0 ])) return self .generate ( context , guidance_scale =guidance_scale , num_steps =num_steps , init_latents =noisy_z , mask =mask_latent , masked_image_latents =z , ) ============================================================================== MODELS.GENERATORS.VIDEO ============================================================================== EPS =1e-5 class InterleavedMRoPE (nn .Module ): """ Interleaved Multi-dimensional Rotary Position Embedding (MRoPE). SOTA: Full-frequency allocation over time, width, and height via robust positional embeddings. Unlike separate spatial and temporal RoPE, Interleaved-MRoPE allocates frequencies across all three dimensions jointly, enhancing long-horizon video reasoning. Key advantages: - Better temporal-spatial correlation modeling - More robust for variable aspect ratios and frame counts - Improved long-range video understanding """ def __init__ (self ,dim :int ,max_height :int =64 ,max_width :int =64 ,max_frames :int =64 ,base :float =10000.0 ): super ().__init__ () self .dim =dim self .max_height =max_height self .max_width =max_width self .max_frames =max_frames self .base =base self .dim_t =dim //3 self .dim_y =dim //3 self .dim_x =dim -self .dim_t -self .dim_y inv_freq_t =1.0 /(base **(torch .arange (0 ,self .dim_t ,2 ,dtype =torch .float32 )/self .dim_t )) inv_freq_y =1.0 /(base **(torch .arange (0 ,self .dim_y ,2 ,dtype =torch .float32 )/self .dim_y )) inv_freq_x =1.0 /(base **(torch .arange (0 ,self .dim_x ,2 ,dtype =torch .float32 )/self .dim_x )) self .register_buffer ('inv_freq_t',inv_freq_t ,persistent =False ) self .register_buffer ('inv_freq_y',inv_freq_y ,persistent =False ) self .register_buffer ('inv_freq_x',inv_freq_x ,persistent =False ) def forward (self ,x :torch .Tensor ,height :int ,width :int ,num_frames :int )->Tuple [torch .Tensor ,torch .Tensor ]: """ Compute interleaved 3D positional embeddings. Args: x: Input tensor for device/dtype reference height: Spatial height width: Spatial width num_frames: Temporal frames Returns: cos, sin: [T * H * W, dim] positional embeddings """ device =x .device dtype =x .dtype pos_t =torch .arange (num_frames ,device =device ,dtype =torch .float32 ) pos_y =torch .arange (height ,device =device ,dtype =torch .float32 ) pos_x =torch .arange (width ,device =device ,dtype =torch .float32 ) freqs_t =torch .outer (pos_t ,self .inv_freq_t .to (device )) freqs_y =torch .outer (pos_y ,self .inv_freq_y .to (device )) freqs_x =torch .outer (pos_x ,self .inv_freq_x .to (device )) freqs_t =torch .cat ([freqs_t ,freqs_t ],dim =-1 ) freqs_y =torch .cat ([freqs_y ,freqs_y ],dim =-1 ) freqs_x =torch .cat ([freqs_x ,freqs_x ],dim =-1 ) seq_len =num_frames *height *width cos_3d =torch .zeros (num_frames ,height ,width ,self .dim ,device =device ,dtype =dtype ) sin_3d =torch .zeros (num_frames ,height ,width ,self .dim ,device =device ,dtype =dtype ) for t in range (num_frames ): for h in range (height ): for w in range (width ): cos_3d [t ,h ,w ,:self .dim_t ]=freqs_t [t ].cos ().to (dtype ) sin_3d [t ,h ,w ,:self .dim_t ]=freqs_t [t ].sin ().to (dtype ) cos_3d [t ,h ,w ,self .dim_t :self .dim_t +self .dim_y ]=freqs_y [h ].cos ().to (dtype ) sin_3d [t ,h ,w ,self .dim_t :self .dim_t +self .dim_y ]=freqs_y [h ].sin ().to (dtype ) cos_3d [t ,h ,w ,self .dim_t +self .dim_y :]=freqs_x [w ].cos ().to (dtype ) sin_3d [t ,h ,w ,self .dim_t +self .dim_y :]=freqs_x [w ].sin ().to (dtype ) cos_3d =cos_3d .view (seq_len ,self .dim ) sin_3d =sin_3d .view (seq_len ,self .dim ) return cos_3d ,sin_3d class RoPE2D (nn .Module ): """ 2D Rotary Position Embedding for spatial dimensions (memory efficient). Used for spatial attention in factorized video attention. """ def __init__ (self ,dim :int ,max_height :int =64 ,max_width :int =64 ,base :float =10000.0 ): super ().__init__ () self .dim =dim self .dim_x =dim //2 self .dim_y =dim -self .dim_x inv_freq_x =1.0 /(base **(torch .arange (0 ,self .dim_x ,2 ,dtype =torch .float32 )/self .dim_x )) inv_freq_y =1.0 /(base **(torch .arange (0 ,self .dim_y ,2 ,dtype =torch .float32 )/self .dim_y )) self .register_buffer ('inv_freq_x',inv_freq_x ,persistent =False ) self .register_buffer ('inv_freq_y',inv_freq_y ,persistent =False ) def forward (self ,x :torch .Tensor ,height :int ,width :int )->Tuple [torch .Tensor ,torch .Tensor ]: device =x .device dtype =x .dtype pos_x =torch .arange (width ,device =device ,dtype =torch .float32 ) pos_y =torch .arange (height ,device =device ,dtype =torch .float32 ) freqs_x =torch .outer (pos_x ,self .inv_freq_x .to (device )) freqs_y =torch .outer (pos_y ,self .inv_freq_y .to (device )) cos_x =torch .cat ([freqs_x .cos (),freqs_x .cos ()],dim =-1 ) sin_x =torch .cat ([freqs_x .sin (),freqs_x .sin ()],dim =-1 ) cos_y =torch .cat ([freqs_y .cos (),freqs_y .cos ()],dim =-1 ) sin_y =torch .cat ([freqs_y .sin (),freqs_y .sin ()],dim =-1 ) cos_2d =torch .zeros (height ,width ,self .dim ,device =device ,dtype =dtype ) sin_2d =torch .zeros (height ,width ,self .dim ,device =device ,dtype =dtype ) cos_2d [:,:,:self .dim_x ]=cos_x .unsqueeze (0 ).expand (height ,-1 ,-1 ) sin_2d [:,:,:self .dim_x ]=sin_x .unsqueeze (0 ).expand (height ,-1 ,-1 ) cos_2d [:,:,self .dim_x :]=cos_y .unsqueeze (1 ).expand (-1 ,width ,-1 ) sin_2d [:,:,self .dim_x :]=sin_y .unsqueeze (1 ).expand (-1 ,width ,-1 ) return cos_2d .view (height *width ,self .dim ).to (dtype ),sin_2d .view (height *width ,self .dim ).to (dtype ) class RoPE1D (nn .Module ): """ 1D Rotary Position Embedding for temporal dimension. Used for temporal attention in factorized video attention. """ def __init__ (self ,dim :int ,max_len :int =64 ,base :float =10000.0 ): super ().__init__ () self .dim =dim inv_freq =1.0 /(base **(torch .arange (0 ,dim ,2 ,dtype =torch .float32 )/dim )) self .register_buffer ('inv_freq',inv_freq ,persistent =False ) def forward (self ,x :torch .Tensor ,seq_len :int )->Tuple [torch .Tensor ,torch .Tensor ]: device =x .device dtype =x .dtype pos =torch .arange (seq_len ,device =device ,dtype =torch .float32 ) freqs =torch .outer (pos ,self .inv_freq .to (device )) freqs =torch .cat ([freqs ,freqs ],dim =-1 ) return freqs .cos ().to (dtype ),freqs .sin ().to (dtype ) def apply_rope (x :torch .Tensor ,cos :torch .Tensor ,sin :torch .Tensor )->torch .Tensor : """Apply rotary position embedding.""" x1 =x [...,:x .shape [-1 ]//2 ] x2 =x [...,x .shape [-1 ]//2 :] rotated =torch .cat ((-x2 ,x1 ),dim =-1 ) return x *cos +rotated *sin class TemporalExpertRouter (nn .Module ): """ Temporal-Aware Expert Router for video generation. Routes tokens based on temporal context and motion patterns. """ def __init__ (self ,hidden_size :int ,num_experts :int =4 ,top_k :int =2 ): super ().__init__ () self .num_experts =num_experts self .top_k =top_k self .temporal_proj =nn .Linear (hidden_size ,hidden_size ) self .gate =nn .Linear (hidden_size ,num_experts ,bias =False ) nn .init .normal_ (self .gate .weight ,mean =0.0 ,std =0.01 ) def forward (self ,x :torch .Tensor ,temporal_context :Optional [torch .Tensor ]=None )->Tuple [torch .Tensor ,torch .Tensor ]: if temporal_context is not None : x =x +self .temporal_proj (temporal_context ) router_logits =self .gate (x ) router_probs =F .softmax (router_logits ,dim =-1 ,dtype =x .dtype ) top_k_probs ,top_k_indices =torch .topk (router_probs ,self .top_k ,dim =-1 ) top_k_probs =top_k_probs /(top_k_probs .sum (dim =-1 ,keepdim =True )+EPS ) return top_k_probs ,top_k_indices class VideoExpert (nn .Module ): """Single expert for video processing with SwiGLU.""" def __init__ (self ,hidden_size :int ,intermediate_size :int ): super ().__init__ () self .gate_proj =nn .Linear (hidden_size ,intermediate_size ,bias =False ) self .up_proj =nn .Linear (hidden_size ,intermediate_size ,bias =False ) self .down_proj =nn .Linear (intermediate_size ,hidden_size ,bias =False ) self .act_fn =nn .SiLU () def forward (self ,x :torch .Tensor )->torch .Tensor : return self .down_proj (self .act_fn (self .gate_proj (x ))*self .up_proj (x )) class TemporalMoELayer (nn .Module ): """ Temporal-Aware MoE Layer for video generation. Uses motion-aware routing for expert selection. """ def __init__ (self ,hidden_size :int ,intermediate_size :int ,num_experts :int =4 ,top_k :int =2 ): super ().__init__ () self .hidden_size =hidden_size self .num_experts =num_experts self .top_k =top_k self .router =TemporalExpertRouter (hidden_size ,num_experts ,top_k ) self .experts =nn .ModuleList ([ VideoExpert (hidden_size ,intermediate_size ) for _ in range (num_experts ) ]) self .shared_expert =VideoExpert (hidden_size ,intermediate_size ) def forward (self ,x :torch .Tensor ,temporal_context :Optional [torch .Tensor ]=None )->torch .Tensor : batch_size ,seq_len ,hidden_size =x .shape x_flat =x .view (-1 ,hidden_size ) top_k_probs ,top_k_indices =self .router (x_flat ,temporal_context .view (-1 ,hidden_size )if temporal_context is not None else None ) output =torch .zeros_like (x_flat ) for expert_idx in range (self .num_experts ): expert =self .experts [expert_idx ] for k in range (self .top_k ): mask =(top_k_indices [:,k ]==expert_idx ) if mask .any (): expert_input =x_flat [mask ] expert_output =expert (expert_input ) weight =top_k_probs [mask ,k :k +1 ] output [mask ]=output [mask ]+weight *expert_output shared_output =self .shared_expert (x_flat ) output =output +shared_output return output .view (batch_size ,seq_len ,hidden_size ) class SpatialAttention (nn .Module ): """ Spatial self-attention: each frame attends only within itself. Memory: O(T * (H*W)^2) instead of O((T*H*W)^2) """ def __init__ (self ,hidden_size :int ,num_heads :int =8 ,max_height :int =64 ,max_width :int =64 ): super ().__init__ () self .hidden_size =hidden_size self .num_heads =num_heads self .head_dim =hidden_size //num_heads self .scale =self .head_dim **-0.5 self .to_qkv =nn .Linear (hidden_size ,hidden_size *3 ,bias =False ) self .to_out =nn .Linear (hidden_size ,hidden_size ,bias =False ) self .rope_2d =RoPE2D (self .head_dim ,max_height ,max_width ) self .norm =nn .LayerNorm (hidden_size ) def forward (self ,x :torch .Tensor ,height :int ,width :int ,frames :int )->torch .Tensor : batch_size ,seq_len ,_ =x .shape spatial_len =height *width x =self .norm (x ) x =x .view (batch_size *frames ,spatial_len ,self .hidden_size ) qkv =self .to_qkv (x ).reshape (batch_size *frames ,spatial_len ,3 ,self .num_heads ,self .head_dim ) q ,k ,v =qkv .unbind (dim =2 ) cos ,sin =self .rope_2d (x ,height ,width ) cos =cos .unsqueeze (0 ).unsqueeze (1 ) sin =sin .unsqueeze (0 ).unsqueeze (1 ) q =q .transpose (1 ,2 ) k =k .transpose (1 ,2 ) v =v .transpose (1 ,2 ) q =apply_rope (q ,cos ,sin ) k =apply_rope (k ,cos ,sin ) qk_scale =self .head_dim **-0.25 out =F .scaled_dot_product_attention ( q *qk_scale ,k *qk_scale ,v , is_causal =False ,scale =1.0 , ) out =out .transpose (1 ,2 ).reshape (batch_size *frames ,spatial_len ,self .hidden_size ) out =self .to_out (out ) return out .view (batch_size ,seq_len ,self .hidden_size ) class TemporalAttention (nn .Module ): """ Temporal self-attention: each spatial position attends across time. Memory: O(H*W * T^2) instead of O((T*H*W)^2) """ def __init__ (self ,hidden_size :int ,num_heads :int =8 ,max_frames :int =32 ): super ().__init__ () self .hidden_size =hidden_size self .num_heads =num_heads self .head_dim =hidden_size //num_heads self .scale =self .head_dim **-0.5 self .to_qkv =nn .Linear (hidden_size ,hidden_size *3 ,bias =False ) self .to_out =nn .Linear (hidden_size ,hidden_size ,bias =False ) self .rope_1d =RoPE1D (self .head_dim ,max_frames ) self .norm =nn .LayerNorm (hidden_size ) def forward (self ,x :torch .Tensor ,height :int ,width :int ,frames :int ,causal :bool =True )->torch .Tensor : batch_size ,seq_len ,_ =x .shape spatial_len =height *width x =self .norm (x ) x =x .view (batch_size ,frames ,spatial_len ,self .hidden_size ) x =x .permute (0 ,2 ,1 ,3 ).reshape (batch_size *spatial_len ,frames ,self .hidden_size ) qkv =self .to_qkv (x ).reshape (batch_size *spatial_len ,frames ,3 ,self .num_heads ,self .head_dim ) q ,k ,v =qkv .unbind (dim =2 ) cos ,sin =self .rope_1d (x ,frames ) cos =cos .unsqueeze (0 ).unsqueeze (1 ) sin =sin .unsqueeze (0 ).unsqueeze (1 ) q =q .transpose (1 ,2 ) k =k .transpose (1 ,2 ) v =v .transpose (1 ,2 ) q =apply_rope (q ,cos ,sin ) k =apply_rope (k ,cos ,sin ) qk_scale =self .head_dim **-0.25 out =F .scaled_dot_product_attention ( q *qk_scale ,k *qk_scale ,v , is_causal =causal ,scale =1.0 , ) out =out .transpose (1 ,2 ).reshape (batch_size *spatial_len ,frames ,self .hidden_size ) out =out .view (batch_size ,spatial_len ,frames ,self .hidden_size ) out =out .permute (0 ,2 ,1 ,3 ).reshape (batch_size ,seq_len ,self .hidden_size ) out =self .to_out (out ) return out class FactorizedSpatioTemporalAttention (nn .Module ): """ Factorized Spatial-Temporal Attention (like CogVideo, Open-Sora, SVD). Instead of full 3D attention O((T*H*W)^2), uses: 1. Spatial attention per frame: O(T * (H*W)^2) 2. Temporal attention per position: O(H*W * T^2) Total: O(T*(H*W)^2 + H*W*T^2) << O((T*H*W)^2) For T=8, H=W=64: - Full 3D: 32768^2 = 1B attention scores - Factorized: 8*4096^2 + 4096*64 = 134M attention scores (7.5x less!) """ def __init__ (self ,hidden_size :int ,num_heads :int =8 ,max_frames :int =32 ,max_height :int =64 ,max_width :int =64 ): super ().__init__ () self .spatial_attn =SpatialAttention (hidden_size ,num_heads ,max_height ,max_width ) self .temporal_attn =TemporalAttention (hidden_size ,num_heads ,max_frames ) def forward (self ,x :torch .Tensor ,height :int ,width :int ,frames :int ,causal :bool =True )->torch .Tensor : x =x +self .spatial_attn (x ,height ,width ,frames ) x =x +self .temporal_attn (x ,height ,width ,frames ,causal ) return x class CrossAttention3D (nn .Module ): """Cross-attention for text-to-video conditioning.""" def __init__ (self ,query_dim :int ,context_dim :int =None ,heads :int =8 ): super ().__init__ () self .heads =heads context_dim =context_dim or query_dim self .head_dim =query_dim //heads self .scale =self .head_dim **-0.5 self .norm =nn .LayerNorm (query_dim ) self .to_q =nn .Linear (query_dim ,query_dim ,bias =False ) self .to_k =nn .Linear (context_dim ,query_dim ,bias =False ) self .to_v =nn .Linear (context_dim ,query_dim ,bias =False ) self .to_out =nn .Linear (query_dim ,query_dim ,bias =False ) def forward (self ,x :torch .Tensor ,context :torch .Tensor )->torch .Tensor : batch_size ,seq_len ,_ =x .shape ctx_len =context .shape [1 ] x =self .norm (x ) q =self .to_q (x ).reshape (batch_size ,seq_len ,self .heads ,self .head_dim ).transpose (1 ,2 ) k =self .to_k (context ).reshape (batch_size ,ctx_len ,self .heads ,self .head_dim ).transpose (1 ,2 ) v =self .to_v (context ).reshape (batch_size ,ctx_len ,self .heads ,self .head_dim ).transpose (1 ,2 ) qk_scale =self .head_dim **-0.25 out =F .scaled_dot_product_attention ( q *qk_scale ,k *qk_scale ,v , is_causal =False ,scale =1.0 , ) out =out .transpose (1 ,2 ).reshape (batch_size ,seq_len ,-1 ) out =self .to_out (out ) return out class Causal3DTransformerBlock (nn .Module ): """ 3D Causal Transformer Block with Factorized Spatial-Temporal Attention. Uses memory-efficient factorized attention instead of full 3D attention: - Spatial: Each frame attends within itself O(T * (H*W)^2) - Temporal: Each position attends across frames O(H*W * T^2) This reduces memory from O((T*H*W)^2) to O(T*(H*W)^2 + H*W*T^2) """ def __init__ (self ,hidden_size :int ,context_dim :int ,num_heads :int =8 ,num_experts :int =4 ,max_frames :int =32 ,max_height :int =64 ,max_width :int =64 ): super ().__init__ () self .self_attn =FactorizedSpatioTemporalAttention (hidden_size ,num_heads ,max_frames ,max_height ,max_width ) self .cross_attn =CrossAttention3D (hidden_size ,context_dim ,num_heads ) self .moe =TemporalMoELayer (hidden_size ,hidden_size *4 ,num_experts ) self .norm1 =nn .LayerNorm (hidden_size ) self .norm2 =nn .LayerNorm (hidden_size ) self .norm3 =nn .LayerNorm (hidden_size ) def forward (self ,x :torch .Tensor ,context :torch .Tensor ,height :int ,width :int ,frames :int ,temporal_context :Optional [torch .Tensor ]=None )->torch .Tensor : x =self .self_attn (self .norm1 (x ),height ,width ,frames ,causal =True ) x =x +self .cross_attn (self .norm2 (x ),context ) x =x +self .moe (self .norm3 (x ),temporal_context ) return x class FlowMatchingScheduler : """ Flow Matching scheduler for video generation. Uses optimal transport paths for superior generation quality. """ def __init__ (self ,num_steps :int =50 ,sigma_min :float =0.002 ): self .num_steps =num_steps self .sigma_min =sigma_min self .timesteps =torch .linspace (1 ,0 ,num_steps +1 ) def get_velocity (self ,x_t :torch .Tensor ,x_0 :torch .Tensor ,t :torch .Tensor )->torch .Tensor : """Compute target velocity for flow matching.""" return x_0 -x_t def step (self ,model_output :torch .Tensor ,t :torch .Tensor ,t_prev :torch .Tensor ,x_t :torch .Tensor )->torch .Tensor : """Single step of flow matching ODE.""" dt =t -t_prev x_prev =x_t +model_output *dt .view (-1 ,1 ,1 ,1 ,1 ) return x_prev def add_noise (self ,x_0 :torch .Tensor ,t :torch .Tensor )->torch .Tensor : """Add noise for training (linear interpolation).""" noise =torch .randn_like (x_0 ) t =t .to (x_0 .dtype ).view (-1 ,1 ,1 ,1 ,1 ) x_t =t *noise +(1 -t )*x_0 return x_t class VideoUNet3D (nn .Module ): """ 3D U-Net for video generation with Factorized Spatial-Temporal Attention. Uses memory-efficient factorized attention that processes spatial and temporal dimensions separately, reducing memory from O((T*H*W)^2) to O(T*(H*W)^2 + H*W*T^2). """ def __init__ ( self , in_channels :int =4 , out_channels :int =4 , hidden_size :int =512 , context_dim :int =1024 , num_layers :int =4 , num_heads :int =8 , num_experts :int =4 , num_frames :int =16 , max_height :int =64 , max_width :int =64 , ): super ().__init__ () self .hidden_size =hidden_size self .num_frames =num_frames self .time_embed =nn .Sequential ( nn .Linear (hidden_size ,hidden_size *4 ), nn .SiLU (), nn .Linear (hidden_size *4 ,hidden_size ), ) self .input_proj =nn .Conv3d (in_channels ,hidden_size ,kernel_size =3 ,padding =1 ) self .transformer_blocks =nn .ModuleList ([ Causal3DTransformerBlock (hidden_size ,context_dim ,num_heads ,num_experts ,num_frames ,max_height ,max_width ) for _ in range (num_layers ) ]) self .output_proj =nn .Sequential ( nn .GroupNorm (32 ,hidden_size ), nn .SiLU (), nn .Conv3d (hidden_size ,out_channels ,kernel_size =3 ,padding =1 ), ) nn .init .zeros_ (self .output_proj [-1 ].weight ) nn .init .zeros_ (self .output_proj [-1 ].bias ) self .gradient_checkpointing =False def enable_gradient_checkpointing (self ): """Enable gradient checkpointing for memory efficiency.""" self .gradient_checkpointing =True def forward (self ,x :torch .Tensor ,timesteps :torch .Tensor ,context :torch .Tensor ,first_frame_latent :Optional [torch .Tensor ]=None )->torch .Tensor : batch_size ,channels ,frames ,height ,width =x .shape half_dim =self .hidden_size //2 t_emb =math .log (10000 )/(half_dim -1 ) t_emb =torch .exp (torch .arange (half_dim ,device =x .device ,dtype =x .dtype )*-t_emb ) t_emb =timesteps [:,None ].to (x .dtype )*t_emb [None ,:] t_emb =torch .cat ([torch .sin (t_emb ),torch .cos (t_emb )],dim =-1 ) t_emb =self .time_embed (t_emb ) h =self .input_proj (x ) h =h .permute (0 ,2 ,3 ,4 ,1 ).reshape (batch_size ,frames *height *width ,self .hidden_size ) temporal_context =t_emb .unsqueeze (1 ).expand (-1 ,frames *height *width ,-1 ) for block in self .transformer_blocks : if self .gradient_checkpointing and self .training : h =torch .utils .checkpoint .checkpoint ( block ,h ,context ,height ,width ,frames ,temporal_context , use_reentrant =False ) else : h =block (h ,context ,height ,width ,frames ,temporal_context ) h =h .reshape (batch_size ,frames ,height ,width ,self .hidden_size ).permute (0 ,4 ,1 ,2 ,3 ) velocity =self .output_proj (h ) return velocity class VideoVAE3D (nn .Module ): """ 3D VAE for video encoding/decoding using VidTok architecture. This replaces the simple placeholder with proper temporal+spatial compression following Microsoft's VidTok architecture for high-quality video tokenization. Features: - Proper temporal compression (4x default) - Proper spatial compression (8x default, same as image VAE) - AlphaBlender for temporal blending - Causal mode support for streaming - Both KL (continuous) and FSQ (discrete) tokenization Compression: [B, C, T, H, W] -> [B, latent_ch, T/4, H/8, W/8] """ def __init__ ( self , in_channels :int =3 , latent_channels :int =4 , base_channels :int =64 , temporal_compression :int =4 , spatial_compression :int =8 , causal :bool =True , use_fsq :bool =False , ): super ().__init__ () self .in_channels =in_channels self .latent_channels =latent_channels self .temporal_compression =temporal_compression self .spatial_compression =spatial_compression self .causal =causal self .use_fsq =use_fsq self .temporal_stages =int (math .log2 (temporal_compression )) self .spatial_stages =int (math .log2 (spatial_compression )) encoder_layers =[] ch_in =in_channels ch_out =base_channels encoder_layers .append (nn .Conv3d (ch_in ,ch_out ,3 ,padding =1 )) encoder_layers .append (nn .SiLU ()) for i in range (self .spatial_stages -self .temporal_stages ): ch_in =ch_out ch_out =min (ch_out *2 ,base_channels *8 ) encoder_layers .append (nn .Conv3d (ch_in ,ch_out ,3 ,stride =(1 ,2 ,2 ),padding =1 )) encoder_layers .append (nn .SiLU ()) for i in range (self .temporal_stages ): ch_in =ch_out ch_out =min (ch_out *2 ,base_channels *8 ) encoder_layers .append (nn .Conv3d (ch_in ,ch_out ,3 ,stride =(2 ,2 ,2 ),padding =1 )) encoder_layers .append (nn .SiLU ()) out_ch =latent_channels *2 if not use_fsq else latent_channels encoder_layers .append (nn .Conv3d (ch_out ,out_ch ,3 ,padding =1 )) self .encoder =nn .Sequential (*encoder_layers ) decoder_layers =[] ch_in =latent_channels ch_out =base_channels *(2 **min (self .spatial_stages ,3 )) decoder_layers .append (nn .Conv3d (ch_in ,ch_out ,3 ,padding =1 )) decoder_layers .append (nn .SiLU ()) for i in range (self .temporal_stages ): ch_in =ch_out ch_out =max (ch_out //2 ,base_channels ) decoder_layers .append (nn .Upsample (scale_factor =(2 ,2 ,2 ),mode ='trilinear',align_corners =False )) decoder_layers .append (nn .Conv3d (ch_in ,ch_out ,3 ,padding =1 )) decoder_layers .append (nn .SiLU ()) for i in range (self .spatial_stages -self .temporal_stages ): ch_in =ch_out ch_out =max (ch_out //2 ,base_channels ) decoder_layers .append (nn .Upsample (scale_factor =(1 ,2 ,2 ),mode ='trilinear',align_corners =False )) decoder_layers .append (nn .Conv3d (ch_in ,ch_out ,3 ,padding =1 )) decoder_layers .append (nn .SiLU ()) decoder_layers .append (nn .Conv3d (ch_out ,in_channels ,3 ,padding =1 )) self .decoder =nn .Sequential (*decoder_layers ) print (f" ๐ŸŽฌ VideoVAE3D (VidTok): {temporal_compression }x{spatial_compression }x{spatial_compression } compression") print (f" Temporal stages: {self .temporal_stages }, Spatial stages: {self .spatial_stages }") print (f" Mode: {'FSQ (discrete)'if use_fsq else 'KL (continuous)'}, Causal: {causal }") def encode (self ,x :torch .Tensor )->Tuple [torch .Tensor ,torch .Tensor ,torch .Tensor ]: """ Encode video to latent space. Args: x: [B, C, T, H, W] video tensor, values in [0, 1] or [-1, 1] Returns: Tuple of (z, mean, logvar) where z is the sampled latent """ h =self .encoder (x ) if self .use_fsq : z =self ._fsq_quantize (h ) return z ,z ,torch .zeros_like (z ) else : mean ,logvar =h .chunk (2 ,dim =1 ) logvar =torch .clamp (logvar ,-30 ,20 ) std =torch .exp (0.5 *logvar ) z =mean +std *torch .randn_like (std ) return z ,mean ,logvar def _fsq_quantize (self ,z :torch .Tensor ,levels :int =8 )->torch .Tensor : """Finite Scalar Quantization.""" z =torch .tanh (z ) z =torch .round ((z +1 )*(levels -1 )/2 )*2 /(levels -1 )-1 return z def decode (self ,z :torch .Tensor )->torch .Tensor : """ Decode latent to video. Args: z: [B, latent_ch, t, h, w] latent tensor Returns: [B, C, T, H, W] reconstructed video """ return self .decoder (z ) class MobileVideoDiffusion (nn .Module ): """ SOTA Video Diffusion with Flow Matching, Factorized Attention, Temporal MoE. Uses memory-efficient factorized spatial-temporal attention: - Full 3D attention: O((T*H*W)^2) = 1B+ attention scores (OOM!) - Factorized: O(T*(H*W)^2 + H*W*T^2) = ~134M scores (7.5x less memory) Optimized for 2x T4 GPUs (15GB each) with FP16. """ def __init__ ( self , latent_channels :int =4 , base_channels :int =64 , context_dim :int =1024 , num_frames :int =16 , image_size :int =256 , num_inference_steps :int =50 , cfg_scale :float =7.5 , temporal_compression :int =4 , spatial_compression :int =8 , causal :bool =True , use_fsq :bool =False , ): super ().__init__ () self .latent_channels =latent_channels self .context_dim =context_dim self .num_frames =num_frames self .image_size =image_size self .temporal_compression =temporal_compression self .spatial_compression =spatial_compression self .latent_size =image_size //spatial_compression self .latent_frames =num_frames //temporal_compression self .num_inference_steps =num_inference_steps self .cfg_scale =cfg_scale self .vae =VideoVAE3D ( in_channels =3 , latent_channels =latent_channels , base_channels =base_channels , temporal_compression =temporal_compression , spatial_compression =spatial_compression , causal =causal , use_fsq =use_fsq , ) self .unet =VideoUNet3D ( in_channels =latent_channels , out_channels =latent_channels , hidden_size =base_channels *4 , context_dim =context_dim , num_layers =4 , num_heads =8 , num_experts =4 , num_frames =num_frames , max_height =self .latent_size , max_width =self .latent_size , ) self .scheduler =FlowMatchingScheduler (num_inference_steps ) def encode_video (self ,video :torch .Tensor )->Tuple [torch .Tensor ,torch .Tensor ,torch .Tensor ]: return self .vae .encode (video *2 -1 ) def decode_video (self ,z :torch .Tensor )->torch .Tensor : return self .vae .decode (z ) def encode_image (self ,image :torch .Tensor )->torch .Tensor : image_expanded =image .unsqueeze (2 ) z ,_ ,_ =self .vae .encode (image_expanded ) return z .squeeze (2 ) def training_step (self ,video :torch .Tensor ,context :torch .Tensor ,first_frame :Optional [torch .Tensor ]=None )->dict : device =video .device dtype =video .dtype batch_size =video .shape [0 ] z ,mean ,logvar =self .encode_video (video ) del video t =torch .rand (batch_size ,device =device ,dtype =dtype ) x_t =self .scheduler .add_noise (z ,t ) target_velocity =self .scheduler .get_velocity (x_t ,z ,t ) if self .training : drop_mask =torch .rand (batch_size ,device =device )<0.1 drop_mask_expanded =drop_mask .view (batch_size ,1 ,1 ).expand_as (context ) null_ctx =torch .zeros_like (context ) context =torch .where (drop_mask_expanded ,null_ctx ,context ) del drop_mask ,drop_mask_expanded ,null_ctx pred_velocity =self .unet (x_t ,(t *1000 ).to (dtype ),context ,None ) del x_t ,context flow_loss =F .mse_loss (pred_velocity ,target_velocity ) del pred_velocity ,target_velocity kl_loss =-0.5 *torch .mean (1 +logvar -mean .pow (2 )-logvar .exp ()) temporal_loss =torch .tensor (0.0 ,device =device ,dtype =dtype ) if z .shape [2 ]>1 : z_diff =z [:,:,1 :]-z [:,:,:-1 ] temporal_loss =torch .mean (z_diff **2 ) del z_diff del z ,mean ,logvar total_loss =flow_loss +0.0001 *kl_loss +0.01 *temporal_loss return { 'flow_loss':flow_loss , 'kl_loss':kl_loss , 'temporal_loss':temporal_loss , 'total_loss':total_loss , } @torch .no_grad () def generate_t2v (self ,context :torch .Tensor ,num_frames :int =None ,guidance_scale :float =None ,num_steps :int =None )->torch .Tensor : device =context .device batch_size =context .shape [0 ] seq_len =context .shape [1 ] num_frames =num_frames or self .num_frames guidance_scale =guidance_scale or self .cfg_scale num_steps =num_steps or self .num_inference_steps latents =torch .randn ( batch_size ,self .latent_channels ,num_frames , self .latent_size ,self .latent_size ,device =device ) timesteps =torch .linspace (1 ,0 ,num_steps +1 ,device =device ) if guidance_scale >1.0 : null_ctx =torch .zeros (batch_size ,seq_len ,self .context_dim ,device =device ,dtype =context .dtype ) context =torch .cat ([null_ctx ,context ]) for i in range (num_steps ): t =timesteps [i ] t_prev =timesteps [i +1 ] t_batch =t .expand (batch_size )*1000 if guidance_scale >1.0 : latent_input =torch .cat ([latents ,latents ]) t_input =torch .cat ([t_batch ,t_batch ]) velocity_pred =self .unet (latent_input ,t_input ,context ,None ) velocity_uncond ,velocity_cond =velocity_pred .chunk (2 ) velocity_pred =velocity_uncond +guidance_scale *(velocity_cond -velocity_uncond ) else : velocity_pred =self .unet (latents ,t_batch ,context ,None ) latents =self .scheduler .step (velocity_pred ,t ,t_prev ,latents ) video =self .decode_video (latents ) return torch .clamp ((video +1 )/2 ,0 ,1 ) @torch .no_grad () def generate_i2v (self ,first_frame :torch .Tensor ,context :Optional [torch .Tensor ]=None ,num_frames :int =None ,guidance_scale :float =None ,num_steps :int =None )->torch .Tensor : device =first_frame .device batch_size =first_frame .shape [0 ] num_frames =num_frames or self .num_frames guidance_scale =guidance_scale or self .cfg_scale num_steps =num_steps or self .num_inference_steps first_frame_latent =self .encode_image (first_frame *2 -1 ) latents =torch .randn ( batch_size ,self .latent_channels ,num_frames , self .latent_size ,self .latent_size ,device =device ) latents [:,:,0 ]=first_frame_latent if context is None : context =torch .zeros (batch_size ,77 ,self .context_dim ,device =device ) seq_len =context .shape [1 ] timesteps =torch .linspace (1 ,0 ,num_steps +1 ,device =device ) if guidance_scale >1.0 : null_ctx =torch .zeros (batch_size ,seq_len ,self .context_dim ,device =device ,dtype =context .dtype ) context =torch .cat ([null_ctx ,context ]) for i in range (num_steps ): t =timesteps [i ] t_prev =timesteps [i +1 ] t_batch =t .expand (batch_size )*1000 if guidance_scale >1.0 : latent_input =torch .cat ([latents ,latents ]) t_input =torch .cat ([t_batch ,t_batch ]) velocity_pred =self .unet (latent_input ,t_input ,context ,None ) velocity_uncond ,velocity_cond =velocity_pred .chunk (2 ) velocity_pred =velocity_uncond +guidance_scale *(velocity_cond -velocity_uncond ) else : velocity_pred =self .unet (latents ,t_batch ,context ,None ) latents =self .scheduler .step (velocity_pred ,t ,t_prev ,latents ) latents [:,:,0 ]=first_frame_latent video =self .decode_video (latents ) return torch .clamp ((video +1 )/2 ,0 ,1 ) ============================================================================== MODELS.LLM.MOE_LLAMA ============================================================================== EPS =1e-5 class YaRNRotaryEmbedding (nn .Module ): """ YaRN (Yet another RoPE extensioN) with LongRoPE-style improvements. Supports up to 128K+ context with proper frequency scaling. """ def __init__ ( self , dim :int , max_position_embeddings :int =131072 , base :float =500000.0 , original_max_position_embeddings :int =8192 , beta_fast :float =32.0 , beta_slow :float =1.0 , mscale :float =1.0 , ): super ().__init__ () self .dim =dim self .max_position_embeddings =max_position_embeddings self .base =base self .original_max_position =original_max_position_embeddings self .beta_fast =beta_fast self .beta_slow =beta_slow self .mscale =mscale self .scaling_factor =max_position_embeddings /original_max_position_embeddings inv_freq =self ._compute_yarn_inv_freq () self .register_buffer ('inv_freq',inv_freq ,persistent =False ) def _compute_yarn_inv_freq (self )->torch .Tensor : """Compute YaRN-scaled inverse frequencies.""" pos_freqs =self .base **(torch .arange (0 ,self .dim ,2 ,dtype =torch .float32 )/self .dim ) inv_freq_extrapolation =1.0 /pos_freqs inv_freq_interpolation =1.0 /(self .scaling_factor *pos_freqs ) low =max (math .floor (self .dim *math .log (self .original_max_position /(self .beta_fast *2 *math .pi ))/ (2 *math .log (self .base ))),0 ) high =min (math .ceil (self .dim *math .log (self .original_max_position /(self .beta_slow *2 *math .pi ))/ (2 *math .log (self .base ))),self .dim -1 ) inv_freq =torch .zeros (self .dim //2 ,dtype =torch .float32 ) for i in range (self .dim //2 ): if i high : inv_freq [i ]=inv_freq_extrapolation [i ] else : smooth =(i -low )/max (high -low ,1 ) inv_freq [i ]=(1 -smooth )*inv_freq_interpolation [i ]+smooth *inv_freq_extrapolation [i ] return inv_freq def _get_mscale (self ,scale :float )->float : """Get attention scaling factor for YaRN.""" if scale <=1 : return 1.0 return 0.1 *math .log (scale )+1.0 def forward (self ,x :torch .Tensor ,position_ids :torch .Tensor )->Tuple [torch .Tensor ,torch .Tensor ]: device =x .device inv_freq =self .inv_freq .to (device ) inv_freq_expanded =inv_freq [None ,:,None ].float ().expand (position_ids .shape [0 ],-1 ,1 ) position_ids_expanded =position_ids [:,None ,:].float () freqs =(inv_freq_expanded @position_ids_expanded ).transpose (1 ,2 ) emb =torch .cat ((freqs ,freqs ),dim =-1 ) mscale =self ._get_mscale (self .scaling_factor )*self .mscale cos =emb .cos ().to (dtype =x .dtype )*mscale sin =emb .sin ().to (dtype =x .dtype )*mscale return cos ,sin LlamaRotaryEmbedding =YaRNRotaryEmbedding def rotate_half (x :torch .Tensor )->torch .Tensor : x1 =x [...,:x .shape [-1 ]//2 ] x2 =x [...,x .shape [-1 ]//2 :] return torch .cat ((-x2 ,x1 ),dim =-1 ) def apply_rotary_pos_emb ( q :torch .Tensor , k :torch .Tensor , cos :torch .Tensor , sin :torch .Tensor , position_ids :Optional [torch .Tensor ]=None , unsqueeze_dim :int =1 , )->Tuple [torch .Tensor ,torch .Tensor ]: cos =cos .unsqueeze (unsqueeze_dim ) sin =sin .unsqueeze (unsqueeze_dim ) q_embed =(q *cos )+(rotate_half (q )*sin ) k_embed =(k *cos )+(rotate_half (k )*sin ) return q_embed ,k_embed class KVCache : """Pre-allocated KV Cache โ€” static buffer with index-based filling. Eliminates VRAM fragmentation from torch.cat during autoregressive generation. Buffer is allocated once at first use and reused via slice assignment. """ __slots__ =('key_cache','value_cache','seen_tokens','_max_len') def __init__ ( self , key_cache :torch .Tensor =None , value_cache :torch .Tensor =None , seen_tokens :int =0 , max_seq_len :int =131072 , ): self .key_cache =key_cache self .value_cache =value_cache self .seen_tokens =seen_tokens self ._max_len =max_seq_len def _allocate (self ,batch :int ,heads :int ,head_dim :int ,device :torch .device ,dtype :torch .dtype ): """Allocate static buffer on first use.""" self .key_cache =torch .zeros (batch ,heads ,self ._max_len ,head_dim ,device =device ,dtype =dtype ) self .value_cache =torch .zeros (batch ,heads ,self ._max_len ,head_dim ,device =device ,dtype =dtype ) def update ( self , key_states :torch .Tensor , value_states :torch .Tensor , chunk_size :Optional [int ]=None , )->Tuple [torch .Tensor ,torch .Tensor ]: batch ,heads ,new_len ,head_dim =key_states .shape if self .key_cache is None : self ._allocate (batch ,heads ,head_dim ,key_states .device ,key_states .dtype ) self .seen_tokens =0 if chunk_size is not None and self .seen_tokens +new_len >chunk_size *2 : keep =chunk_size if self .seen_tokens >keep : self .key_cache [:,:,:keep ]=self .key_cache [:,:,self .seen_tokens -keep :self .seen_tokens ].clone () self .value_cache [:,:,:keep ]=self .value_cache [:,:,self .seen_tokens -keep :self .seen_tokens ].clone () self .seen_tokens =keep if self .seen_tokens +new_len >self .key_cache .shape [2 ]: new_max =max (self .key_cache .shape [2 ]*2 ,self .seen_tokens +new_len ) new_key =torch .zeros (batch ,heads ,new_max ,head_dim ,device =key_states .device ,dtype =key_states .dtype ) new_val =torch .zeros (batch ,heads ,new_max ,head_dim ,device =key_states .device ,dtype =key_states .dtype ) new_key [:,:,:self .seen_tokens ]=self .key_cache [:,:,:self .seen_tokens ] new_val [:,:,:self .seen_tokens ]=self .value_cache [:,:,:self .seen_tokens ] self .key_cache =new_key self .value_cache =new_val self .key_cache [:,:,self .seen_tokens :self .seen_tokens +new_len ]=key_states self .value_cache [:,:,self .seen_tokens :self .seen_tokens +new_len ]=value_states self .seen_tokens +=new_len return self .key_cache [:,:,:self .seen_tokens ],self .value_cache [:,:,:self .seen_tokens ] def reset (self ): """Reset cache position without deallocating the buffer.""" self .seen_tokens =0 def ring_attention ( query :torch .Tensor , key :torch .Tensor , value :torch .Tensor , chunk_size :int =4096 , causal :bool =True , )->torch .Tensor : """ Ring Attention for distributed long-context processing. Processes sequence in chunks with online softmax accumulation. Args: query: [batch, heads, seq_len, head_dim] key: [batch, heads, kv_len, head_dim] value: [batch, heads, kv_len, head_dim] chunk_size: Size of each attention chunk causal: Whether to apply causal masking Returns: Output tensor [batch, heads, seq_len, head_dim] """ batch_size ,num_heads ,seq_len ,head_dim =query .shape kv_len =key .shape [2 ] if seq_len <=chunk_size and kv_len <=chunk_size : qk_scale =head_dim **-0.25 use_causal =causal and seq_len ==kv_len and seq_len >1 if use_causal : return F .scaled_dot_product_attention ( query *qk_scale ,key *qk_scale ,value , is_causal =True ,scale =1.0 , ) elif causal and kv_len >seq_len : causal_mask =torch .zeros (seq_len ,kv_len ,device =query .device ,dtype =query .dtype ) q_pos =torch .arange (seq_len ,device =query .device )+(kv_len -seq_len ) k_pos =torch .arange (kv_len ,device =query .device ) causal_mask =torch .where (k_pos .unsqueeze (0 )>q_pos .unsqueeze (1 ),float ('-inf'),0.0 ) return F .scaled_dot_product_attention ( query *qk_scale ,key *qk_scale ,value , attn_mask =causal_mask ,scale =1.0 , ) else : return F .scaled_dot_product_attention ( query *qk_scale ,key *qk_scale ,value , is_causal =False ,scale =1.0 , ) scale =head_dim **-0.5 output =torch .zeros_like (query ) max_logits =torch .full ((batch_size ,num_heads ,seq_len ,1 ),float ('-inf'),device =query .device ,dtype =query .dtype ) sum_exp =torch .zeros ((batch_size ,num_heads ,seq_len ,1 ),device =query .device ,dtype =query .dtype ) if causal : q_positions =torch .arange (seq_len ,device =query .device ) if kv_len >seq_len : q_positions =q_positions +(kv_len -seq_len ) num_kv_chunks =(kv_len +chunk_size -1 )//chunk_size for kv_idx in range (num_kv_chunks ): kv_start =kv_idx *chunk_size kv_end =min ((kv_idx +1 )*chunk_size ,kv_len ) key_chunk =key [:,:,kv_start :kv_end ,:] value_chunk =value [:,:,kv_start :kv_end ,:] attn_chunk =torch .matmul (query ,key_chunk .transpose (-1 ,-2 ))*scale if causal : k_positions =torch .arange (kv_start ,kv_end ,device =query .device ) causal_mask =k_positions .unsqueeze (0 )>q_positions .unsqueeze (1 ) attn_chunk =attn_chunk .masked_fill (causal_mask .unsqueeze (0 ).unsqueeze (0 ),float ('-inf')) chunk_max =attn_chunk .max (dim =-1 ,keepdim =True )[0 ] new_max =torch .maximum (max_logits ,chunk_max ) exp_weights =torch .exp (attn_chunk -new_max ) exp_sum_chunk =exp_weights .sum (dim =-1 ,keepdim =True ) correction =torch .exp (max_logits -new_max ) output =output *correction +torch .matmul (exp_weights ,value_chunk ) sum_exp =sum_exp *correction +exp_sum_chunk max_logits =new_max output =output /(sum_exp +EPS ) return output class MultiHeadLatentAttention (nn .Module ): """ Multi-Head Latent Attention (MLA) from DeepSeek-V2. Compresses KV cache using low-rank projections for memory efficiency. """ def __init__ ( self , hidden_size :int , num_heads :int , num_kv_heads :int =None , head_dim :int =None , kv_lora_rank :int =512 , q_lora_rank :int =0 , rope_theta :float =500000.0 , max_position_embeddings :int =131072 , use_ring_attention :bool =True , ring_chunk_size :int =4096 , ): super ().__init__ () self .hidden_size =hidden_size self .num_heads =num_heads self .num_kv_heads =num_kv_heads or num_heads self .head_dim =head_dim or hidden_size //num_heads self .kv_lora_rank =kv_lora_rank self .q_lora_rank =q_lora_rank self .use_ring_attention =use_ring_attention self .ring_chunk_size =ring_chunk_size self .num_key_value_groups =self .num_heads //self .num_kv_heads self .scale =self .head_dim **-0.5 if q_lora_rank >0 : self .q_a_proj =nn .Linear (hidden_size ,q_lora_rank ,bias =False ) self .q_b_proj =nn .Linear (q_lora_rank ,num_heads *self .head_dim ,bias =False ) self .q_a_layernorm =LlamaRMSNorm (q_lora_rank ) else : self .q_proj =nn .Linear (hidden_size ,num_heads *self .head_dim ,bias =False ) self .kv_a_proj =nn .Linear (hidden_size ,kv_lora_rank +self .head_dim ,bias =False ) self .kv_b_proj =nn .Linear (kv_lora_rank ,self .num_kv_heads *self .head_dim *2 ,bias =False ) self .kv_a_layernorm =LlamaRMSNorm (kv_lora_rank ) self .o_proj =nn .Linear (num_heads *self .head_dim ,hidden_size ,bias =False ) self .rotary_emb =YaRNRotaryEmbedding ( dim =self .head_dim , max_position_embeddings =max_position_embeddings , base =rope_theta , ) self ._init_weights () def _init_weights (self ): std =0.02 for name ,module in self .named_modules (): if isinstance (module ,nn .Linear ): nn .init .normal_ (module .weight ,mean =0.0 ,std =std ) def forward ( self , hidden_states :torch .Tensor , attention_mask :Optional [torch .Tensor ]=None , position_ids :Optional [torch .Tensor ]=None , past_key_value :Optional [KVCache ]=None , output_attentions :bool =False , use_cache :bool =False , )->Tuple [torch .Tensor ,Optional [torch .Tensor ],Optional [KVCache ]]: batch_size ,seq_len ,_ =hidden_states .shape if self .q_lora_rank >0 : q_compressed =self .q_a_layernorm (self .q_a_proj (hidden_states )) query_states =self .q_b_proj (q_compressed ) else : query_states =self .q_proj (hidden_states ) kv_compressed =self .kv_a_proj (hidden_states ) kv_latent ,k_pe =kv_compressed .split ([self .kv_lora_rank ,self .head_dim ],dim =-1 ) kv_latent =self .kv_a_layernorm (kv_latent ) kv_states =self .kv_b_proj (kv_latent ) query_states =query_states .view (batch_size ,seq_len ,self .num_heads ,self .head_dim ).transpose (1 ,2 ) key_states ,value_states =kv_states .split (self .num_kv_heads *self .head_dim ,dim =-1 ) key_states =key_states .view (batch_size ,seq_len ,self .num_kv_heads ,self .head_dim ).transpose (1 ,2 ) value_states =value_states .view (batch_size ,seq_len ,self .num_kv_heads ,self .head_dim ).transpose (1 ,2 ) if position_ids is None : position_ids =torch .arange (seq_len ,device =hidden_states .device ).unsqueeze (0 ).expand (batch_size ,-1 ) if past_key_value is not None and past_key_value .seen_tokens >0 : position_ids =position_ids +past_key_value .seen_tokens cos ,sin =self .rotary_emb (hidden_states ,position_ids ) query_states ,key_states =apply_rotary_pos_emb (query_states ,key_states ,cos ,sin ) if past_key_value is not None : key_states ,value_states =past_key_value .update ( key_states ,value_states , self .ring_chunk_size if self .use_ring_attention else None ) if self .use_ring_attention : if self .num_key_value_groups >1 : key_expanded =key_states .repeat_interleave (self .num_key_value_groups ,dim =1 ) value_expanded =value_states .repeat_interleave (self .num_key_value_groups ,dim =1 ) else : key_expanded =key_states value_expanded =value_states attn_output =ring_attention ( query_states ,key_expanded ,value_expanded , chunk_size =self .ring_chunk_size , causal =True , ) else : qk_scale =self .head_dim **-0.25 kv_len =key_states .shape [2 ] use_causal =(attention_mask is None and seq_len >1 and seq_len ==kv_len ) attn_output =F .scaled_dot_product_attention ( query_states *qk_scale , key_states *qk_scale , value_states , attn_mask =attention_mask , is_causal =use_causal , scale =1.0 , enable_gqa =(self .num_key_value_groups >1 ), ) attn_output =attn_output .transpose (1 ,2 ).contiguous ().view (batch_size ,seq_len ,-1 ) attn_output =self .o_proj (attn_output ) return attn_output ,None ,past_key_value if use_cache else None class AuxLosslessMoERouter (nn .Module ): """ Aux-Lossless MoE Router with Shared Expert Isolation. Eliminates auxiliary loss while maintaining load balance through architecture. """ def __init__ ( self , hidden_size :int , num_experts :int , top_k :int =2 , norm_topk_prob :bool =True , ): super ().__init__ () self .num_experts =num_experts self .top_k =top_k self .norm_topk_prob =norm_topk_prob self .input_norm =LlamaRMSNorm (hidden_size ) self .gate =nn .Linear (hidden_size ,num_experts ,bias =False ) nn .init .normal_ (self .gate .weight ,mean =0.0 ,std =0.01 ) self .expert_bias =nn .Parameter (torch .zeros (num_experts )) # Deep experts gate (4 deep experts) self .num_deep_experts = 4 self .deep_gate = nn .Linear (hidden_size , self .num_deep_experts , bias =False ) nn .init .normal_ (self .deep_gate .weight , mean =0.0 , std =0.01 ) self .deep_expert_bias = nn .Parameter (torch .zeros (self .num_deep_experts )) def forward (self ,hidden_states :torch .Tensor )->Tuple [torch .Tensor ,torch .Tensor ,torch .Tensor ]: batch_size ,seq_len ,hidden_dim =hidden_states .shape hidden_flat =hidden_states .view (-1 ,hidden_dim ) hidden_norm =self .input_norm (hidden_flat ) # Standard experts router_logits_std =self .gate (hidden_norm ) biased_logits_std =router_logits_std +self .expert_bias # Deep experts router_logits_deep = self .deep_gate (hidden_norm ) biased_logits_deep = router_logits_deep + self .deep_expert_bias # Concatenate: [batch*seq, num_experts + num_deep_experts] router_logits = torch .cat ([biased_logits_std , biased_logits_deep ], dim =-1 ) router_probs =F .softmax (router_logits ,dim =-1 ,dtype =hidden_states .dtype ) top_k_probs ,top_k_indices =torch .topk (router_probs ,self .top_k ,dim =-1 ) if self .norm_topk_prob : top_k_probs =top_k_probs /(top_k_probs .sum (dim =-1 ,keepdim =True )+EPS ) return top_k_probs ,top_k_indices ,router_logits class MoEExpert (nn .Module ): """Single MoE Expert with SwiGLU activation.""" def __init__ (self ,hidden_size :int ,intermediate_size :int ): super ().__init__ () self .gate_proj =nn .Linear (hidden_size ,intermediate_size ,bias =False ) self .up_proj =nn .Linear (hidden_size ,intermediate_size ,bias =False ) self .down_proj =nn .Linear (intermediate_size ,hidden_size ,bias =False ) self .act_fn =nn .SiLU () self ._init_weights () def _init_weights (self ): std =0.02 nn .init .normal_ (self .gate_proj .weight ,mean =0.0 ,std =std ) nn .init .normal_ (self .up_proj .weight ,mean =0.0 ,std =std ) nn .init .normal_ (self .down_proj .weight ,mean =0.0 ,std =std *0.5 ) def forward (self ,x :torch .Tensor )->torch .Tensor : return self .down_proj (self .act_fn (self .gate_proj (x ))*self .up_proj (x )) class DeepMoEExpert (nn .Module ): """Deep MoE Expert with multiple sequential SwiGLU transformations.""" def __init__ (self ,hidden_size :int ,intermediate_size :int ,depth :int =2 ): super ().__init__ () self .depth = depth self .gate_projs = nn .ModuleList ([nn .Linear (hidden_size if i == 0 else intermediate_size , intermediate_size , bias =False ) for i in range (depth )]) self .up_projs = nn .ModuleList ([nn .Linear (hidden_size if i == 0 else intermediate_size , intermediate_size , bias =False ) for i in range (depth )]) self .down_projs = nn .ModuleList ([nn .Linear (intermediate_size , intermediate_size if i < depth - 1 else hidden_size , bias =False ) for i in range (depth )]) self .act_fn = nn .SiLU () self ._init_weights () def _init_weights (self ): std =0.02 for g , u , d in zip (self .gate_projs , self .up_projs , self .down_projs ): nn .init .normal_ (g .weight ,mean =0.0 ,std =std ) nn .init .normal_ (u .weight ,mean =0.0 ,std =std ) nn .init .normal_ (d .weight ,mean =0.0 ,std =std *0.5 ) def forward (self ,x :torch .Tensor )->torch .Tensor : for i in range (self .depth ): # Optional residual connection if intermediate sizes match, but standard SwiGLU doesn't usually use them internally unless specified. # We'll stick to sequential application as defined: Input -> SwiGLU -> SwiGLU ... -> DownProj gate = self .act_fn (self .gate_projs [i ](x )) up = self .up_projs [i ](x ) x = self .down_projs [i ](gate * up ) return x class IsolatedSharedExpert (nn .Module ): """ Isolated Shared Expert that always processes all tokens. Separate from routed experts to prevent competition. """ def __init__ (self ,hidden_size :int ,intermediate_size :int ): super ().__init__ () self .gate_proj =nn .Linear (hidden_size ,intermediate_size ,bias =False ) self .up_proj =nn .Linear (hidden_size ,intermediate_size ,bias =False ) self .down_proj =nn .Linear (intermediate_size ,hidden_size ,bias =False ) self .act_fn =nn .SiLU () self ._init_weights () def _init_weights (self ): std =0.02 nn .init .normal_ (self .gate_proj .weight ,mean =0.0 ,std =std ) nn .init .normal_ (self .up_proj .weight ,mean =0.0 ,std =std ) nn .init .normal_ (self .down_proj .weight ,mean =0.0 ,std =std *0.5 ) def forward (self ,x :torch .Tensor )->torch .Tensor : return self .down_proj (self .act_fn (self .gate_proj (x ))*self .up_proj (x )) class AuxLosslessMoELayer (nn .Module ): """ Aux-Lossless MoE Layer with Isolated Shared Expert. No auxiliary loss needed - load balance maintained through isolation. """ def __init__ ( self , hidden_size :int , intermediate_size :int , num_experts :int =8 , num_experts_per_tok :int =2 , shared_expert_intermediate_size :int =None , ): super ().__init__ () self .hidden_size =hidden_size self .num_experts =num_experts self .num_experts_per_tok =num_experts_per_tok self .router =AuxLosslessMoERouter (hidden_size ,num_experts ,num_experts_per_tok ) self .experts =nn .ModuleList ([ MoEExpert (hidden_size ,intermediate_size ) for _ in range (num_experts ) ]) # Deep Experts: Depths 2, 3, 4, 5 self .num_deep_experts = 4 self .deep_experts = nn .ModuleList ([ DeepMoEExpert (hidden_size , intermediate_size , depth =d ) for d in range (2 , 6 ) ]) shared_size =shared_expert_intermediate_size or intermediate_size self .shared_expert =IsolatedSharedExpert (hidden_size ,shared_size ) def forward (self ,hidden_states :torch .Tensor )->Tuple [torch .Tensor ,torch .Tensor ]: batch_size ,seq_len ,hidden_size =hidden_states .shape original_dtype =hidden_states .dtype hidden_flat =hidden_states .view (-1 ,hidden_size ) num_tokens =hidden_flat .shape [0 ] top_k_probs ,top_k_indices ,router_logits =self .router (hidden_states ) if hasattr (self ,'_utilization_tracker'): self ._utilization_tracker .record (top_k_indices ) final_output =torch .zeros_like (hidden_flat ) total_experts = self .num_experts + self .num_deep_experts for expert_idx in range (total_experts ): # Determine which expert list to use if expert_idx < self .num_experts : expert =self .experts [expert_idx ] else : expert =self .deep_experts [expert_idx - self .num_experts ] for k in range (self .num_experts_per_tok ): mask =(top_k_indices [:,k ]==expert_idx ) if mask .any (): expert_input =hidden_flat [mask ] expert_output =expert (expert_input ) weight =top_k_probs [mask ,k :k +1 ] weighted_output =(weight *expert_output ).to (original_dtype ) final_output [mask ]=final_output [mask ]+weighted_output shared_output =self .shared_expert (hidden_flat ) final_output =final_output +shared_output .to (original_dtype ) final_output =final_output .view (batch_size ,seq_len ,hidden_size ) aux_loss =self ._compute_aux_loss (router_logits ,top_k_indices ,num_tokens ) return final_output ,aux_loss def _compute_aux_loss ( self , router_logits :torch .Tensor , top_k_indices :torch .Tensor , num_tokens :int , )->torch .Tensor : """ Aux-lossless auxiliary loss. Uses z-loss to keep router logits from growing unboundedly (FP16 stability), plus a soft utilization penalty that activates only when experts go completely cold. The expert_bias parameter handles routine load balancing. """ z_loss =torch .logsumexp (router_logits ,dim =-1 ).square ().mean ()*0.0001 # Add penalty for choosing deep experts # Depths are 2, 3, 4, 5 for indices (num_experts) to (num_experts + 3) # Cost is roughly proportional to depth deep_penalty = torch .tensor (0.0 , device =router_logits .device , dtype =router_logits .dtype ) # Calculate how often each deep expert was selected # top_k_indices shape: [batch*seq, top_k] for i in range (self .num_deep_experts ): expert_idx = self .num_experts + i depth = i + 2 # depths 2, 3, 4, 5 # Count how many times this deep expert was chosen in top-k selection_count = (top_k_indices == expert_idx ).sum () # Simple penalty: deeper experts cost more # Multiplied by a small scalar to act as a soft deterrent # The model must truly need the depth to offset this loss increase deep_penalty += selection_count .float () * depth * 0.00005 return z_loss + deep_penalty expert_mask =F .one_hot (top_k_indices ,self .num_experts ).float () tokens_per_expert =expert_mask .sum (dim =(0 ,1 )) fraction_used =(tokens_per_expert >0 ).float ().mean () utilization_loss =(1.0 -fraction_used )*0.01 return z_loss +utilization_loss MoELayer =AuxLosslessMoELayer class MoELlamaDecoderLayer (nn .Module ): """Decoder layer with MLA and Aux-Lossless MoE.""" def __init__ (self ,config ,layer_idx :int ,moe_config :dict =None ): super ().__init__ () self .hidden_size =config .hidden_size self .layer_idx =layer_idx use_ring =getattr (config ,'use_ring_attention',True ) ring_chunk =getattr (config ,'ring_attention_chunk_size',4096 ) num_kv_heads =getattr (config ,'num_key_value_heads',config .num_attention_heads //4 ) self .self_attn =MultiHeadLatentAttention ( hidden_size =config .hidden_size , num_heads =config .num_attention_heads , num_kv_heads =num_kv_heads , rope_theta =getattr (config ,'rope_theta',500000.0 ), max_position_embeddings =config .max_position_embeddings , use_ring_attention =use_ring , ring_chunk_size =ring_chunk , ) self .input_layernorm =LlamaRMSNorm (config .hidden_size ,eps =config .rms_norm_eps ) self .post_attention_layernorm =LlamaRMSNorm (config .hidden_size ,eps =config .rms_norm_eps ) self .use_moe =moe_config and moe_config .get ('use_moe',False ) moe_freq =moe_config .get ('moe_layer_freq',2 )if moe_config else 2 if self .use_moe and layer_idx %moe_freq ==(moe_freq -1 ): self .mlp =AuxLosslessMoELayer ( hidden_size =config .hidden_size , intermediate_size =moe_config .get ('intermediate_size',config .intermediate_size ), num_experts =moe_config .get ('num_experts',8 ), num_experts_per_tok =moe_config .get ('num_experts_per_tok',2 ), ) self .is_moe_layer =True else : self .mlp =MoEExpert (config .hidden_size ,config .intermediate_size ) self .is_moe_layer =False def forward ( self , hidden_states :torch .Tensor , attention_mask :Optional [torch .Tensor ]=None , position_ids :Optional [torch .Tensor ]=None , past_key_value :Optional [KVCache ]=None , output_attentions :bool =False , use_cache :bool =False , )->Tuple [torch .Tensor ,Optional [torch .Tensor ],Optional [KVCache ],Optional [torch .Tensor ]]: residual =hidden_states hidden_states =self .input_layernorm (hidden_states ) hidden_states ,_ ,present_key_value =self .self_attn ( hidden_states =hidden_states , attention_mask =attention_mask , position_ids =position_ids , past_key_value =past_key_value , output_attentions =output_attentions , use_cache =use_cache , ) hidden_states =residual +hidden_states residual =hidden_states hidden_states =self .post_attention_layernorm (hidden_states ) aux_loss =None if self .is_moe_layer : hidden_states ,aux_loss =self .mlp (hidden_states ) else : hidden_states =self .mlp (hidden_states ) hidden_states =residual +hidden_states return hidden_states ,None ,present_key_value ,aux_loss @dataclass class MoELlamaModelOutput : last_hidden_state :torch .Tensor past_key_values :Optional [List [KVCache ]]=None hidden_states :Optional [Tuple [torch .Tensor ]]=None attentions :Optional [Tuple [torch .Tensor ]]=None aux_loss :Optional [torch .Tensor ]=None class MoELlamaModel (nn .Module ): """MoE LLaMA Model with MLA and Ring Attention.""" def __init__ (self ,config ,moe_config :dict =None ): super ().__init__ () self .config =config self .moe_config =moe_config self .gradient_checkpointing =False self .embed_tokens =nn .Embedding (config .vocab_size ,config .hidden_size ) self .layers =nn .ModuleList ([ MoELlamaDecoderLayer (config ,layer_idx ,moe_config ) for layer_idx in range (config .num_hidden_layers ) ]) self .norm =LlamaRMSNorm (config .hidden_size ,eps =config .rms_norm_eps ) self .num_moe_layers =sum (1 for layer in self .layers if layer .is_moe_layer ) # โ”€โ”€ Coconut: Continuous Thought components โ”€โ”€ # Learned gate controls how much recurrent thought vs original input # to retain at each thinking step. Sigmoid output in [0,1]. self .thought_gate = nn .Linear (config .hidden_size , 1 , bias =True ) nn .init .constant_ (self .thought_gate .bias , -2.0 ) # Initialize gate biased toward original (sigmoid(-2)โ‰ˆ0.12) self .thought_layernorm = LlamaRMSNorm (config .hidden_size , eps =config .rms_norm_eps ) # Halt head: dynamically decides when to stop thinking self .thought_halt_head = nn .Linear (config .hidden_size , 1 , bias =True ) nn .init .constant_ (self .thought_halt_head .bias , -2.0 ) # Biased toward continuing to think initially # Fast Ponder Block for hyper-efficient 10x faster latent reasoning # Bypasses O(N^2) attention, uses pure deep SwiGLU logic self .fast_ponder_block = DeepMoEExpert (config .hidden_size , config .intermediate_size , depth =3 ) self ._init_weights () def _init_weights (self ): nn .init .normal_ (self .embed_tokens .weight ,mean =0.0 ,std =0.02 ) def gradient_checkpointing_enable (self ): """Enable gradient checkpointing for memory efficiency.""" self .gradient_checkpointing =True def gradient_checkpointing_disable (self ): """Disable gradient checkpointing.""" self .gradient_checkpointing =False def forward ( self , input_ids :Optional [torch .Tensor ]=None , attention_mask :Optional [torch .Tensor ]=None , position_ids :Optional [torch .Tensor ]=None , inputs_embeds :Optional [torch .Tensor ]=None , past_key_values :Optional [List [KVCache ]]=None , use_cache :bool =False , output_attentions :bool =False , output_hidden_states :bool =False , return_dict :bool =True , cache_position :Optional [torch .Tensor ]=None , thinking_depth :int =0 , )->Union [Tuple ,MoELlamaModelOutput ]: if inputs_embeds is None : inputs_embeds =self .embed_tokens (input_ids ) hidden_states =inputs_embeds batch_size ,seq_len =hidden_states .shape [:2 ] if position_ids is None : position_ids =torch .arange (seq_len ,device =hidden_states .device ).unsqueeze (0 ).expand (batch_size ,-1 ) if past_key_values is None : past_key_values =[None ]*len (self .layers ) all_hidden_states =()if output_hidden_states else None all_attentions =()if output_attentions else None next_cache =[]if use_cache else None total_aux_loss =torch .tensor (0.0 ,device =hidden_states .device ,dtype =hidden_states .dtype ) for idx ,layer in enumerate (self .layers ): if output_hidden_states : all_hidden_states =all_hidden_states +(hidden_states ,) if self .gradient_checkpointing and self .training and not use_cache : def create_custom_forward (module ): def custom_forward (*inputs ): return module (*inputs ) return custom_forward layer_outputs =torch .utils .checkpoint .checkpoint ( create_custom_forward (layer ), hidden_states , attention_mask , position_ids , past_key_values [idx ], output_attentions , use_cache , use_reentrant =False , ) hidden_states ,attn_weights ,present_key_value ,aux_loss =layer_outputs else : hidden_states ,attn_weights ,present_key_value ,aux_loss =layer ( hidden_states =hidden_states , attention_mask =attention_mask , position_ids =position_ids , past_key_value =past_key_values [idx ], output_attentions =output_attentions , use_cache =use_cache , ) if use_cache : next_cache .append (present_key_value ) if aux_loss is not None : total_aux_loss =total_aux_loss +aux_loss if output_attentions and attn_weights is not None : all_attentions =all_attentions +(attn_weights ,) # โ”€โ”€ Coconut: Continuous Thought Loop โ”€โ”€ # After the normal pass, loop hidden states back through the # transformer layers for extra computation in latent space. # No tokens are decoded โ€” pure continuous reasoning. if thinking_depth > 0 : original_hidden = hidden_states .clone () thought_position_ids = torch .arange ( seq_len , device =hidden_states .device ).unsqueeze (0 ).expand (batch_size , -1 ) for thought_step in range (thinking_depth ): # Check if we should halt thinking (only during inference or if forced) # We evaluate the halt head on the *current* hidden state of the last token halt_logits = self .thought_halt_head (hidden_states [:, -1:, :]) halt_prob = torch .sigmoid (halt_logits ) # If during generation we decide to stop, break early if not self .training and (halt_prob > 0.5 ).all (): break # Normalize before processing hidden_states = self .thought_layernorm (hidden_states ) # Run purely through the attention-free fast ponder block # This achieves ~10x speedup by completely bypassing the O(N^2) self-attention stack hidden_states = self .fast_ponder_block (hidden_states ) # Gated residual: blend thought with original # gate โˆˆ [0,1], initialized small so early training # stays close to original behavior gate = torch .sigmoid (self .thought_gate (hidden_states )) hidden_states = gate * hidden_states + (1.0 - gate ) * original_hidden hidden_states =self .norm (hidden_states ) if output_hidden_states : all_hidden_states =all_hidden_states +(hidden_states ,) return MoELlamaModelOutput ( last_hidden_state =hidden_states , past_key_values =next_cache if use_cache else None , hidden_states =all_hidden_states , attentions =all_attentions , aux_loss =total_aux_loss , ) @dataclass class CausalLMOutput : loss :Optional [torch .Tensor ]=None logits :torch .Tensor =None past_key_values :Optional [List [KVCache ]]=None hidden_states :Optional [Tuple [torch .Tensor ]]=None attentions :Optional [Tuple [torch .Tensor ]]=None aux_loss :Optional [torch .Tensor ]=None class MoELlamaForCausalLM (nn .Module ): """MoE LLaMA for Causal Language Modeling with MLA and Ring Attention.""" def __init__ (self ,config ,moe_config :dict =None ): super ().__init__ () self .config =config self .moe_config =moe_config self .model =MoELlamaModel (config ,moe_config ) self .lm_head =nn .Linear (config .hidden_size ,config .vocab_size ,bias =False ) if getattr (config ,'tie_word_embeddings',True ): self .lm_head .weight =self .model .embed_tokens .weight self .apply (self ._init_weights ) def _init_weights (self ,module ): std =0.02 if isinstance (module ,nn .Linear ): nn .init .normal_ (module .weight ,mean =0.0 ,std =std ) if module .bias is not None : nn .init .zeros_ (module .bias ) elif isinstance (module ,nn .Embedding ): nn .init .normal_ (module .weight ,mean =0.0 ,std =std ) def get_input_embeddings (self )->nn .Embedding : return self .model .embed_tokens def set_input_embeddings (self ,value :nn .Embedding ): self .model .embed_tokens =value def get_output_embeddings (self )->nn .Linear : return self .lm_head def set_output_embeddings (self ,new_embeddings :nn .Linear ): self .lm_head =new_embeddings def gradient_checkpointing_enable (self ): """Enable gradient checkpointing for memory efficiency.""" self .model .gradient_checkpointing_enable () def gradient_checkpointing_disable (self ): """Disable gradient checkpointing.""" self .model .gradient_checkpointing_disable () def prepare_inputs_for_generation ( self , input_ids :torch .Tensor , past_key_values :Optional [List [KVCache ]]=None , attention_mask :Optional [torch .Tensor ]=None , inputs_embeds :Optional [torch .Tensor ]=None , **kwargs , )->dict : if past_key_values is not None : input_ids =input_ids [:,-1 :] position_ids =kwargs .get ("position_ids",None ) if attention_mask is not None and position_ids is None : position_ids =attention_mask .long ().cumsum (-1 )-1 position_ids .masked_fill_ (attention_mask ==0 ,1 ) if past_key_values is not None : position_ids =position_ids [:,-1 :] return { "input_ids":input_ids , "past_key_values":past_key_values , "use_cache":kwargs .get ("use_cache",True ), "position_ids":position_ids , "attention_mask":attention_mask , } def forward ( self , input_ids :Optional [torch .Tensor ]=None , attention_mask :Optional [torch .Tensor ]=None , position_ids :Optional [torch .Tensor ]=None , inputs_embeds :Optional [torch .Tensor ]=None , labels :Optional [torch .Tensor ]=None , past_key_values :Optional [List [KVCache ]]=None , use_cache :bool =False , output_attentions :bool =False , output_hidden_states :bool =False , return_dict :bool =True , cache_position :Optional [torch .Tensor ]=None , thinking_depth :int =0 , **kwargs , )->Union [Tuple ,CausalLMOutput ]: outputs =self .model ( input_ids =input_ids , attention_mask =attention_mask , position_ids =position_ids , inputs_embeds =inputs_embeds , past_key_values =past_key_values , use_cache =use_cache , output_attentions =output_attentions , output_hidden_states =output_hidden_states , return_dict =True , cache_position =cache_position , thinking_depth =thinking_depth , ) hidden_states =outputs .last_hidden_state aux_loss =outputs .aux_loss logits =self .lm_head (hidden_states ) loss =None if labels is not None : shift_logits =logits [...,:-1 ,:].contiguous () shift_labels =labels [...,1 :].contiguous () if shift_labels .dtype !=torch .long : shift_labels =shift_labels .long () valid_mask =(shift_labels !=-100 ) num_valid =valid_mask .sum ().item () if num_valid >0 : loss_fct =nn .CrossEntropyLoss (ignore_index =-100 ) loss =loss_fct ( shift_logits .view (-1 ,shift_logits .size (-1 )), shift_labels .view (-1 ) ) loss =torch .clamp (loss ,min =0.0 ,max =100.0 ) else : loss =torch .tensor (0.0 ,device =logits .device ,dtype =logits .dtype ,requires_grad =True ) return CausalLMOutput ( loss =loss , logits =logits , past_key_values =outputs .past_key_values , hidden_states =outputs .hidden_states , attentions =outputs .attentions , aux_loss =aux_loss , ) @torch .no_grad () def generate ( self , input_ids :torch .Tensor , max_new_tokens :int =100 , temperature :float =1.0 , top_k :int =50 , top_p :float =0.9 , do_sample :bool =True , pad_token_id :Optional [int ]=None , eos_token_id :Optional [int ]=None , attention_mask :Optional [torch .Tensor ]=None , thinking_depth :int =0 , **kwargs , )->torch .Tensor : batch_size =input_ids .shape [0 ] device =input_ids .device past_key_values =None is_prefill =True # Deep thinking only on first pass (full context) if attention_mask is None : attention_mask =torch .ones_like (input_ids ) for _ in range (max_new_tokens ): model_inputs =self .prepare_inputs_for_generation ( input_ids , past_key_values =past_key_values , attention_mask =attention_mask , ) # Apply thinking depth only on prefill, not per-token steps current_depth = thinking_depth if is_prefill else 0 outputs =self .forward (**model_inputs ,use_cache =True ,return_dict =True ,thinking_depth =current_depth ) is_prefill =False next_token_logits =outputs .logits [:,-1 ,:] if temperature !=1.0 : next_token_logits =next_token_logits /temperature if do_sample : if top_k >0 : indices_to_remove =next_token_logits top_p sorted_indices_to_remove [...,1 :]=sorted_indices_to_remove [...,:-1 ].clone () sorted_indices_to_remove [...,0 ]=0 indices_to_remove =sorted_indices_to_remove .scatter (1 ,sorted_indices ,sorted_indices_to_remove ) next_token_logits [indices_to_remove ]=float ('-inf') probs =F .softmax (next_token_logits ,dim =-1 ) next_tokens =torch .multinomial (probs ,num_samples =1 ).squeeze (-1 ) else : next_tokens =torch .argmax (next_token_logits ,dim =-1 ) input_ids =torch .cat ([input_ids ,next_tokens .unsqueeze (-1 )],dim =-1 ) attention_mask =torch .cat ([attention_mask ,torch .ones ((batch_size ,1 ),device =device )],dim =-1 ) past_key_values =outputs .past_key_values if eos_token_id is not None and (next_tokens ==eos_token_id ).all (): break return input_ids ============================================================================== MODELS.XORON ============================================================================== logger =logging .getLogger (__name__ ) MAX_HIDDEN =10000.0 def safe_clamp_tensor (x :torch .Tensor ,max_val :float =MAX_HIDDEN )->torch .Tensor : """Clamp tensor values for FP16 safety, handling NaN/Inf properly. WARNING: Only use for linear/hidden states, NOT for attention scores before softmax! For attention scores, use a max of ~11.0 to prevent exp() overflow. CRITICAL: torch.clamp does NOT fix NaN! clamp(nan, -10, 10) = nan Must use nan_to_num first. """ if x is None or x .numel ()==0 : return x x =torch .nan_to_num (x ,nan =0.0 ,posinf =max_val ,neginf =-max_val ) return x .clamp (-max_val ,max_val ) COMPONENT_GROUPS ={ 'vision':['vision_encoder','projector'], 'video':['video_encoder'], 'audio':['audio_encoder','audio_decoder','audio_projector','waveform_decoder'], 'speech':['waveform_decoder'], 'llm':['llm'], 'cross_attention':['cross_attention_layers'], 'image_generation':['generator'], 'video_generation':['video_generator'], 'modality_markers':['image_start','image_end','video_start','video_end','audio_start','audio_end'], } class MultimodalModelOutput (dict ): """Output class for multimodal model.""" def __getattr__ (self ,name ): try : return self [name ] except KeyError : raise AttributeError (f"'{type (self ).__name__ }' has no attribute '{name }'") def __setattr__ (self ,name ,value ): self [name ]=value class XoronMultimodalModel (nn .Module ): """ Xoron-Dev: Complete multimodal model with: - Image/video understanding (CLIP) - Text generation (MoE LLM) - Image/video generation (MobileDiffusion) - Voice understanding and generation (ASR/TTS) - Cross-attention for multimodal fusion - LoRA support for efficient fine-tuning - Flash Attention for faster training - Model Parallelism support for multi-GPU training """ def __init__ (self ,config :XoronConfig ,device_map :Dict [str ,str ]=None ): super ().__init__ () self .config =config self .device_map =device_map if device_map is not None : device_values =[v for v in device_map .values ()if isinstance (v ,str )] self ._model_parallel =len (set (device_values ))>1 else : self ._model_parallel =False logger .info ("Initializing Xoron-Dev Multimodal Model Build") if self ._model_parallel : logger .info (" โšก Model Parallelism: ENABLED") self .vision_encoder =VisionEncoder (config .vision_model_name ,freeze =config .freeze_vision ) self .video_encoder =VideoEncoder (self .vision_encoder ,max_frames =config .video_max_frames ) logger .info ("Building SOTA Audio Encoder...") self .audio_encoder =AudioEncoder ( hidden_size =config .hidden_size , n_mels =80 , max_audio_length =3000 , use_raw_waveform =getattr (config ,'use_raw_waveform',True ), ) logger .info ("Building SOTA Audio Decoder...") self .audio_decoder =AudioDecoder ( hidden_size =config .hidden_size , n_mels =80 , max_audio_length =1000 , ) logger .info ("Building Raw Waveform Decoder (Speech-to-Speech)...") self .waveform_decoder =RawWaveformDecoder ( hidden_size =config .hidden_size , sample_rate =getattr (config ,'audio_sample_rate',16000 ), ) llm_config =LlamaConfig ( vocab_size =config .vocab_size , hidden_size =config .hidden_size , intermediate_size =config .intermediate_size , num_hidden_layers =config .num_layers , num_attention_heads =config .num_heads , max_position_embeddings =config .max_position_embeddings , rms_norm_eps =1e-6 , tie_word_embeddings =getattr (config ,'tie_word_embeddings',True ), pad_token_id =0 , ) llm_config .use_flash_attention =config .use_flash_attention llm_config .use_ring_attention =getattr (config ,'use_ring_attention',True ) llm_config .ring_attention_chunk_size =getattr (config ,'ring_attention_chunk_size',4096 ) moe_config ={ 'use_moe':config .use_moe , 'num_experts':config .num_experts , 'num_experts_per_tok':config .num_experts_per_tok , 'moe_layer_freq':config .moe_layer_freq , 'intermediate_size':config .intermediate_size , } logger .info (f"Building LLM Core: {config .hidden_size }d, {config .num_layers }L") logger .info (f" ๐Ÿ“ Context: {config .max_position_embeddings //1024 }K positions") if config .use_ring_attention : logger .info (f" ๐Ÿ”„ Ring Attention Enabled (chunk size: {config .ring_attention_chunk_size })") logger .info (f" ๐ŸŽฏ MoE: {config .num_experts } experts, top-{config .num_experts_per_tok }") self .llm =MoELlamaForCausalLM (llm_config ,moe_config ) logger .info (f" โœ… MoE layers initialized: {self .llm .model .num_moe_layers }/{config .num_layers }") self .projector =MultimodalProjector ( self .vision_encoder .hidden_size , config .hidden_size , config .num_vision_tokens ) logger .info (f" ๐Ÿ”— Projector initialized: {self .vision_encoder .hidden_size } -> {config .hidden_size }") self .audio_projector =nn .Linear (config .hidden_size ,config .hidden_size ) self .image_start =nn .Parameter (torch .randn (1 ,1 ,config .hidden_size )*0.02 ) self .image_end =nn .Parameter (torch .randn (1 ,1 ,config .hidden_size )*0.02 ) self .video_start =nn .Parameter (torch .randn (1 ,1 ,config .hidden_size )*0.02 ) self .video_end =nn .Parameter (torch .randn (1 ,1 ,config .hidden_size )*0.02 ) self .audio_start =nn .Parameter (torch .randn (1 ,1 ,config .hidden_size )*0.02 ) self .audio_end =nn .Parameter (torch .randn (1 ,1 ,config .hidden_size )*0.02 ) self .cross_attention_layers =None if config .use_cross_attention : logger .info (f"Building Cross-Attention Fusion ({config .cross_attention_layers } layers)...") self .cross_attention_layers =nn .ModuleList ([ MultimodalFusionLayer ( hidden_size =config .hidden_size , num_heads =config .cross_attention_heads , dropout =config .cross_attention_dropout , use_flash_attention =config .use_flash_attention , ) for _ in range (config .cross_attention_layers ) ]) logger .info (f" โœ… Cross-attention: {config .cross_attention_layers } layers, {config .cross_attention_heads } heads") self .generator =None if config .enable_generation : logger .info ("Building MobileDiffusion Generators (Image & Video)...") self .generator =MobileDiffusionGenerator ( latent_channels =config .generation_latent_channels , base_channels =config .generation_base_channels , context_dim =config .hidden_size , num_inference_steps =config .generation_inference_steps , image_size =config .image_max_size , ) self .video_generator =None if config .enable_generation : self .video_generator =MobileVideoDiffusion ( latent_channels =config .generation_latent_channels , base_channels =config .generation_base_channels //2 , context_dim =config .hidden_size , num_frames =config .video_max_frames , image_size =config .video_max_size , num_inference_steps =config .generation_inference_steps , ) self .num_vision_tokens =config .num_vision_tokens self .video_max_frames =config .video_max_frames self .lora_applied =False self ._print_stats () logger .info ("Xoron-Dev Multimodal Model Build Complete") def apply_model_parallel (self ,device_map :Dict [str ,str ]): """Apply Model Parallelism by sharding components across devices. Trained components get their layers split across all training GPUs. Frozen components go to CPU. Small components (projectors, markers) go to the primary GPU. """ self .device_map =device_map training_gpus = device_map .get ('training_gpus', ['cuda:0']) primary = device_map .get ('primary', 'cuda:0') if len (training_gpus ) <= 1 and not any (v == 'cpu' for v in device_map .values () if isinstance (v, str)): logger .info (" โ„น๏ธ Single device - no model parallelism needed") return self self ._model_parallel = True logger .info ("Applying Model Parallelism (layer sharding)...") def _shard_module (module, name, gpus): """Shard a module's sub-layers across GPUs.""" # Find shardable sub-layers (nn.ModuleList children) layer_lists = [] for attr_name in dir (module): attr = getattr (module, attr_name, None) if isinstance (attr, nn .ModuleList) and len (attr) > 0: layer_lists .append ((attr_name, attr)) if layer_lists: # Shard the largest ModuleList across GPUs layer_lists .sort (key=lambda x: len (x[1]), reverse=True) list_name, layers = layer_lists [0] for i, layer in enumerate (layers): target_gpu = gpus [i % len (gpus)] layer .to (target_gpu) # Put remaining params on primary GPU for param_name, param in module .named_parameters (): if not any (f'{list_name}.' in param_name for _ in [1]): param .data = param .data .to (gpus [0]) logger .info (f" โœ… {name}: {len(layers)} layers sharded across {gpus}") else: # No layers to shard โ€” put whole module on first GPU module .to (gpus [0]) logger .info (f" โœ… {name} -> {gpus[0]}") # Map component names to actual attributes component_attrs = { 'vision_encoder': 'vision_encoder', 'video_encoder': 'video_encoder', 'audio_encoder': 'audio_encoder', 'audio_decoder': 'audio_decoder', 'waveform_decoder': 'waveform_decoder', 'projector': 'projector', 'audio_projector': 'audio_projector', 'llm': 'llm', 'cross_attention': 'cross_attention_layers', 'generator': 'generator', 'video_generator': 'video_generator', } for comp_name, attr_name in component_attrs .items (): comp = getattr (self, attr_name, None) if comp is None: continue target = device_map .get (comp_name, 'cpu') if target == 'cpu': comp .to ('cpu') logger .info (f" โ„๏ธ {comp_name} -> cpu (frozen)") else: # Shard across all training GPUs _shard_module (comp, comp_name, training_gpus) # Modality markers โ†’ primary GPU marker_device = device_map .get ('modality_markers', primary) if marker_device != 'cpu': marker_device = primary for marker_name in ['image_start', 'image_end', 'video_start', 'video_end', 'audio_start', 'audio_end']: marker = getattr (self, marker_name, None) if marker is not None: setattr (self, marker_name, nn .Parameter (marker .data .to (marker_device))) logger .info (f" โœ… Modality markers -> {marker_device}") logger .info ("Model Parallelism applied successfully!") return self def get_llm_device (self ): """Get the device where LLM is located.""" if self .device_map is not None : return torch .device (self .device_map ['llm']) return next (self .llm .parameters ()).device def generate (self ,*args ,**kwargs ): """ Delegates generation to the internal LLM. This allows the model to be treated as a causal LM in many pipelines. """ return self .llm .generate (*args ,**kwargs ) def get_encoder_device (self ): """Get the device where encoders are located.""" if self .device_map is not None : return torch .device (self .device_map ['vision_encoder']) return next (self .vision_encoder .parameters ()).device def apply_lora (self ): """ Apply LoRA to the LLM and optionally cross-attention layers. MEMORY OPTIMIZATION: - LoRA layers share base weights (no cloning) - Base weights in LoRA layers are frozen (requires_grad=False) - LoRA params (A, B, magnitude) are always trainable NOTE: This does NOT freeze other components! Component freezing is handled separately by freeze_components() based on training mode (--text, --video, --image, --voice flags). This allows PARALLEL FINE-TUNING: - LoRA adapters on LLM for efficient adaptation - Full weight training on active components (vision, audio, etc.) """ if self .lora_applied : logger .warning ("LoRA already applied") return if not self .config .use_lora : logger .info ("LoRA disabled in config") return lora_config =LoRAConfig ( r =self .config .lora_r , lora_alpha =self .config .lora_alpha , lora_dropout =self .config .lora_dropout , target_modules =list (self .config .lora_target_modules ), enable_lora =True , ) logger .info ("Applying LoRA to LLM Core...") self .llm =apply_lora_to_model (self .llm ,lora_config ) if self .cross_attention_layers is not None : logger .info ("Applying LoRA to cross-attention layers...") cross_attn_lora_config =LoRAConfig ( r =lora_config .r , lora_alpha =lora_config .lora_alpha , lora_dropout =lora_config .lora_dropout , target_modules =['q_proj','k_proj','v_proj','o_proj'], enable_lora =True , ) for i ,layer in enumerate (self .cross_attention_layers ): self .cross_attention_layers [i ]=apply_lora_to_model (layer ,cross_attn_lora_config ) self .lora_applied =True self ._print_stats () def get_trainable_params (self ): """ Get trainable parameters, respecting LoRA settings and component freezing. If train_lora_only=True and LoRA is applied: - Freezes all non-LoRA params - Returns only LoRA params Otherwise: - Returns all params with requires_grad=True - This includes both LoRA params AND unfrozen component weights - Allows parallel fine-tuning: LoRA + full weights on active components """ if self .config .train_lora_only and self .lora_applied : freeze_non_lora_params (self ) return get_lora_parameters (self ) return [p for p in self .parameters ()if p .requires_grad ] def _print_stats (self ): total =sum (p .numel ()for p in self .parameters ()) trainable =sum (p .numel ()for p in self .parameters ()if p .requires_grad ) logger .info ("Model Statistics:") logger .info (f" Total parameters: {total /1e6 :.1f}M") logger .info (f" Trainable parameters: {trainable /1e6 :.1f}M") if self .lora_applied : lora_params =sum (p .numel ()for n ,p in self .named_parameters ()if 'lora_'in n ) logger .info (f" LoRA parameters: {lora_params /1e6 :.2f}M") def encode_image (self ,pixel_values :torch .Tensor )->torch .Tensor : encoder_device =self .get_encoder_device () pixel_values =pixel_values .to (encoder_device ) vision_features =self .vision_encoder (pixel_values ) projected =self .projector (vision_features ) llm_device =self .get_llm_device () return projected .to (llm_device ) def encode_video (self ,video_frames :torch .Tensor )->torch .Tensor : encoder_device =self .get_encoder_device () video_frames =video_frames .to (encoder_device ) video_features =self .video_encoder (video_frames ) projected =self .projector (video_features ) llm_device =self .get_llm_device () return projected .to (llm_device ) def encode_audio (self ,audio_features :torch .Tensor )->torch .Tensor : encoder_device =self .get_encoder_device () audio_features =audio_features .to (encoder_device ) audio_embeds =self .audio_encoder (audio_features ) projected =self .audio_projector (audio_embeds ) llm_device =self .get_llm_device () return projected .to (llm_device ) def get_text_embeddings (self ,input_ids :torch .Tensor ,attention_mask :torch .Tensor =None )->torch .Tensor : llm_device =self .get_llm_device () input_ids =input_ids .to (llm_device ) embeddings =self .llm .model .embed_tokens (input_ids ) return embeddings def _apply_cross_attention ( self , text_embeds :torch .Tensor , image_embeds :torch .Tensor =None , video_embeds :torch .Tensor =None , audio_embeds :torch .Tensor =None , )->torch .Tensor : if self .cross_attention_layers is None : return text_embeds for fusion_layer in self .cross_attention_layers : text_embeds ,_ =fusion_layer ( text_hidden =text_embeds , image_hidden =image_embeds , video_hidden =video_embeds , audio_hidden =audio_embeds , use_cache =False , ) return text_embeds def forward ( self , input_ids :torch .Tensor , attention_mask :torch .Tensor =None , pixel_values :torch .Tensor =None , video_frames :torch .Tensor =None , audio_features :torch .Tensor =None , labels :torch .Tensor =None , ): """Forward pass - FP16 native.""" batch_size =input_ids .shape [0 ] llm_device =self .get_llm_device () input_ids_llm =input_ids .to (llm_device ) text_embeds =self .llm .model .embed_tokens (input_ids_llm ) text_embeds =safe_clamp_tensor (text_embeds ) device =text_embeds .device if attention_mask is not None : attention_mask =attention_mask .to (device ) if labels is not None : labels =labels .to (device ) image_embeds_for_cross =None video_embeds_for_cross =None audio_embeds_for_cross =None def has_content (tensor ): if tensor is None : return False if not isinstance (tensor ,torch .Tensor ): return False try : if tensor .numel ()==0 : return False return bool (tensor .any ()) except Exception : return False if has_content (pixel_values ): try : image_embeds =self .encode_image (pixel_values ) image_embeds =safe_clamp_tensor (image_embeds ) image_embeds_for_cross =image_embeds image_start =self .image_start .expand (batch_size ,-1 ,-1 ) image_end =self .image_end .expand (batch_size ,-1 ,-1 ) image_embeds =torch .cat ([image_start ,image_embeds ,image_end ],dim =1 ) text_embeds =torch .cat ([image_embeds ,text_embeds ],dim =1 ) text_embeds =safe_clamp_tensor (text_embeds ) if attention_mask is not None : image_mask =torch .ones (batch_size ,image_embeds .shape [1 ],device =device ) attention_mask =torch .cat ([image_mask ,attention_mask ],dim =1 ) if labels is not None : image_labels =torch .full ((batch_size ,image_embeds .shape [1 ]),-100 ,device =device ,dtype =labels .dtype ) labels =torch .cat ([image_labels ,labels ],dim =1 ) except Exception as e : logger .debug (f"Image encoding skipped: {e }") if has_content (video_frames ): try : video_embeds =self .encode_video (video_frames ) video_embeds =safe_clamp_tensor (video_embeds ) video_embeds_for_cross =video_embeds video_start =self .video_start .expand (batch_size ,-1 ,-1 ) video_end =self .video_end .expand (batch_size ,-1 ,-1 ) video_embeds =torch .cat ([video_start ,video_embeds ,video_end ],dim =1 ) text_embeds =torch .cat ([video_embeds ,text_embeds ],dim =1 ) text_embeds =safe_clamp_tensor (text_embeds ) if attention_mask is not None : video_mask =torch .ones (batch_size ,video_embeds .shape [1 ],device =device ) attention_mask =torch .cat ([video_mask ,attention_mask ],dim =1 ) if labels is not None : video_labels =torch .full ((batch_size ,video_embeds .shape [1 ]),-100 ,device =device ,dtype =labels .dtype ) labels =torch .cat ([video_labels ,labels ],dim =1 ) except Exception as e : logger .debug (f"Video encoding skipped: {e }") if has_content (audio_features ): try : audio_embeds =self .encode_audio (audio_features ) audio_embeds =safe_clamp_tensor (audio_embeds ) audio_embeds_for_cross =audio_embeds audio_start =self .audio_start .expand (batch_size ,-1 ,-1 ) audio_end =self .audio_end .expand (batch_size ,-1 ,-1 ) audio_embeds =torch .cat ([audio_start ,audio_embeds ,audio_end ],dim =1 ) text_embeds =torch .cat ([audio_embeds ,text_embeds ],dim =1 ) text_embeds =safe_clamp_tensor (text_embeds ) if attention_mask is not None : audio_mask =torch .ones (batch_size ,audio_embeds .shape [1 ],device =device ) attention_mask =torch .cat ([audio_mask ,attention_mask ],dim =1 ) if labels is not None : audio_labels =torch .full ((batch_size ,audio_embeds .shape [1 ]),-100 ,device =device ,dtype =labels .dtype ) labels =torch .cat ([audio_labels ,labels ],dim =1 ) except Exception as e : logger .debug (f"Audio encoding skipped: {e }") if self .cross_attention_layers is not None : try : text_embeds =self ._apply_cross_attention ( text_embeds , image_embeds =image_embeds_for_cross , video_embeds =video_embeds_for_cross , audio_embeds =audio_embeds_for_cross , ) text_embeds =safe_clamp_tensor (text_embeds ) except Exception as e : logger .debug (f"Cross-attention skipped: {e }") text_embeds =safe_clamp_tensor (text_embeds ) outputs =self .llm (inputs_embeds =text_embeds ,attention_mask =attention_mask ,labels =labels ) return MultimodalModelOutput ( loss =outputs .loss if hasattr (outputs ,'loss')else None , logits =outputs .logits if hasattr (outputs ,'logits')else None , aux_loss =outputs .aux_loss if hasattr (outputs ,'aux_loss')else None , ) @torch .no_grad () def generate_image (self ,input_ids :torch .Tensor ,attention_mask :torch .Tensor =None ): """Generate image from text.""" if self .generator is None : raise ValueError ("Image generator not enabled") context =self .get_text_embeddings (input_ids ,attention_mask ) images =self .generator .generate (context ) return images @torch .no_grad () def generate_video (self ,input_ids :torch .Tensor ,attention_mask :torch .Tensor =None , first_frame :torch .Tensor =None ,num_frames :int =None ): """Generate video from text (T2V) or from image (I2V).""" if self .video_generator is None : raise ValueError ("Video generator not enabled") context =self .get_text_embeddings (input_ids ,attention_mask ) context =context .mean (dim =1 ) if first_frame is not None : video =self .video_generator .generate_i2v (first_frame ,context ,num_frames ) else : video =self .video_generator .generate_t2v (context ,num_frames ) return video @torch .no_grad () def generate_speech (self ,input_ids :torch .Tensor ,attention_mask :torch .Tensor =None ): """Generate speech (mel-spectrogram) from text (TTS).""" text_embeds =self .get_text_embeddings (input_ids ,attention_mask ) mel ,durations ,_ ,_ =self .audio_decoder (text_embeds ) return mel ,durations @torch .no_grad () def speak ( self , input_ids :torch .Tensor , attention_mask :torch .Tensor =None , speaker_embedding :torch .Tensor =None , return_mel :bool =False , )->torch .Tensor : """ Generate playable audio waveform from text (Speech-to-Speech TTS). This is the main method for making the model talk. It converts text directly to audio waveform without needing an external vocoder. Args: input_ids: [B, T] tokenized text input attention_mask: [B, T] attention mask speaker_embedding: [B, D] optional speaker embedding for voice cloning return_mel: If True, also return intermediate mel spectrogram Returns: waveform: [B, T_audio] raw audio waveform in [-1, 1] range at 16kHz Can be played directly or saved as WAV file mel (optional): [B, 80, T_mel] mel spectrogram if return_mel=True """ text_embeds =self .get_text_embeddings (input_ids ,attention_mask ) mel ,durations ,_ ,_ =self .audio_decoder ( text_embeds , speaker_embedding =speaker_embedding , ) mel_features =mel .transpose (1 ,2 ) if not hasattr (self ,'_mel_to_hidden'): self ._mel_to_hidden =nn .Linear (80 ,self .config .hidden_size ).to (mel .device ) audio_features =self ._mel_to_hidden (mel_features ) waveform =self .waveform_decoder (audio_features ) if return_mel : return waveform ,mel return waveform @torch .no_grad () def listen (self ,audio_waveform :torch .Tensor )->torch .Tensor : """ Transcribe audio to text embeddings (Speech-to-Speech ASR). This is the listening component - converts speech to embeddings that can be fed to the LLM for understanding. Args: audio_waveform: [B, T_audio] raw audio waveform Returns: audio_embeds: [B, T, hidden_size] encoded audio features """ return self .encode_audio (audio_waveform ) @torch .no_grad () def listen_and_respond ( self , audio_waveform :torch .Tensor , tokenizer =None , max_new_tokens :int =512 , speaker_embedding :torch .Tensor =None , temperature :float =0.7 , top_p :float =0.9 , tool_executor =None , available_tools :list =None , system_prompt :str =None , max_tool_calls :int =5 , ) -> Dict [str ,Any ]: """ Agentic Speech-to-Speech: Listen, think, use tools, speak back. This is the full agentic pipeline for live voice conversations. The model can detect when the user is asking for actions (e.g. "write me a Python script") and execute tools mid-generation. Pipeline: 1. Encode input audio โ†’ audio embeddings (ASR) 2. Build context (system prompt with tools + audio embeddings) 3. Generate tokens, watching for <|tool_call|> sequences 4. When tool call detected: parse, execute, inject result, resume 5. Synthesize final spoken response from non-tool text Args: audio_waveform: [B, T_audio] input audio waveform tokenizer: Tokenizer for decoding tokens to text (required for tools) max_new_tokens: Maximum total tokens to generate speaker_embedding: [B, D] optional speaker embedding for voice cloning temperature: Sampling temperature top_p: Nucleus sampling probability tool_executor: Callable(tool_name, args_dict) -> str result. If None, tool calls are detected but not executed. available_tools: List of tool definition dicts for system prompt. system_prompt: Optional system prompt override. max_tool_calls: Maximum number of tool calls per response (safety limit). Returns: Dict with: 'waveform': [B, T_response] audio waveform tensor (in-memory, no file I/O) 'text': str full response text (excluding tool call markup) 'token_ids': [B, T_tokens] all generated token IDs 'mel': [B, 80, T_mel] intermediate mel spectrogram 'tool_calls': List[Dict] executed tool calls and their results 'speaking_text': str clean text that was spoken (no tool markup) """ import re import json as _json device = audio_waveform .device batch_size = audio_waveform .shape [0 ] llm_device = self .get_llm_device () # โ”€โ”€ 1. Listen: encode input audio โ”€โ”€ audio_embeds = self .encode_audio (audio_waveform ) # Wrap with start/end markers audio_start = self .audio_start .expand (batch_size , -1 , -1 ).to (llm_device ) audio_end = self .audio_end .expand (batch_size , -1 , -1 ).to (llm_device ) audio_embeds = audio_embeds .to (llm_device ) # โ”€โ”€ 2. Build context with system prompt + tools โ”€โ”€ context_parts = [] if tokenizer is not None and (system_prompt or tool_executor): sys_text = system_prompt or "You are Xoron, an intelligent voice assistant. You can use tools to help the user." if tool_executor and hasattr (tool_executor , 'get_tool_prompt' ): sys_text = sys_text + "\n\n" + tool_executor .get_tool_prompt () elif available_tools : from utils .tool_executor import format_tools_for_prompt sys_text = sys_text + "\n\n" + format_tools_for_prompt (available_tools ) # Encode system prompt and prepend sys_str = "<|system|>" + sys_text + "<|/system|>" sys_token_ids = tokenizer .encode (sys_str , return_tensors ="pt" ).to (llm_device ) sys_embeds = self .llm .model .embed_tokens (sys_token_ids ) context_parts .append (sys_embeds .squeeze (0 ) if sys_embeds .dim () == 3 else sys_embeds ) # Audio context context_parts .extend ([audio_start , audio_embeds , audio_end ]) # Assistant generation prompt if tokenizer is not None : asst_str = "<|assistant|>" asst_ids = tokenizer .encode (asst_str , return_tensors ="pt" ).to (llm_device ) asst_embeds = self .llm .model .embed_tokens (asst_ids ) context_parts .append (asst_embeds .squeeze (0 ) if asst_embeds .dim () == 3 else asst_embeds ) input_embeds = torch .cat (context_parts , dim =1 ) # โ”€โ”€ 3. Agentic generation loop with tool call detection โ”€โ”€ tool_call_start_token = "<|tool_call|>" tool_call_end_token = "<|/tool_call|>" fn_name_start = "<|function_name|>" fn_name_end = "<|/function_name|>" fn_args_start = "<|function_args|>" fn_args_end = "<|/function_args|>" tool_result_start = "<|tool_result|>" tool_result_end = "<|/tool_result|>" eos_token = "<|eos|>" all_generated_ids = [] tool_calls_made = [] num_tool_calls = 0 generated_text = "" total_tokens = 0 # Use standard generation if no tool executor if tool_executor is None or tokenizer is None : gen_kwargs = { 'inputs_embeds': input_embeds , 'max_new_tokens': max_new_tokens , 'do_sample': True , 'temperature': temperature , 'top_p': top_p , 'use_cache': True , } generated_ids = self .llm .generate (**gen_kwargs ) all_generated_ids = [generated_ids ] if tokenizer is not None : generated_text = tokenizer .batch_decode (generated_ids , skip_special_tokens =True )[0 ] else : # Token-by-token generation with tool call detection current_embeds = input_embeds past_key_values = None in_tool_call = False tool_call_buffer = "" while total_tokens < max_new_tokens : outputs = self .llm ( inputs_embeds =current_embeds , past_key_values =past_key_values , use_cache =True , ) past_key_values = outputs .past_key_values logits = outputs .logits [:, -1 :, :] # Sample next token if temperature > 0 : logits = logits / temperature if top_p < 1.0 : sorted_logits , sorted_indices = torch .sort (logits , descending =True , dim =-1 ) cumulative_probs = torch .cumsum (F .softmax (sorted_logits , dim =-1 ), dim =-1 ) sorted_mask = cumulative_probs - F .softmax (sorted_logits , dim =-1 ) >= top_p sorted_logits [sorted_mask ] = float ('-inf' ) logits .scatter_ (-1 , sorted_indices , sorted_logits ) probs = F .softmax (logits , dim =-1 ) next_token = torch .multinomial (probs .squeeze (1 ), num_samples =1 ) else : next_token = logits .argmax (dim =-1 ) total_tokens += 1 all_generated_ids .append (next_token ) # Decode the token token_text = tokenizer .decode (next_token [0 ], skip_special_tokens =False ) generated_text = generated_text + token_text # Check for EOS if eos_token in token_text or next_token .item () == tokenizer .eos_token_id : break # โ”€โ”€ Tool call detection โ”€โ”€ if tool_call_start_token in generated_text and not in_tool_call : in_tool_call = True # Extract everything after the tool_call_start tc_start_idx = generated_text .rfind (tool_call_start_token ) tool_call_buffer = generated_text [tc_start_idx :] if in_tool_call : tool_call_buffer = tool_call_buffer + token_text if tool_call_buffer else generated_text # Check if we have a complete tool call if tool_call_end_token in tool_call_buffer : in_tool_call = False num_tool_calls += 1 # Parse the tool call tool_name = "" tool_args = {} try : # Extract function name name_start = tool_call_buffer .find (fn_name_start ) + len (fn_name_start ) name_end = tool_call_buffer .find (fn_name_end ) if name_start > 0 and name_end > 0 : tool_name = tool_call_buffer [name_start :name_end ].strip () # Extract arguments args_start = tool_call_buffer .find (fn_args_start ) + len (fn_args_start ) args_end = tool_call_buffer .find (fn_args_end ) if args_start > 0 and args_end > 0 : args_str = tool_call_buffer [args_start :args_end ].strip () try : import json as _json tool_args = _json .loads (args_str ) except Exception : tool_args = {"raw": args_str } except Exception : pass # Execute the tool tool_result = "[error]: Failed to parse tool call" if tool_name : tool_result = tool_executor (tool_name , tool_args ) tool_calls_made .append ({ "name": tool_name , "arguments": tool_args , "result": tool_result , }) # Inject tool result back into generation context result_str = tool_result_start + tool_result + tool_result_end result_ids = tokenizer .encode (result_str , return_tensors ="pt" ).to (llm_device ) result_embeds = self .llm .model .embed_tokens (result_ids ) current_embeds = result_embeds past_key_values = None # Reset KV cache to include result all_generated_ids .append (result_ids .squeeze (0 )) generated_text = generated_text + result_str tool_call_buffer = "" if num_tool_calls >= max_tool_calls : break continue # Prepare next input next_embeds = self .llm .model .embed_tokens (next_token ) current_embeds = next_embeds # Combine all generated IDs if all_generated_ids : flat_ids = [] for t in all_generated_ids : if t .dim () == 0 : flat_ids .append (t .unsqueeze (0 )) elif t .dim () == 1 : flat_ids .append (t ) else : flat_ids .append (t .view (-1 )) generated_ids = torch .cat (flat_ids , dim =0 ).unsqueeze (0 ) else : generated_ids = torch .tensor ([[]], dtype =torch .long , device =llm_device ) # โ”€โ”€ 4. Extract speaking text (strip tool call/result markup) โ”€โ”€ speaking_text = generated_text # Remove tool call blocks while tool_call_start_token in speaking_text : tc_s = speaking_text .find (tool_call_start_token ) tc_e = speaking_text .find (tool_call_end_token ) if tc_e > tc_s : speaking_text = speaking_text [:tc_s ] + speaking_text [tc_e + len (tool_call_end_token ):] else : break # Remove tool result blocks while tool_result_start in speaking_text : tr_s = speaking_text .find (tool_result_start ) tr_e = speaking_text .find (tool_result_end ) if tr_e > tr_s : speaking_text = speaking_text [:tr_s ] + speaking_text [tr_e + len (tool_result_end ):] else : break speaking_text = speaking_text .strip () # โ”€โ”€ 5. Speak: encode โ†’ mel โ†’ stream_decode โ†’ waveform โ”€โ”€ response_embeds = self .llm .model .embed_tokens (generated_ids .to (llm_device )) mel , durations , _ , _ = self .audio_decoder ( response_embeds , speaker_embedding =speaker_embedding , ) mel_features = mel .transpose (1 , 2 ) if not hasattr (self , '_mel_to_hidden' ): self ._mel_to_hidden = nn .Linear (80 , self .config .hidden_size ).to (mel .device ) audio_features = self ._mel_to_hidden (mel_features ) waveform = self .waveform_decoder .stream_decode (audio_features ) return { 'waveform': waveform , 'text': generated_text , 'speaking_text': speaking_text , 'token_ids': generated_ids , 'mel': mel , 'tool_calls': tool_calls_made , } def merge_lora_weights (self ): """Merge LoRA weights into main weights for inference.""" if not self .lora_applied : return for module in self .modules (): if isinstance (module ,LoRALinear ): module .merge_lora_weights () logger .info ("LoRA weights merged into base model") def unmerge_lora_weights (self ): """Unmerge LoRA weights for continued training.""" if not self .lora_applied : return for module in self .modules (): if isinstance (module ,LoRALinear ): module .unmerge_lora_weights () logger .info ("LoRA weights unmerged") def save_pretrained ( self , path :str , optimizer =None , scheduler =None , global_step :int =0 , epoch :int =0 , best_loss :float =float ('inf'), sharded :bool =False , max_shard_size :int =2 *1024 *1024 *1024 , save_separately :bool =True , ): """ Save model and optionally training state for resuming. Args: path: Directory to save the model optimizer: Optional optimizer to save state scheduler: Optional scheduler to save state global_step: Current training step epoch: Current epoch best_loss: Best loss achieved so far sharded: If True, save model in multiple .safetensors files max_shard_size: Maximum size per shard in bytes (default 2GB) save_separately: If True, save each component as separate .safetensors files (default) This avoids safetensors issues with shared storage in LSTM weights """ os .makedirs (path ,exist_ok =True ) if save_separately : self ._save_components_safe (path ) elif sharded : self ._save_sharded (path ,max_shard_size ) else : self ._save_single_file_safe (path ) config_dict =self .config .to_dict () config_dict ['has_audio_encoder']=True config_dict ['has_audio_decoder']=True config_dict ['has_waveform_decoder']=hasattr (self ,'waveform_decoder')and self .waveform_decoder is not None config_dict ['has_vision_encoder']=hasattr (self ,'vision_encoder')and self .vision_encoder is not None config_dict ['has_video_encoder']=hasattr (self ,'video_encoder')and self .video_encoder is not None config_dict ['has_generator']=hasattr (self ,'generator')and self .generator is not None config_dict ['has_video_generator']=hasattr (self ,'video_generator')and self .video_generator is not None config_dict ['has_cross_attention']=hasattr (self ,'cross_attention_layers')and self .cross_attention_layers is not None config_dict ['lora_applied']=self .lora_applied config_dict ['architecture_version']=2 config_dict ['auto_map']={ 'AutoConfig':'configuration_xoron.XoronConfig', 'AutoModel':'modeling_xoron.XoronModel', 'AutoModelForCausalLM':'modeling_xoron.XoronForCausalLM', } with open (os .path .join (path ,"config.json"),"w")as f : json .dump (config_dict ,f ,indent =2 ) self ._copy_huggingface_files (path ) if optimizer is not None or scheduler is not None : training_state ={ 'global_step':global_step , 'epoch':epoch , 'best_loss':best_loss , } if optimizer is not None : training_state ['optimizer_state_dict']=optimizer .state_dict () if scheduler is not None : training_state ['scheduler_state_dict']=scheduler .state_dict () torch .save (training_state ,os .path .join (path ,"training_state.pt")) logger .info (f"Training state saved (step {global_step }, epoch {epoch })") logger .info (f"Model saved to {path }") def _copy_huggingface_files (self ,path :str ): """ Build and copy HuggingFace custom code files for trust_remote_code support. This DYNAMICALLY BUILDS a self-contained modeling_xoron.py by combining all model components, so users can load from HuggingFace Hub with: model = AutoModel.from_pretrained("repo/model", trust_remote_code=True) WITHOUT needing to install the full Xoron-Dev package. Args: path: Directory to save the files """ import shutil current_dir =os .path .dirname (os .path .abspath (__file__ )) project_root =os .path .dirname (current_dir ) config_src =os .path .join (project_root ,'configuration_xoron.py') config_dst =os .path .join (path ,'configuration_xoron.py') if os .path .exists (config_src ): shutil .copy2 (config_src ,config_dst ) logger .info ("Copied configuration_xoron.py") modeling_dst =os .path .join (path ,'modeling_xoron.py') self ._build_self_contained_modeling_file (project_root ,modeling_dst ) logger .info ("HuggingFace custom code files ready") def _build_self_contained_modeling_file (self ,project_root :str ,output_path :str ): """ Build a self-contained modeling_xoron.py by combining all model components. This creates a single file with ALL model code embedded, removing internal imports so it works standalone on HuggingFace without the full package. """ import re component_files =[ "models/components/lora.py", "models/components/attention.py", "models/components/projectors.py", "models/components/moe.py", "models/encoders/vision.py", "models/encoders/video.py", "models/encoders/audio.py", "models/generators/image.py", "models/generators/video.py", "models/llm/moe_llama.py", "models/xoron.py", ] internal_import_patterns =[ r"^from config import.*$", r"^from config\..*import.*$", r"^from models\..*import.*$", r"^from models import.*$", ] def is_internal_import (line ): line =line .strip () for pattern in internal_import_patterns : if re .match (pattern ,line ): return True return False def is_module_level_import (line ): """Check if this is a module-level import (no indentation).""" stripped =line .strip () if line and not line [0 ].isspace (): return (stripped .startswith ("import ")or stripped .startswith ("from ")) return False def extract_code_body (content ): """Extract code body, removing module docstring and module-level imports only.""" lines =content .split ('\n') code_lines =[] i =0 in_multiline_import =False while i =2 : i +=1 else : i +=1 while i {saved_vocab_size}") new_embed = nn.Embedding(saved_vocab_size, hidden_size) new_embed.weight.data = state_dict[embed_key] component.model.embed_tokens = new_embed if lm_head_key in state_dict: new_lm_head = nn.Linear(hidden_size, saved_vocab_size, bias=False) new_lm_head.weight.data = state_dict[lm_head_key] component.lm_head = new_lm_head del state_dict[embed_key] if lm_head_key in state_dict: del state_dict[lm_head_key] component.load_state_dict(state_dict, strict=False) logger.info(f"Loaded {comp_name}") markers_path = os.path.join(model_path, "modality_markers.safetensors") if os.path.exists(markers_path): with safe_open(markers_path, framework="pt") as f: model._internal_model.image_start.data = f.get_tensor('image_start') model._internal_model.image_end.data = f.get_tensor('image_end') model._internal_model.video_start.data = f.get_tensor('video_start') model._internal_model.video_end.data = f.get_tensor('video_end') model._internal_model.audio_start.data = f.get_tensor('audio_start') model._internal_model.audio_end.data = f.get_tensor('audio_end') logger.info("Loaded modality markers") logger.info(f"Xoron model loaded from {pretrained_model_name_or_path}") return model def forward( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, pixel_values: Optional[torch.Tensor] = None, video_frames: Optional[torch.Tensor] = None, audio_features: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: self._ensure_model_initialized() return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self._internal_model( input_ids=input_ids, attention_mask=attention_mask, images=pixel_values, video=video_frames, audio=audio_features, labels=labels, ) if return_dict: return CausalLMOutputWithPast( loss=outputs.get("loss"), logits=outputs.get("logits"), past_key_values=outputs.get("past_key_values"), hidden_states=outputs.get("hidden_states"), attentions=outputs.get("attentions"), ) return (outputs.get("loss"), outputs.get("logits")) def generate_image(self, prompt_embeds: torch.Tensor, **kwargs): self._ensure_model_initialized() return self._internal_model.generate_image(prompt_embeds, **kwargs) def generate_video(self, prompt_embeds: torch.Tensor, **kwargs): self._ensure_model_initialized() return self._internal_model.generate_video(prompt_embeds, **kwargs) def generate_speech(self, text_embeds: torch.Tensor, **kwargs): self._ensure_model_initialized() return self._internal_model.generate_speech(text_embeds, **kwargs) class XoronForCausalLM(XoronModel): """Alias for XoronModel for compatibility.""" pass XoronConfig.register_for_auto_class() XoronModel.register_for_auto_class("AutoModel") XoronForCausalLM.register_for_auto_class("AutoModelForCausalLM") ''' all_code .append (hf_wrapper ) final_content ='\n'.join (all_code ) with open (output_path ,'w',encoding ='utf-8')as f : f .write (final_content ) line_count =final_content .count ('\n') logger .info (f"Built self-contained modeling_xoron.py ({line_count :,} lines)") def _save_single_file_safe (self ,path :str ): """ Save model as single safetensors file with cloned tensors. Cloning breaks shared storage that causes safetensors errors. Args: path: Directory to save the model """ from safetensors .torch import save_file state_dict =self .state_dict () safe_state_dict ={} for key ,tensor in state_dict .items (): safe_state_dict [key ]=tensor .clone ().contiguous () save_file (safe_state_dict ,os .path .join (path ,"model.safetensors")) size_mb =sum (t .numel ()*t .element_size ()for t in safe_state_dict .values ())/(1024 *1024 ) logger .info (f"Saved model.safetensors ({size_mb :.1f} MB)") def _save_components_safe (self ,path :str ): """ Save model components as separate .safetensors files with cloned tensors. This is the default and most robust saving method that: 1. Handles LSTM weight sharing issues in safetensors 2. Allows surgical component loading/updates 3. Better for debugging and inspection Args: path: Directory to save component files """ from safetensors .torch import save_file os .makedirs (path ,exist_ok =True ) component_map ={ 'llm':self .llm , 'vision_encoder':self .vision_encoder , 'video_encoder':self .video_encoder , 'audio_encoder':self .audio_encoder , 'audio_decoder':self .audio_decoder , 'projector':self .projector , 'audio_projector':self .audio_projector , } if self .cross_attention_layers is not None : component_map ['cross_attention']=self .cross_attention_layers if self .generator is not None : component_map ['generator']=self .generator if self .video_generator is not None : component_map ['video_generator']=self .video_generator if hasattr (self ,'waveform_decoder')and self .waveform_decoder is not None : component_map ['waveform_decoder']=self .waveform_decoder saved_files =[] total_size =0 for comp_name ,component in component_map .items (): if component is None : continue comp_state =component .state_dict () if not comp_state : continue safe_comp_state ={} for key ,tensor in comp_state .items (): safe_comp_state [key ]=tensor .clone ().contiguous () comp_path =os .path .join (path ,f"{comp_name }.safetensors") save_file (safe_comp_state ,comp_path ) size_mb =sum (t .numel ()*t .element_size ()for t in safe_comp_state .values ())/(1024 *1024 ) total_size +=size_mb logger .info (f"Saved {comp_name }: {size_mb :.1f} MB") saved_files .append (comp_name ) markers ={ 'image_start':self .image_start .data .clone ().contiguous (), 'image_end':self .image_end .data .clone ().contiguous (), 'video_start':self .video_start .data .clone ().contiguous (), 'video_end':self .video_end .data .clone ().contiguous (), 'audio_start':self .audio_start .data .clone ().contiguous (), 'audio_end':self .audio_end .data .clone ().contiguous (), } save_file (markers ,os .path .join (path ,"modality_markers.safetensors")) logger .info ("Saved modality_markers") manifest ={ "components":saved_files +["modality_markers"], "save_format":"components", } with open (os .path .join (path ,"components.json"),"w")as f : json .dump (manifest ,f ,indent =2 ) weight_map ={} total_bytes =0 for comp_name ,component in component_map .items (): if component is None : continue comp_state =component .state_dict () if not comp_state : continue safetensor_file =f"{comp_name }.safetensors" for key in comp_state .keys (): full_key =f"{comp_name }.{key }" weight_map [full_key ]=safetensor_file total_bytes +=comp_state [key ].numel ()*comp_state [key ].element_size () marker_names =['image_start','image_end','video_start','video_end','audio_start','audio_end'] for marker_name in marker_names : weight_map [marker_name ]="modality_markers.safetensors" marker_tensor =getattr (self ,marker_name ) total_bytes +=marker_tensor .numel ()*marker_tensor .element_size () index ={ "metadata":{ "total_size":total_bytes , "format":"components", }, "weight_map":weight_map , } index_path =os .path .join (path ,"model.safetensors.index.json") with open (index_path ,"w")as f : json .dump (index ,f ,indent =2 ) logger .info ("Saved model.safetensors.index.json for HuggingFace compatibility") logger .info (f"Total size: {total_size :.1f} MB across {len (saved_files )} components") def _save_sharded (self ,path :str ,max_shard_size :int ): """ Save model weights in sharded .safetensors files. Components are surgically split across shards. Args: path: Directory to save shards max_shard_size: Maximum bytes per shard """ from safetensors .torch import save_file state_dict =self .state_dict () component_groups ={ 'llm':{}, 'vision_encoder':{}, 'video_encoder':{}, 'audio_encoder':{}, 'audio_decoder':{}, 'waveform_decoder':{}, 'generator':{}, 'video_generator':{}, 'projector':{}, 'audio_projector':{}, 'cross_attention_layers':{}, 'other':{}, } for key ,tensor in state_dict .items (): placed =False for comp_name in component_groups .keys (): if comp_name !='other'and key .startswith (comp_name ): component_groups [comp_name ][key ]=tensor placed =True break if not placed : component_groups ['other'][key ]=tensor shards =[] current_shard ={} current_size =0 shard_index_map ={} for comp_name ,comp_tensors in component_groups .items (): for key ,tensor in comp_tensors .items (): tensor_size =tensor .numel ()*tensor .element_size () if current_size +tensor_size >max_shard_size and current_shard : shards .append (current_shard ) current_shard ={} current_size =0 current_shard [key ]=tensor current_size +=tensor_size if current_shard : shards .append (current_shard ) total_shards =len (shards ) weight_map ={} for i ,shard in enumerate (shards ): shard_name =f"model-{i +1 :05d}-of-{total_shards :05d}.safetensors" shard_path =os .path .join (path ,shard_name ) shard_contiguous ={k :v .clone ().contiguous ()for k ,v in shard .items ()} save_file (shard_contiguous ,shard_path ) for key in shard .keys (): weight_map [key ]=shard_name shard_size_mb =sum (t .numel ()*t .element_size ()for t in shard .values ())/(1024 *1024 ) logger .info (f"Saved shard {i +1 }/{total_shards }: {shard_name } ({shard_size_mb :.1f} MB)") index ={ "metadata":{ "total_size":sum (t .numel ()*t .element_size ()for t in state_dict .values ()), "total_shards":total_shards , }, "weight_map":weight_map , } index_path =os .path .join (path ,"model.safetensors.index.json") with open (index_path ,"w")as f : json .dump (index ,f ,indent =2 ) logger .info ("Saved index: model.safetensors.index.json") def save_components_separately (self ,path :str ): """ Save model components as separate .safetensors files. Useful for surgical component updates and debugging. NOTE: This method now clones tensors to handle LSTM shared storage issues. Args: path: Directory to save component files """ from safetensors .torch import save_file os .makedirs (path ,exist_ok =True ) component_map ={ 'llm':self .llm , 'vision_encoder':self .vision_encoder , 'video_encoder':self .video_encoder , 'audio_encoder':self .audio_encoder , 'audio_decoder':self .audio_decoder , 'projector':self .projector , 'audio_projector':self .audio_projector , } if self .cross_attention_layers is not None : component_map ['cross_attention']=self .cross_attention_layers if self .generator is not None : component_map ['generator']=self .generator if self .video_generator is not None : component_map ['video_generator']=self .video_generator if hasattr (self ,'waveform_decoder')and self .waveform_decoder is not None : component_map ['waveform_decoder']=self .waveform_decoder saved_files =[] for comp_name ,component in component_map .items (): if component is None : continue comp_state =component .state_dict () if not comp_state : continue comp_state ={k :v .clone ().contiguous ()for k ,v in comp_state .items ()} comp_path =os .path .join (path ,f"{comp_name }.safetensors") save_file (comp_state ,comp_path ) size_mb =sum (t .numel ()*t .element_size ()for t in comp_state .values ())/(1024 *1024 ) logger .info (f"Saved {comp_name }: {size_mb :.1f} MB") saved_files .append (comp_name ) markers ={ 'image_start':self .image_start .data .clone ().contiguous (), 'image_end':self .image_end .data .clone ().contiguous (), 'video_start':self .video_start .data .clone ().contiguous (), 'video_end':self .video_end .data .clone ().contiguous (), 'audio_start':self .audio_start .data .clone ().contiguous (), 'audio_end':self .audio_end .data .clone ().contiguous (), } save_file (markers ,os .path .join (path ,"modality_markers.safetensors")) logger .info ("Saved modality_markers") manifest ={ "components":saved_files +["modality_markers"], "config":self .config .to_dict (), "lora_applied":self .lora_applied , } with open (os .path .join (path ,"components.json"),"w")as f : json .dump (manifest ,f ,indent =2 ) weight_map ={} total_bytes =0 for comp_name ,component in component_map .items (): if component is None : continue comp_state =component .state_dict () if not comp_state : continue safetensor_file =f"{comp_name }.safetensors" for key in comp_state .keys (): full_key =f"{comp_name }.{key }" weight_map [full_key ]=safetensor_file total_bytes +=comp_state [key ].numel ()*comp_state [key ].element_size () marker_names =['image_start','image_end','video_start','video_end','audio_start','audio_end'] for marker_name in marker_names : weight_map [marker_name ]="modality_markers.safetensors" marker_tensor =getattr (self ,marker_name ) total_bytes +=marker_tensor .numel ()*marker_tensor .element_size () index ={ "metadata":{ "total_size":total_bytes , "format":"components", }, "weight_map":weight_map , } index_path =os .path .join (path ,"model.safetensors.index.json") with open (index_path ,"w")as f : json .dump (index ,f ,indent =2 ) logger .info ("Saved model.safetensors.index.json for HuggingFace compatibility") logger .info (f"Components saved to {path }") @classmethod def from_pretrained ( cls , path :str , device :str =None , device_map :Dict [str ,str ]=None , apply_lora :bool =True , strict :bool =False , )->'XoronMultimodalModel': """ Load a pretrained Xoron model from a checkpoint or final model directory. Args: path: Path to the saved model directory device: Device to load the model to (if not using device_map) device_map: Device map for model parallelism apply_lora: Whether to apply LoRA after loading strict: If False, allows loading weights even if architecture changed Returns: Loaded XoronMultimodalModel instance """ from safetensors import safe_open logger .info (f"Loading model from {path }...") config_path =os .path .join (path ,"config.json") if not os .path .exists (config_path ): raise FileNotFoundError (f"Config file not found at {config_path }") with open (config_path ,'r')as f : config_dict =json .load (f ) lora_was_applied =config_dict .pop ('lora_applied',False ) architecture_version =config_dict .pop ('architecture_version',1 ) has_waveform_decoder =config_dict .pop ('has_waveform_decoder',False ) has_vision_encoder =config_dict .pop ('has_vision_encoder',True ) has_video_encoder =config_dict .pop ('has_video_encoder',True ) has_generator =config_dict .pop ('has_generator',True ) has_video_generator =config_dict .pop ('has_video_generator',True ) has_cross_attention =config_dict .pop ('has_cross_attention',True ) config_dict .pop ('has_audio_encoder',None ) config_dict .pop ('has_audio_decoder',None ) logger .info (f"Saved model architecture (version {architecture_version }):") logger .info (f" - Waveform Decoder: {'โœ…'if has_waveform_decoder else 'โŒ (will init randomly)'}") logger .info (f" - Vision Encoder: {'โœ…'if has_vision_encoder else 'โŒ'}") logger .info (f" - Video Encoder: {'โœ…'if has_video_encoder else 'โŒ'}") logger .info (f" - Image Generator: {'โœ…'if has_generator else 'โŒ'}") logger .info (f" - Video Generator: {'โœ…'if has_video_generator else 'โŒ'}") logger .info (f" - Cross Attention: {'โœ…'if has_cross_attention else 'โŒ'}") logger .info (f" - LoRA Applied: {'โœ…'if lora_was_applied else 'โŒ'}") config =XoronConfig .from_dict (config_dict ) model =cls (config ,device_map =device_map ) if lora_was_applied: logger .info ("Checkpoint has LoRA weights. Applying LoRA structure before loading...") model .apply_lora () components_json =os .path .join (path ,"components.json") model_path =os .path .join (path ,"model.safetensors") if os .path .exists (components_json ): logger .info ("Loading from component-based format...") model ._load_components (path ,strict =strict ) model .lora_applied =False # Always allow fresh LoRA application (checkpoint has merged weights) elif os .path .exists (model_path ): logger .info ("Loading weights from safetensors...") if strict : load_model (model ,model_path ) else : checkpoint_state_dict ={} with safe_open (model_path ,framework ="pt",device ="cpu")as f : for key in f .keys (): checkpoint_state_dict [key ]=f .get_tensor (key ) model .load_state_dict (checkpoint_state_dict ,strict =False ) logger .info ("Loaded weights from checkpoint") model .lora_applied =False # Always allow fresh LoRA application (checkpoint has merged weights) else : pytorch_path =os .path .join (path ,"pytorch_model.bin") if os .path .exists (pytorch_path ): logger .info ("Loading weights from pytorch_model.bin...") checkpoint_state_dict =torch .load (pytorch_path ,map_location ='cpu') model .load_state_dict (checkpoint_state_dict ,strict =False ) logger .info ("Loaded weights from checkpoint") model .lora_applied =False # Always allow fresh LoRA application (checkpoint has merged weights) else : raise FileNotFoundError (f"No model weights found at {path }") if apply_lora and config .use_lora and not model .lora_applied : model .apply_lora () if device_map is not None : model .apply_model_parallel (device_map ) elif device is not None : model =model .to (device ) logger .info ("Model loaded successfully!") model ._print_stats () return model def _load_components (self ,path :str ,strict :bool =False ): """ Load model from component-based safetensors files. Args: path: Directory containing component files strict: If True, require exact match; if False, allow partial loading """ from safetensors import safe_open component_map ={ 'llm':self .llm , 'vision_encoder':self .vision_encoder , 'video_encoder':self .video_encoder , 'audio_encoder':self .audio_encoder , 'audio_decoder':self .audio_decoder , 'projector':self .projector , 'audio_projector':self .audio_projector , } if self .cross_attention_layers is not None : component_map ['cross_attention']=self .cross_attention_layers if self .generator is not None : component_map ['generator']=self .generator if self .video_generator is not None : component_map ['video_generator']=self .video_generator if hasattr (self ,'waveform_decoder')and self .waveform_decoder is not None : component_map ['waveform_decoder']=self .waveform_decoder for comp_name ,component in component_map .items (): if component is None : continue comp_path =os .path .join (path ,f"{comp_name }.safetensors") if not os .path .exists (comp_path ): continue try : checkpoint_state ={} with safe_open (comp_path ,framework ="pt",device ="cpu")as f : for key in f .keys (): checkpoint_state [key ]=f .get_tensor (key ) component .load_state_dict (checkpoint_state ,strict =strict ) size_mb =sum (t .numel ()*t .element_size ()for t in checkpoint_state .values ())/(1024 *1024 ) logger .info (f"Loaded {comp_name } ({size_mb :.1f} MB)") except Exception as e : logger .warning (f"Error loading {comp_name }: {e }") markers_path =os .path .join (path ,"modality_markers.safetensors") if os .path .exists (markers_path ): try : with safe_open (markers_path ,framework ="pt",device ="cpu")as f : self .image_start .data =f .get_tensor ('image_start') self .image_end .data =f .get_tensor ('image_end') self .video_start .data =f .get_tensor ('video_start') self .video_end .data =f .get_tensor ('video_end') self .audio_start .data =f .get_tensor ('audio_start') self .audio_end .data =f .get_tensor ('audio_end') logger .info ("Loaded modality_markers") except Exception as e : logger .warning (f"Error loading modality_markers: {e }") logger .info ("Components loaded successfully") @staticmethod def load_training_state (path :str )->Optional [Dict ]: """ Load training state from a checkpoint. Args: path: Path to the checkpoint directory Returns: Dictionary with training state or None if not found """ state_path =os .path .join (path ,"training_state.pt") if os .path .exists (state_path ): logger .info (f"Loading training state from {state_path }...") return torch .load (state_path ,map_location ='cpu') return None def freeze_components (self ,components :List [str ],hard_freeze :bool =True ): """ Freeze specific components of the model. IMPORTANT RULES: 1. LLM is NEVER frozen - it's trained from scratch and always needs full weight training 2. LoRA parameters are usually kept trainable, UNLESS hard_freeze=True Args: components: List of component group names to freeze. Valid groups: 'vision', 'video', 'audio', 'cross_attention', 'image_generation', 'video_generation', 'modality_markers' NOTE: 'llm' is NOT a valid group to freeze - will be ignored! hard_freeze: If True, completely freezes the component including its LoRA adapters. This prevents inactive components from updating via weight decay/momentum. """ if 'llm'in components : logger .warning ("Ignoring 'llm' in freeze list - LLM must always train (from scratch)") components =[c for c in components if c !='llm'] logger .info (f"Freezing components: {components } (hard_freeze={hard_freeze })") for group_name in components : if group_name not in COMPONENT_GROUPS : logger .warning (f" โš ๏ธ Unknown component group: {group_name }") continue for attr_name in COMPONENT_GROUPS [group_name ]: if hasattr (self ,attr_name ): component =getattr (self ,attr_name ) if component is not None : if isinstance (component ,nn .Parameter ): component .requires_grad =False elif isinstance (component ,nn .Module ): for name ,param in component .named_parameters (): path_lora ='lora_A'in name or 'lora_B'in name or 'magnitude'in name if hard_freeze or not path_lora : param .requires_grad =False logger .info (f"Frozen: {attr_name }") if self .lora_applied and not hard_freeze: enable_lora_training (self ) logger .info ("LoRA parameters remain trainable") self ._print_stats () def unfreeze_components (self ,components :List [str ]): """ Unfreeze specific components of the model. Args: components: List of component group names to unfreeze. """ logger .info (f"Unfreezing components: {components }") for group_name in components : if group_name not in COMPONENT_GROUPS : logger .warning (f" โš ๏ธ Unknown component group: {group_name }") continue for attr_name in COMPONENT_GROUPS [group_name ]: if hasattr (self ,attr_name ): component =getattr (self ,attr_name ) if component is not None : if isinstance (component ,nn .Parameter ): component .requires_grad =True elif isinstance (component ,nn .Module ): for param in component .parameters (): param .requires_grad =True logger .info (f"Unfrozen: {attr_name }") self ._print_stats () def freeze_all_except (self ,components :List [str ],hard_freeze :bool =True ): """ Freeze all components except the specified ones. NOTE: LLM is always kept trainable regardless of input - it's trained from scratch. Args: components: List of component group names to keep trainable. """ if 'llm'not in components : components =components +['llm'] all_groups =list (COMPONENT_GROUPS .keys ()) groups_to_freeze =[g for g in all_groups if g not in components ] self .freeze_components (groups_to_freeze ,hard_freeze =hard_freeze ) def get_trainable_component_names (self )->List [str ]: """Get list of component groups that have trainable parameters.""" trainable =[] for group_name ,attr_names in COMPONENT_GROUPS .items (): for attr_name in attr_names : if hasattr (self ,attr_name ): component =getattr (self ,attr_name ) if component is not None : if isinstance (component ,nn .Parameter ): if component .requires_grad : trainable .append (group_name ) break elif isinstance (component ,nn .Module ): if any (p .requires_grad for p in component .parameters ()): trainable .append (group_name ) break return trainable def get_frozen_component_names (self )->List [str ]: """Get list of component groups that are frozen (no trainable parameters).""" frozen =[] for group_name ,attr_names in COMPONENT_GROUPS .items (): has_component =False is_trainable =False for attr_name in attr_names : if hasattr (self ,attr_name ): component =getattr (self ,attr_name ) if component is not None : has_component =True if isinstance (component ,nn .Parameter ): if component .requires_grad : is_trainable =True break elif isinstance (component ,nn .Module ): if any (p .requires_grad for p in component .parameters ()): is_trainable =True break if has_component and not is_trainable : frozen .append (group_name ) return frozen def get_component_status (self )->tuple : """ Get tuple of (trainable_components, frozen_components) for display. Returns: tuple: (list of trainable component names, list of frozen component names) """ trainable =self .get_trainable_component_names () frozen =self .get_frozen_component_names () return trainable ,frozen class XoronPreTrainedModel(PreTrainedModel): """Base class for Xoron models providing HuggingFace integration.""" config_class = XoronConfig base_model_prefix = "xoron" supports_gradient_checkpointing = True _no_split_modules = ["XoronMultimodalModel"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True def _init_weights(self, module): std = 0.02 if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) class XoronModel(XoronPreTrainedModel): """Xoron Multimodal Model for HuggingFace.""" def __init__(self, config: XoronConfig): super().__init__(config) self.config = config self._internal_model = None self._model_initialized = False def _ensure_model_initialized(self): """Lazily initialize the internal model to avoid meta device conflicts.""" if not self._model_initialized: self._internal_model = XoronMultimodalModel(self.config) self._model_initialized = True @property def internal_model(self): self._ensure_model_initialized() return self._internal_model @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): """ Load pretrained Xoron model from HuggingFace Hub or local path. This override ensures proper initialization without meta device conflicts. """ kwargs.pop('device_map', None) config = kwargs.pop('config', None) if config is None: config = XoronConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) model = cls(config) model._internal_model = XoronMultimodalModel(config) model._model_initialized = True import os from safetensors import safe_open if os.path.isdir(pretrained_model_name_or_path): model_path = pretrained_model_name_or_path else: from huggingface_hub import snapshot_download model_path = snapshot_download(repo_id=pretrained_model_name_or_path) components_json = os.path.join(model_path, "components.json") if os.path.exists(components_json): with open(components_json, 'r') as f: manifest = json.load(f) component_map = { 'llm': model._internal_model.llm, 'vision_encoder': model._internal_model.vision_encoder, 'video_encoder': model._internal_model.video_encoder, 'audio_encoder': model._internal_model.audio_encoder, 'audio_decoder': model._internal_model.audio_decoder, 'projector': model._internal_model.projector, 'audio_projector': model._internal_model.audio_projector, } if model._internal_model.cross_attention_layers is not None: component_map['cross_attention'] = model._internal_model.cross_attention_layers if model._internal_model.generator is not None: component_map['generator'] = model._internal_model.generator if model._internal_model.video_generator is not None: component_map['video_generator'] = model._internal_model.video_generator if hasattr(model._internal_model, 'waveform_decoder') and model._internal_model.waveform_decoder is not None: component_map['waveform_decoder'] = model._internal_model.waveform_decoder for comp_name in manifest.get('components', []): if comp_name == 'modality_markers': continue comp_path = os.path.join(model_path, f"{comp_name}.safetensors") if os.path.exists(comp_path) and comp_name in component_map: component = component_map[comp_name] if component is not None: with safe_open(comp_path, framework="pt") as f: state_dict = {k: f.get_tensor(k) for k in f.keys()} if comp_name == 'llm': embed_key = 'model.embed_tokens.weight' lm_head_key = 'lm_head.weight' if embed_key in state_dict: saved_vocab_size = state_dict[embed_key].shape[0] hidden_size = state_dict[embed_key].shape[1] current_vocab_size = component.model.embed_tokens.weight.shape[0] if saved_vocab_size != current_vocab_size: logger.info(f"Resizing embeddings: {current_vocab_size} -> {saved_vocab_size}") new_embed = nn.Embedding(saved_vocab_size, hidden_size) new_embed.weight.data = state_dict[embed_key] component.model.embed_tokens = new_embed if lm_head_key in state_dict: new_lm_head = nn.Linear(hidden_size, saved_vocab_size, bias=False) new_lm_head.weight.data = state_dict[lm_head_key] component.lm_head = new_lm_head del state_dict[embed_key] if lm_head_key in state_dict: del state_dict[lm_head_key] component.load_state_dict(state_dict, strict=False) logger.info(f"Loaded {comp_name}") markers_path = os.path.join(model_path, "modality_markers.safetensors") if os.path.exists(markers_path): with safe_open(markers_path, framework="pt") as f: model._internal_model.image_start.data = f.get_tensor('image_start') model._internal_model.image_end.data = f.get_tensor('image_end') model._internal_model.video_start.data = f.get_tensor('video_start') model._internal_model.video_end.data = f.get_tensor('video_end') model._internal_model.audio_start.data = f.get_tensor('audio_start') model._internal_model.audio_end.data = f.get_tensor('audio_end') logger.info("Loaded modality markers") logger.info(f"Xoron model loaded from {pretrained_model_name_or_path}") return model def forward( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, pixel_values: Optional[torch.Tensor] = None, video_frames: Optional[torch.Tensor] = None, audio_features: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: self._ensure_model_initialized() return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self._internal_model( input_ids=input_ids, attention_mask=attention_mask, images=pixel_values, video=video_frames, audio=audio_features, labels=labels, ) if return_dict: return CausalLMOutputWithPast( loss=outputs.get("loss"), logits=outputs.get("logits"), past_key_values=outputs.get("past_key_values"), hidden_states=outputs.get("hidden_states"), attentions=outputs.get("attentions"), ) return (outputs.get("loss"), outputs.get("logits")) def generate_image(self, prompt_embeds: torch.Tensor, **kwargs): self._ensure_model_initialized() return self._internal_model.generate_image(prompt_embeds, **kwargs) def generate_video(self, prompt_embeds: torch.Tensor, **kwargs): self._ensure_model_initialized() return self._internal_model.generate_video(prompt_embeds, **kwargs) def generate_speech(self, text_embeds: torch.Tensor, **kwargs): self._ensure_model_initialized() return self._internal_model.generate_speech(text_embeds, **kwargs) class XoronForCausalLM(XoronModel): """Alias for XoronModel for compatibility.""" pass XoronConfig.register_for_auto_class() XoronModel.register_for_auto_class("AutoModel") XoronForCausalLM.register_for_auto_class("AutoModelForCausalLM")