diff --git "a/caca_transformers.py" "b/caca_transformers.py" new file mode 100644--- /dev/null +++ "b/caca_transformers.py" @@ -0,0 +1,2485 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +from typing import Optional, Tuple, List +from transformers import PreTrainedModel, PretrainedConfig +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.generation.utils import GenerationMixin +from collections import OrderedDict +import logging +from functools import lru_cache + +logger = logging.getLogger(__name__) + +try: + from flash_attn import flash_attn_func + HAS_FLASH_ATTN = True +except ImportError: + HAS_FLASH_ATTN = False + +try: + from xformers.ops import memory_efficient_attention + HAS_XFORMERS = True +except ImportError: + HAS_XFORMERS = False + +HAS_SDPA = hasattr(F, 'scaled_dot_product_attention') + +# --- config --- +class CacaConfig(PretrainedConfig): + model_type = "caca" + def __init__( + self, + vocab_size=32000, + hidden_size=2048, + intermediate_size=8192, + num_hidden_layers=24, + num_attention_heads=32, + num_key_value_heads=8, + head_dim=None, + max_position_embeddings=8192, + rms_norm_eps=1e-6, + qk_norm_eps=1e-6, + use_qk_norm=True, + initializer_range=0.02, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + use_rotary_embeddings=True, + use_alibi=False, + attention_bias=False, + attention_dropout=0.0, + use_flash_attn=True, + use_grouped_query_attention=False, + use_multi_query_attention=False, + sliding_window=None, + use_longformer_attention=False, + longformer_attention_window=512, + attn_logit_softcapping=None, + final_logit_softcapping=None, + attention_sink_size=4, + attention_sink_window=1024, + use_attention_sink=False, + attention_pattern="all_global", + global_attention_every_n_layers=2, + mlp_bias=False, + hidden_dropout=0.1, + residual_dropout=0.1, + use_moe=False, + num_experts=8, + num_experts_per_tok=2, + use_expert_choice=False, + expert_choice_k=0.125, + router_aux_loss_coef=0.01, + router_z_loss_coef=0.001, + moe_layer_frequency=2, + expert_capacity_factor=1.0, + use_grouped_moe=False, + num_expert_groups=1, + use_layer_scale=False, + layer_scale_init=1e-5, + use_stochastic_depth=False, + stochastic_depth_prob=0.1, + use_mixture_of_depths=False, + mod_capacity_factor=0.5, + mod_route_method="learned", + use_cross_attention=False, + cross_attention_frequency=4, + use_multimodal=False, + vision_config=None, + audio_config=None, + projector_hidden_size=None, + use_soft_merging=False, + merge_threshold=0.5, + pretraining_tp=1, + tensor_parallel_size=1, + pipeline_parallel_size=1, + chat_template=None, + **kwargs + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.head_dim = head_dim or (hidden_size // num_attention_heads if hidden_size and num_attention_heads else None) + self.max_position_embeddings = max_position_embeddings + self.rms_norm_eps = rms_norm_eps + self.qk_norm_eps = qk_norm_eps + self.initializer_range = initializer_range + self.use_cache = use_cache + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.tie_word_embeddings = tie_word_embeddings + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.use_rotary_embeddings = use_rotary_embeddings + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.use_qk_norm = use_qk_norm + self.use_alibi = use_alibi + self.use_flash_attn = use_flash_attn + self.use_grouped_query_attention = use_grouped_query_attention + self.use_multi_query_attention = use_multi_query_attention + self.sliding_window = sliding_window + self.use_longformer_attention = use_longformer_attention + self.longformer_attention_window = longformer_attention_window + self.attn_logit_softcapping = attn_logit_softcapping + self.final_logit_softcapping = final_logit_softcapping + self.attention_sink_size = attention_sink_size + self.attention_sink_window = attention_sink_window + self.use_attention_sink = use_attention_sink + self.attention_pattern = attention_pattern + self.global_attention_every_n_layers = global_attention_every_n_layers + self.mlp_bias = mlp_bias + self.hidden_dropout = hidden_dropout + self.residual_dropout = residual_dropout + self.use_moe = use_moe + self.num_experts = num_experts + self.num_experts_per_tok = num_experts_per_tok + self.use_expert_choice = use_expert_choice + self.expert_choice_k = expert_choice_k + self.router_aux_loss_coef = router_aux_loss_coef + self.router_z_loss_coef = router_z_loss_coef + self.moe_layer_frequency = moe_layer_frequency + self.expert_capacity_factor = expert_capacity_factor + self.use_grouped_moe = use_grouped_moe + self.num_expert_groups = num_expert_groups + self.use_layer_scale = use_layer_scale + self.layer_scale_init = layer_scale_init + self.use_stochastic_depth = use_stochastic_depth + self.stochastic_depth_prob = stochastic_depth_prob + self.use_mixture_of_depths = use_mixture_of_depths + self.mod_capacity_factor = mod_capacity_factor + self.mod_route_method = mod_route_method + self.use_cross_attention = use_cross_attention + self.cross_attention_frequency = cross_attention_frequency + self.use_multimodal = use_multimodal + self.vision_config = vision_config or {} + self.audio_config = audio_config or {} + self.projector_hidden_size = projector_hidden_size or hidden_size + self.use_soft_merging = use_soft_merging + self.merge_threshold = merge_threshold + self.pretraining_tp = pretraining_tp + self.tensor_parallel_size = tensor_parallel_size + self.pipeline_parallel_size = pipeline_parallel_size + if chat_template is None: + self.chat_template = ( + "{% for message in messages %}" + "{% if message['role'] == 'system' %}" + "System: {{ message['content'] }}\n" + "{% elif message['role'] == 'user' %}" + "User: {{ message['content'] }}\n" + "{% elif message['role'] == 'assistant' %}" + "Assistant: {{ message['content'] }}\n" + "{% endif %}" + "{% endfor %}" + "{% if add_generation_prompt %}Assistant:{% endif %}" + ) + else: + self.chat_template = chat_template + self._validate_config() + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs + ) + def _validate_config(self): + if self.num_attention_heads % self.num_key_value_heads != 0: + raise ValueError( + f"num_attention_heads ({self.num_attention_heads}) must be divisible by " + f"num_key_value_heads ({self.num_key_value_heads})" + ) + if self.num_key_value_heads > self.num_attention_heads: + raise ValueError( + f"num_key_value_heads ({self.num_key_value_heads}) cannot be > " + f"num_attention_heads ({self.num_attention_heads})" + ) + if self.hidden_size % self.num_attention_heads != 0: + raise ValueError( + f"hidden_size ({self.hidden_size}) must be divisible by " + f"num_attention_heads ({self.num_attention_heads})" + ) + expected_head_dim = self.hidden_size // self.num_attention_heads + if self.head_dim is None: + self.head_dim = expected_head_dim + else: + total_dim = self.head_dim * self.num_attention_heads + if total_dim != self.hidden_size: + logger.warning( + f"head_dim ({self.head_dim}) * num_attention_heads ({self.num_attention_heads}) " + f"= {total_dim} != hidden_size ({self.hidden_size}). " + f"This may cause dimension mismatch issues." + ) + if self.intermediate_size <= 0: + raise ValueError(f"intermediate_size must be > 0, got {self.intermediate_size}") + ffn_ratio = self.intermediate_size / self.hidden_size + if ffn_ratio < 1.5 or ffn_ratio > 10: + logger.warning( + f"intermediate_size/hidden_size ratio ({ffn_ratio:.2f}) is unusual. " + f"Typical range is 2-8x. Got intermediate_size={self.intermediate_size}, " + f"hidden_size={self.hidden_size}" + ) + if self.vocab_size <= 0: + raise ValueError(f"vocab_size must be > 0, got {self.vocab_size}") + if self.vocab_size > 1000000: + logger.warning( + f"vocab_size ({self.vocab_size}) is very large. " + f"This may cause memory issues." + ) + if self.use_flash_attn and not HAS_FLASH_ATTN: + logger.warning( + "use_flash_attn=True but flash-attn not installed. " + "Will fallback to SDPA/standard attention." + ) + if self.sliding_window is not None: + if self.sliding_window > self.max_position_embeddings: + raise ValueError( + f"sliding_window ({self.sliding_window}) cannot be > " + f"max_position_embeddings ({self.max_position_embeddings})" + ) + if self.sliding_window < 128: + logger.warning( + f"sliding_window ({self.sliding_window}) is very small. " + f"This may limit context significantly." + ) + if self.use_moe: + if self.num_experts < self.num_experts_per_tok: + raise ValueError( + f"num_experts ({self.num_experts}) must be >= " + f"num_experts_per_tok ({self.num_experts_per_tok})" + ) + if self.moe_layer_frequency <= 0: + raise ValueError(f"moe_layer_frequency must be > 0") + if self.moe_layer_frequency > self.num_hidden_layers: + logger.warning( + f"moe_layer_frequency ({self.moe_layer_frequency}) > " + f"num_hidden_layers ({self.num_hidden_layers}). " + f"MoE will not be used." + ) + if self.expert_capacity_factor <= 0: + raise ValueError(f"expert_capacity_factor must be > 0") + def to_dict(self): + quantization_config_backup = getattr(self, 'quantization_config', None) + if quantization_config_backup is None and hasattr(self, 'quantization_config'): + delattr(self, 'quantization_config') + temp_removed = True + else: + temp_removed = False + try: + output = super().to_dict() + output['auto_map'] = { + "AutoConfig": "caca_transformers.CacaConfig", + "AutoModel": "caca_transformers.CacaModel", + "AutoModelForCausalLM": "caca_transformers.CacaForCausalLM" + } + return output + finally: + if temp_removed: + self.quantization_config = None + elif quantization_config_backup is not None: + self.quantization_config = quantization_config_backup + @classmethod + def from_dict(cls, config_dict, **kwargs): + config_dict = {k: v for k, v in config_dict.items() if k != 'auto_map'} + config_dict.update(kwargs) + return cls(**config_dict) + +# --- Arsitektur Model --- +class MetricsTracker: + def __init__(self): + self.metrics = defaultdict(list) + self.reset_interval = 100 + self.step_count = 0 + + def log(self, name, value): + if isinstance(value, torch.Tensor): + value = value.item() + self.metrics[name].append(value) + + def step(self): + self.step_count += 1 + if self.step_count % self.reset_interval == 0: + self.clear() + + def get_summary(self): + summary = {} + for name, values in self.metrics.items(): + if values: + summary[name] = { + "mean": np.mean(values), + "std": np.std(values), + "min": np.min(values), + "max": np.max(values), + "last": values[-1], + } + return summary + + def clear(self): + self.metrics.clear() + +class CacaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.eps = eps + self.variance_epsilon = eps + + def forward(self, x): + input_dtype = x.dtype + x = x.float() + variance = x.pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(variance + self.variance_epsilon) + return (self.weight * x).to(input_dtype) + +class LayerScale(nn.Module): + def __init__(self, dim, init_value=1e-5): + super().__init__() + self.gamma = nn.Parameter(init_value * torch.ones(dim)) + + def forward(self, x): + return self.gamma * x + +class StochasticDepth(nn.Module): + def __init__(self, drop_prob=0.0): + super().__init__() + self.drop_prob = drop_prob + + def forward(self, x, training=True): + if not training or self.drop_prob == 0.0: + return x + keep_prob = 1 - self.drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() + return x.div(keep_prob) * random_tensor + +class CacaRotaryEmbedding(nn.Module): + def __init__( + self, + dim, + max_position_embeddings=8192, + base=10000.0, + scaling_factor=1.0, + scaling_type=None, + ): + super().__init__() + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.scaling_factor = scaling_factor + self.scaling_type = scaling_type + inv_freq = 1.0 / ( + self.base ** (torch.arange(0, self.dim, 2).float() / self.dim) + ) + if scaling_type == "linear": + inv_freq = inv_freq / scaling_factor + elif scaling_type == "dynamic": + inv_freq = inv_freq + elif scaling_type == "yarn": + inv_freq = self._yarn_get_inv_freq(inv_freq) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._cos_cache = None + self._sin_cache = None + self._cached_seq_len = 0 + + def _yarn_get_inv_freq(self, inv_freq): + if len(inv_freq) == 0: + return inv_freq + alpha = self.scaling_factor + beta_fast = 32 + beta_slow = 1 + freq_threshold = 1 / (self.max_position_embeddings * beta_fast) + low_freq_mask = inv_freq > freq_threshold + high_freq_mask = ~low_freq_mask + low_freq = inv_freq[low_freq_mask] + high_freq = inv_freq[high_freq_mask] + if len(low_freq) > 0: + low_freq = low_freq / alpha + if len(high_freq) > 0: + smooth_factor = ( + self.max_position_embeddings * beta_slow / high_freq - beta_fast + ) / (beta_slow - beta_fast) + smooth_factor = torch.clamp(smooth_factor, 0.0, 1.0) + high_freq = (1 - smooth_factor) * ( + high_freq / alpha + ) + smooth_factor * high_freq + result = torch.zeros_like(inv_freq) + result[low_freq_mask] = low_freq + result[high_freq_mask] = high_freq + return result + + def forward(self, x, seq_len, position_offset=0): + if ( + self._cos_cache is not None + and self._sin_cache is not None + and self._cached_seq_len >= seq_len + and position_offset == 0 + ): + return ( + self._cos_cache[:, :, :seq_len, :].to(x.dtype), + self._sin_cache[:, :, :seq_len, :].to(x.dtype), + ) + t = torch.arange( + position_offset, position_offset + seq_len, device=x.device + ).type_as(self.inv_freq) + if self.scaling_type == "dynamic": + if seq_len > self.max_position_embeddings: + dynamic_scale = seq_len / self.max_position_embeddings + t = t / dynamic_scale + freqs = torch.outer(t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos()[None, None, :, :] + sin = emb.sin()[None, None, :, :] + if position_offset == 0 and seq_len > self._cached_seq_len: + self._cos_cache = cos + self._sin_cache = sin + self._cached_seq_len = seq_len + return cos.to(x.dtype), sin.to(x.dtype) + +class ALiBiPositionalBias(nn.Module): + def __init__(self, num_heads, max_positions=8192): + super().__init__() + self.num_heads = num_heads + self.max_positions = max_positions + slopes = torch.tensor(self._get_slopes(num_heads)) + self.register_buffer("slopes", slopes, persistent=False) + + def _get_slopes(self, n): + def get_slopes_power_of_2(n): + start = 2 ** (-(2 ** -(math.log2(n) - 3))) + ratio = start + return [start * (ratio**i) for i in range(n)] + + if math.log2(n).is_integer(): + return get_slopes_power_of_2(n) + else: + closest_power_of_2 = 2 ** math.floor(math.log2(n)) + return ( + get_slopes_power_of_2(closest_power_of_2) + + self._get_slopes(2 * closest_power_of_2)[0::2][ + : n - closest_power_of_2 + ] + ) + + def forward(self, seq_len, key_len=None): + if key_len is None: + key_len = seq_len + query_pos = torch.arange(seq_len, device=self.slopes.device).unsqueeze(1) + key_pos = torch.arange(key_len, device=self.slopes.device).unsqueeze(0) + relative_pos = key_pos - query_pos + bias = relative_pos.unsqueeze(0) * self.slopes.unsqueeze(1).unsqueeze(2) + return bias.unsqueeze(0) + +def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + +def apply_rotary_pos_emb(q, k, cos, sin): + cos = cos.to(q.dtype) + sin = sin.to(q.dtype) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + +def soft_cap_logits(x, cap): + if cap is None or cap <= 0: + return x + return x.clamp(-cap * 0.99, cap * 0.99) + +class TopKRouter(nn.Module): + def __init__(self, hidden_size, num_experts, num_experts_per_tok): + super().__init__() + self.num_experts = num_experts + self.num_experts_per_tok = num_experts_per_tok + self.gate = nn.Linear(hidden_size, num_experts, bias=False) + self.gate_norm = nn.LayerNorm(hidden_size) + self.temperature = nn.Parameter(torch.ones(1)) + self.jitter_noise = 0.01 + + def forward(self, hidden_states): + batch_size, seq_len, hidden_size = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_size) + hidden_states = self.gate_norm(hidden_states) + router_logits = self.gate(hidden_states) + router_logits = torch.clamp(router_logits, min=-20, max=20) + temperature = torch.clamp(self.temperature, min=0.1, max=10.0) + router_logits = router_logits / temperature + if self.training and self.jitter_noise > 0: + noise = torch.randn_like(router_logits) * self.jitter_noise + router_logits = router_logits + noise + routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float32) + top_k_weights, top_k_indices = torch.topk( + routing_weights, self.num_experts_per_tok, dim=-1 + ) + top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-9) + router_probs = routing_weights + expert_usage = router_probs.mean(dim=0) + mean_usage = expert_usage.mean() + aux_loss = expert_usage.std() / (mean_usage + 1e-10) + router_logits_for_z = router_logits.to(torch.float32) + z_loss = torch.logsumexp(router_logits_for_z, dim=-1).pow(2).mean() + router_probs_log = torch.log(router_probs + 1e-10) + entropy = -(router_probs * router_probs_log).sum(dim=-1).mean() + entropy_loss = -0.01 * entropy + total_aux_loss = aux_loss + entropy_loss + return top_k_weights, top_k_indices, total_aux_loss, z_loss + +class ExpertChoiceRouter(nn.Module): + def __init__(self, hidden_size, num_experts, expert_choice_k): + super().__init__() + self.num_experts = num_experts + self.expert_choice_k = expert_choice_k + self.gate = nn.Linear(hidden_size, num_experts, bias=False) + + def forward(self, hidden_states): + batch_size, seq_len, hidden_size = hidden_states.shape + total_tokens = batch_size * seq_len + hidden_states_flat = hidden_states.view(-1, hidden_size) + router_logits = self.gate(hidden_states_flat) + router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) + router_probs_t = router_probs.t() + capacity = max(1, int(self.expert_choice_k * total_tokens / self.num_experts)) + top_k_values, top_k_indices = torch.topk( + router_probs_t, k=min(capacity, total_tokens), dim=-1 + ) + expert_mask = torch.zeros( + self.num_experts, total_tokens, device=hidden_states.device + ) + for expert_idx in range(self.num_experts): + expert_mask[expert_idx, top_k_indices[expert_idx]] = 1.0 + routing_weights = expert_mask.t() * router_probs + aux_loss = (router_probs.mean(dim=0) ** 2).sum() * self.num_experts + z_loss = torch.logsumexp(router_logits, dim=-1).mean() + return routing_weights, aux_loss, z_loss + +class Expert(nn.Module): + def __init__(self, config): + super().__init__() + self.gate_proj = nn.Linear( + config.hidden_size, config.intermediate_size, bias=config.mlp_bias + ) + self.up_proj = nn.Linear( + config.hidden_size, config.intermediate_size, bias=config.mlp_bias + ) + self.down_proj = nn.Linear( + config.intermediate_size, config.hidden_size, bias=config.mlp_bias + ) + self.dropout = nn.Dropout(config.hidden_dropout) + + def forward(self, x): + gate = F.silu(self.gate_proj(x)) + up = self.up_proj(x) + hidden = gate * up + hidden = self.dropout(hidden) + return self.down_proj(hidden) + +class MixtureOfExperts(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.num_experts = config.num_experts + self.num_experts_per_tok = config.num_experts_per_tok + self.use_expert_choice = config.use_expert_choice + self.experts = nn.ModuleList([Expert(config) for _ in range(self.num_experts)]) + if self.use_expert_choice: + self.router = ExpertChoiceRouter( + config.hidden_size, config.num_experts, config.expert_choice_k + ) + else: + self.router = TopKRouter( + config.hidden_size, config.num_experts, config.num_experts_per_tok + ) + self.register_buffer( + "expert_usage_count", torch.zeros(self.num_experts, dtype=torch.long) + ) + + def forward(self, hidden_states): + batch_size, seq_len, hidden_size = hidden_states.shape + hidden_states_flat = hidden_states.view(-1, hidden_size) + if ( + torch.isnan(hidden_states_flat).any() + or torch.isinf(hidden_states_flat).any() + ): + logger.error("NaN or Inf detected in MoE input. Using fallback.") + return ( + hidden_states, + torch.tensor(0.0, device=hidden_states.device), + torch.tensor(0.0, device=hidden_states.device), + ) + if self.use_expert_choice: + routing_weights, aux_loss, z_loss = self.router(hidden_states) + final_output = torch.zeros_like(hidden_states_flat) + for expert_idx, expert in enumerate(self.experts): + expert_mask = routing_weights[:, expert_idx] > 1e-6 + if expert_mask.any(): + if not self.training: + self.expert_usage_count[expert_idx] += expert_mask.sum() + try: + expert_input = hidden_states_flat[expert_mask] + expert_output = expert(expert_input) + if ( + torch.isnan(expert_output).any() + or torch.isinf(expert_output).any() + ): + logger.warning( + f"Expert {expert_idx} produced NaN/Inf. Skipping." + ) + continue + final_output[expert_mask] += ( + expert_output + * routing_weights[expert_mask, expert_idx : expert_idx + 1] + ) + except RuntimeError as e: + logger.error(f"Expert {expert_idx} failed: {e}") + continue + else: + top_k_weights, top_k_indices, aux_loss, z_loss = self.router(hidden_states) + final_output = torch.zeros_like(hidden_states_flat) + for expert_idx in range(self.num_experts): + expert_mask = (top_k_indices == expert_idx).any(dim=-1) + if expert_mask.any(): + if not self.training: + self.expert_usage_count[expert_idx] += expert_mask.sum() + try: + expert_input = hidden_states_flat[expert_mask] + expert_output = self.experts[expert_idx](expert_input) + if ( + torch.isnan(expert_output).any() + or torch.isinf(expert_output).any() + ): + logger.warning( + f"Expert {expert_idx} produced NaN/Inf. Skipping." + ) + continue + token_indices = torch.where(expert_mask)[0] + for i, token_idx in enumerate(token_indices): + expert_positions = ( + top_k_indices[token_idx] == expert_idx + ).nonzero(as_tuple=True)[0] + if len(expert_positions) > 0: + weight = top_k_weights[token_idx, expert_positions[0]] + final_output[token_idx] += expert_output[i] * weight + except RuntimeError as e: + logger.error(f"Expert {expert_idx} failed: {e}") + continue + final_output = final_output.view(batch_size, seq_len, hidden_size) + if torch.isnan(final_output).any() or torch.isinf(final_output).any(): + logger.error("MoE output contains NaN/Inf. Returning input.") + return hidden_states, aux_loss, z_loss + return final_output, aux_loss, z_loss + +class MixtureOfDepthsRouter(nn.Module): + def __init__(self, hidden_size, capacity_factor=0.5, route_method="learned"): + super().__init__() + self.capacity_factor = capacity_factor + self.route_method = route_method + if route_method == "learned": + self.router = nn.Linear(hidden_size, 1) + + def forward(self, hidden_states): + batch_size, seq_len, hidden_size = hidden_states.shape + if self.route_method == "learned": + routing_logits = self.router(hidden_states).squeeze(-1) + elif self.route_method == "random": + routing_logits = torch.rand( + batch_size, seq_len, device=hidden_states.device + ) + else: + routing_logits = torch.zeros( + batch_size, seq_len, device=hidden_states.device + ) + capacity = max(1, int(seq_len * self.capacity_factor)) + _, top_indices = torch.topk(routing_logits, k=capacity, dim=-1) + process_mask = torch.zeros( + batch_size, seq_len, dtype=torch.bool, device=hidden_states.device + ) + process_mask.scatter_(1, top_indices, True) + return process_mask + +class CacaAttention(nn.Module): + def __init__(self, config, layer_idx=None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.head_dim = config.head_dim + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.sliding_window = config.sliding_window + self.attn_logit_softcapping = config.attn_logit_softcapping + self.attention_sink_size = config.attention_sink_size + self.attention_sink_window = config.attention_sink_window + self.q_proj = nn.Linear( + self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.v_proj = nn.Linear( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.o_proj = nn.Linear( + self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias + ) + if config.use_qk_norm: + self.q_norm = CacaRMSNorm(self.head_dim, eps=config.qk_norm_eps) + self.k_norm = CacaRMSNorm(self.head_dim, eps=config.qk_norm_eps) + else: + self.q_norm = None + self.k_norm = None + if config.use_rotary_embeddings: + scaling_factor = 1.0 + scaling_type = None + if config.rope_scaling is not None: + scaling_type = config.rope_scaling.get("type", "linear") + scaling_factor = config.rope_scaling.get("factor", 1.0) + self.rotary_emb = CacaRotaryEmbedding( + self.head_dim, + config.max_position_embeddings, + config.rope_theta, + scaling_factor=scaling_factor, + scaling_type=scaling_type, + ) + else: + self.rotary_emb = None + if config.use_alibi: + self.alibi = ALiBiPositionalBias( + self.num_heads, config.max_position_embeddings + ) + else: + self.alibi = None + self.attention_dropout = nn.Dropout(config.attention_dropout) + self.is_global_attention = self._determine_attention_type(config, layer_idx) + self.has_flash_attn = HAS_FLASH_ATTN and config.use_flash_attn + self.has_xformers = HAS_XFORMERS + self.has_sdpa = HAS_SDPA + self._mask_cache = OrderedDict() + self._max_cache_size = 10 + self._cache_hits = 0 + self._cache_misses = 0 + + def _determine_attention_type(self, config, layer_idx): + if layer_idx is None: + return False + if config.attention_pattern == "all_global": + return True + elif config.attention_pattern == "all_local": + return False + elif config.attention_pattern == "interleaved": + return (layer_idx % config.global_attention_every_n_layers) == ( + config.global_attention_every_n_layers - 1 + ) + return False + + def _create_causal_mask( + self, query_length, key_length, dtype, device, use_sliding_window + ): + device_key = (device.type, device.index if device.type == "cuda" else None) + cache_key = ( + query_length, + key_length, + str(dtype), + device_key, + use_sliding_window, + self.sliding_window if use_sliding_window else None, + ) + if cache_key in self._mask_cache: + self._cache_hits += 1 + self._mask_cache.move_to_end(cache_key) + cached_mask = self._mask_cache[cache_key] + return cached_mask.to(dtype).to(device) + self._cache_misses += 1 + if query_length > key_length: + key_length = query_length + query_pos = torch.arange(query_length, device=device) + ( + key_length - query_length + ) + key_pos = torch.arange(key_length, device=device) + distance = query_pos[:, None] - key_pos[None, :] + mask = distance < 0 + if use_sliding_window and self.sliding_window is not None: + if self.config.use_attention_sink and self.attention_sink_size > 0: + is_sink = key_pos[None, :] < self.attention_sink_size + in_window = (distance >= 0) & (distance <= self.sliding_window) + mask = (distance < 0) | ((~is_sink) & (~in_window)) + else: + too_far_mask = distance > self.sliding_window + mask = mask | too_far_mask + float_mask = torch.zeros( + 1, 1, query_length, key_length, dtype=dtype, device=device + ) + float_mask.masked_fill_(mask.unsqueeze(0).unsqueeze(0), -1e9) + if len(self._mask_cache) >= self._max_cache_size: + self._mask_cache.popitem(last=False) + self._mask_cache[cache_key] = float_mask.detach().cpu() + return float_mask + + def get_cache_stats(self): + total_requests = self._cache_hits + self._cache_misses + hit_rate = self._cache_hits / total_requests if total_requests > 0 else 0 + return { + "hits": self._cache_hits, + "misses": self._cache_misses, + "hit_rate": hit_rate, + "cache_size": len(self._mask_cache), + } + + def forward( + self, hidden_states, attention_mask=None, past_key_value=None, use_cache=False + ): + if hidden_states is None: + raise ValueError("hidden_states cannot be None") + if hidden_states.shape[-1] != self.hidden_size: + raise ValueError( + f"Expected hidden_size {self.hidden_size}, got {hidden_states.shape[-1]}" + ) + batch_size, seq_length, _ = hidden_states.size() + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + query_states = query_states.view( + batch_size, seq_length, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + batch_size, seq_length, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + batch_size, seq_length, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + if self.q_norm is not None and self.k_norm is not None: + query_states = self.q_norm(query_states) + key_states = self.k_norm(key_states) + position_offset = 0 + if past_key_value is not None: + try: + if ( + isinstance(past_key_value, (tuple, list)) + and len(past_key_value) >= 2 + and past_key_value[0] is not None + and past_key_value[0].numel() > 0 + ): + position_offset = past_key_value[0].shape[2] + except (IndexError, AttributeError, TypeError) as e: + logger.debug(f"Could not get position_offset: {e}") + position_offset = 0 + if self.rotary_emb is not None: + cos, sin = self.rotary_emb(query_states, seq_length, position_offset) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin + ) + if past_key_value is not None and past_key_value[0] is not None: + try: + if past_key_value[0].numel() > 0 and past_key_value[1].numel() > 0: + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + except RuntimeError as e: + logger.error(f"Failed to concatenate past_key_value: {e}") + if use_cache: + present_key_value = (key_states, value_states) + else: + present_key_value = None + key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1) + value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1) + kv_seq_len = key_states.shape[-2] + use_sliding_window = (not self.is_global_attention) and ( + self.sliding_window is not None + ) + if self.has_flash_attn and attention_mask is None: + if query_states.device.type == "cuda" and query_states.dtype in [ + torch.float16, + torch.bfloat16, + ]: + try: + attn_output = self._flash_attention( + query_states, key_states, value_states, use_sliding_window + ) + except Exception as e: + logger.warning(f"Flash Attention failed: {e}") + attn_output = self._fallback_attention( + query_states, + key_states, + value_states, + attention_mask, + kv_seq_len, + use_sliding_window, + ) + else: + attn_output = self._fallback_attention( + query_states, + key_states, + value_states, + attention_mask, + kv_seq_len, + use_sliding_window, + ) + else: + attn_output = self._fallback_attention( + query_states, + key_states, + value_states, + attention_mask, + kv_seq_len, + use_sliding_window, + ) + attn_output = self.o_proj(attn_output) + return attn_output, present_key_value + + def _flash_attention( + self, query_states, key_states, value_states, use_sliding_window + ): + batch_size, num_heads, seq_length, head_dim = query_states.shape + kv_seq_len = key_states.shape[-2] + if key_states.shape[0] != batch_size or key_states.shape[1] != num_heads: + raise ValueError( + f"Shape mismatch: query={query_states.shape}, key={key_states.shape}" + ) + original_dtype = query_states.dtype + if original_dtype == torch.bfloat16: + if not torch.cuda.is_bf16_supported(): + logger.warning("BF16 not supported, using FP16") + original_dtype = torch.float16 + compute_dtype = ( + torch.bfloat16 + if original_dtype not in [torch.float16, torch.bfloat16] + else original_dtype + ) + query_states = query_states.transpose(1, 2).contiguous().to(compute_dtype) + key_states = key_states.transpose(1, 2).contiguous().to(compute_dtype) + value_states = value_states.transpose(1, 2).contiguous().to(compute_dtype) + if use_sliding_window and self.sliding_window < kv_seq_len: + window_size = (self.sliding_window, 0) + else: + window_size = (-1, 0) + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout_p=self.config.attention_dropout if self.training else 0.0, + softmax_scale=None, + causal=True, + window_size=window_size, + ) + attn_output = attn_output.to(original_dtype) + attn_output = attn_output.reshape(batch_size, seq_length, self.hidden_size) + return attn_output + + def _fallback_attention( + self, + query_states, + key_states, + value_states, + attention_mask, + kv_seq_len, + use_sliding_window, + ): + device_type = query_states.device.type + if self.has_xformers and device_type == "cuda" and attention_mask is None: + try: + return self._xformers_attention( + query_states, + key_states, + value_states, + kv_seq_len, + use_sliding_window, + ) + except Exception as e: + logger.debug(f"xFormers failed: {e}") + if self.has_sdpa: + return self._sdpa_attention( + query_states, + key_states, + value_states, + attention_mask, + kv_seq_len, + use_sliding_window, + ) + else: + return self._standard_attention( + query_states, + key_states, + value_states, + attention_mask, + kv_seq_len, + use_sliding_window, + ) + + def _sdpa_attention( + self, + query_states, + key_states, + value_states, + attention_mask, + kv_seq_len, + use_sliding_window, + ): + batch_size, num_heads, seq_length, head_dim = query_states.shape + if attention_mask is None: + attention_mask = self._create_causal_mask( + seq_length, + kv_seq_len, + query_states.dtype, + query_states.device, + use_sliding_window, + ) + if self.alibi is not None: + alibi_bias = self.alibi(seq_length, kv_seq_len) + attention_mask = attention_mask + alibi_bias + attn_output = F.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.config.attention_dropout if self.training else 0.0, + is_causal=False, + ) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, seq_length, self.hidden_size) + return attn_output + + def _xformers_attention( + self, query_states, key_states, value_states, kv_seq_len, use_sliding_window + ): + batch_size, num_heads, seq_length, head_dim = query_states.shape + attn_bias = self._create_causal_mask( + seq_length, + kv_seq_len, + query_states.dtype, + query_states.device, + use_sliding_window, + ) + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + attn_output = memory_efficient_attention( + query_states, + key_states, + value_states, + attn_bias=attn_bias, + p=self.config.attention_dropout if self.training else 0.0, + ) + attn_output = attn_output.reshape(batch_size, seq_length, self.hidden_size) + return attn_output + + def _standard_attention( + self, + query_states, + key_states, + value_states, + attention_mask, + kv_seq_len, + use_sliding_window, + ): + batch_size, num_heads, seq_length, head_dim = query_states.shape + attn_weights = torch.matmul( + query_states, key_states.transpose(2, 3) + ) / math.sqrt(head_dim) + attn_weights = torch.clamp(attn_weights, min=-50.0, max=50.0) + attn_weights = soft_cap_logits(attn_weights, self.attn_logit_softcapping) + if attention_mask is None: + attention_mask = self._create_causal_mask( + seq_length, + kv_seq_len, + attn_weights.dtype, + attn_weights.device, + use_sliding_window, + ) + if self.alibi is not None: + alibi_bias = self.alibi(seq_length, kv_seq_len) + attention_mask = attention_mask + alibi_bias + attn_weights = attn_weights + attention_mask + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + query_states.dtype + ) + attn_weights = self.attention_dropout(attn_weights) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, seq_length, self.hidden_size) + return attn_output + +class CacaCrossAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.head_dim = config.head_dim + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.q_proj = nn.Linear( + self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.v_proj = nn.Linear( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.o_proj = nn.Linear( + self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias + ) + self.attention_dropout = nn.Dropout(config.attention_dropout) + + def forward(self, hidden_states, encoder_hidden_states, attention_mask=None): + batch_size, seq_length, _ = hidden_states.size() + encoder_seq_length = encoder_hidden_states.size(1) + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(encoder_hidden_states) + value_states = self.v_proj(encoder_hidden_states) + query_states = query_states.view( + batch_size, seq_length, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + batch_size, encoder_seq_length, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + batch_size, encoder_seq_length, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1) + value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1) + attn_weights = torch.matmul( + query_states, key_states.transpose(2, 3) + ) / math.sqrt(self.head_dim) + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + query_states.dtype + ) + attn_weights = self.attention_dropout(attn_weights) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, seq_length, self.hidden_size) + attn_output = self.o_proj(attn_output) + return attn_output + +class CacaMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear( + self.hidden_size, self.intermediate_size, bias=config.mlp_bias + ) + self.up_proj = nn.Linear( + self.hidden_size, self.intermediate_size, bias=config.mlp_bias + ) + self.down_proj = nn.Linear( + self.intermediate_size, self.hidden_size, bias=config.mlp_bias + ) + self.dropout = nn.Dropout(config.hidden_dropout) + + def forward(self, x): + gate = F.silu(self.gate_proj(x)) + up = self.up_proj(x) + hidden = gate * up + hidden = self.dropout(hidden) + output = self.down_proj(hidden) + return output + +class CacaDecoderLayer(nn.Module): + def __init__(self, config, layer_idx): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.self_attn = CacaAttention(config, layer_idx=layer_idx) + self.use_moe = config.use_moe and (layer_idx % config.moe_layer_frequency == 0) + if self.use_moe: + self.mlp = MixtureOfExperts(config) + else: + self.mlp = CacaMLP(config) + self.use_cross_attention = config.use_cross_attention and ( + layer_idx % config.cross_attention_frequency == 0 + ) + if self.use_cross_attention: + self.cross_attn = CacaCrossAttention(config) + self.cross_attn_layernorm = CacaRMSNorm( + config.hidden_size, config.rms_norm_eps + ) + self.input_layernorm = CacaRMSNorm(config.hidden_size, config.rms_norm_eps) + self.post_attention_layernorm = CacaRMSNorm( + config.hidden_size, config.rms_norm_eps + ) + self.residual_dropout = nn.Dropout(config.residual_dropout) + if config.use_layer_scale: + self.layer_scale_1 = LayerScale(config.hidden_size, config.layer_scale_init) + self.layer_scale_2 = LayerScale(config.hidden_size, config.layer_scale_init) + if self.use_cross_attention: + self.layer_scale_cross = LayerScale( + config.hidden_size, config.layer_scale_init + ) + else: + self.layer_scale_1 = None + self.layer_scale_2 = None + self.layer_scale_cross = None + if config.use_stochastic_depth: + drop_prob = ( + config.stochastic_depth_prob * layer_idx / config.num_hidden_layers + ) + self.stochastic_depth = StochasticDepth(drop_prob) + else: + self.stochastic_depth = None + if config.use_mixture_of_depths: + self.mod_router = MixtureOfDepthsRouter( + config.hidden_size, config.mod_capacity_factor, config.mod_route_method + ) + else: + self.mod_router = None + self.gradient_stats = { + "max_grad_norm": 0.0, + "mean_grad_norm": 0.0, + "grad_clip_count": 0, + } + + def _gradient_monitor_hook(self, grad): + if grad is None: + return grad + grad_norm = grad.norm().item() + self.gradient_stats["max_grad_norm"] = max( + self.gradient_stats["max_grad_norm"], grad_norm + ) + self.gradient_stats["mean_grad_norm"] = ( + 0.9 * self.gradient_stats["mean_grad_norm"] + 0.1 * grad_norm + ) + max_grad = 10.0 + if grad_norm > max_grad: + self.gradient_stats["grad_clip_count"] += 1 + if self.gradient_stats["grad_clip_count"] % 100 == 0: + logger.warning( + f"Layer {self.layer_idx}: High gradient norm {grad_norm:.2f} " + f"(clipped {self.gradient_stats['grad_clip_count']} times)" + ) + grad = torch.clamp(grad, min=-max_grad, max=max_grad) + return grad + + def forward( + self, + hidden_states, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + use_cache=False, + ): + if hidden_states is None: + raise ValueError("hidden_states cannot be None") + if torch.isnan(hidden_states).any() or torch.isinf(hidden_states).any(): + logger.error(f"Layer {self.layer_idx}: NaN/Inf in input!") + hidden_states = torch.nan_to_num( + hidden_states, nan=0.0, posinf=1e4, neginf=-1e4 + ) + aux_loss = 0.0 + z_loss = 0.0 + if self.training and hidden_states.requires_grad: + hidden_states.register_hook(self._gradient_monitor_hook) + if self.mod_router is not None: + process_mask = self.mod_router(hidden_states) + tokens_to_process = hidden_states[process_mask] + if tokens_to_process.numel() == 0: + present_key_value = past_key_value if use_cache else None + return hidden_states, present_key_value, aux_loss, z_loss + else: + process_mask = None + tokens_to_process = hidden_states + residual = tokens_to_process + tokens_to_process = self.input_layernorm(tokens_to_process) + attn_output, present_key_value = self.self_attn( + tokens_to_process, + attention_mask, + past_key_value=past_key_value, + use_cache=use_cache, + ) + if torch.isnan(attn_output).any() or torch.isinf(attn_output).any(): + logger.warning( + f"Layer {self.layer_idx}: NaN/Inf in attention. Using zero output." + ) + attn_output = torch.zeros_like(attn_output) + if self.layer_scale_1 is not None: + attn_output = self.layer_scale_1(attn_output) + if self.stochastic_depth is not None: + attn_output = self.stochastic_depth(attn_output, self.training) + tokens_to_process = residual + self.residual_dropout(attn_output) + if self.training: + tokens_to_process = torch.clamp(tokens_to_process, min=-1e4, max=1e4) + if self.use_cross_attention and encoder_hidden_states is not None: + residual = tokens_to_process + tokens_to_process = self.cross_attn_layernorm(tokens_to_process) + cross_attn_output = self.cross_attn( + tokens_to_process, + encoder_hidden_states, + attention_mask=encoder_attention_mask, + ) + if ( + torch.isnan(cross_attn_output).any() + or torch.isinf(cross_attn_output).any() + ): + logger.warning( + f"Layer {self.layer_idx}: NaN/Inf in cross-attention. Skipping." + ) + cross_attn_output = torch.zeros_like(cross_attn_output) + if self.layer_scale_cross is not None: + cross_attn_output = self.layer_scale_cross(cross_attn_output) + if self.stochastic_depth is not None: + cross_attn_output = self.stochastic_depth( + cross_attn_output, self.training + ) + tokens_to_process = residual + self.residual_dropout(cross_attn_output) + if self.training: + tokens_to_process = torch.clamp(tokens_to_process, min=-1e4, max=1e4) + residual = tokens_to_process + tokens_to_process = self.post_attention_layernorm(tokens_to_process) + if self.use_moe: + mlp_output, moe_aux_loss, moe_z_loss = self.mlp(tokens_to_process) + aux_loss += moe_aux_loss + z_loss += moe_z_loss + else: + mlp_output = self.mlp(tokens_to_process) + if torch.isnan(mlp_output).any() or torch.isinf(mlp_output).any(): + logger.warning( + f"Layer {self.layer_idx}: NaN/Inf in MLP output. Using zero output." + ) + mlp_output = torch.zeros_like(mlp_output) + if self.layer_scale_2 is not None: + mlp_output = self.layer_scale_2(mlp_output) + if self.stochastic_depth is not None: + mlp_output = self.stochastic_depth(mlp_output, self.training) + tokens_to_process = residual + self.residual_dropout(mlp_output) + if self.training: + tokens_to_process = torch.clamp(tokens_to_process, min=-1e4, max=1e4) + if process_mask is not None: + hidden_states[process_mask] = tokens_to_process + else: + hidden_states = tokens_to_process + return hidden_states, present_key_value, aux_loss, z_loss + + def get_gradient_stats(self): + return self.gradient_stats.copy() + +class VisionEncoder(nn.Module): + def __init__(self, config): + super().__init__() + vision_config = config.vision_config + self.patch_size = vision_config.get("patch_size", 14) + self.image_size = vision_config.get("image_size", 224) + self.num_channels = vision_config.get("num_channels", 3) + self.hidden_size = vision_config.get("hidden_size", 1024) + self.num_layers = vision_config.get("num_layers", 24) + self.num_heads = vision_config.get("num_heads", 16) + self.intermediate_size = vision_config.get("intermediate_size", 4096) + self.layer_norm_eps = vision_config.get("layer_norm_eps", 1e-6) + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.patch_embed = nn.Sequential( + nn.Conv2d( + self.num_channels, + self.hidden_size, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=False, + ), + nn.Dropout(p=vision_config.get("dropout", 0.0)), + ) + self.cls_token = nn.Parameter(torch.zeros(1, 1, self.hidden_size)) + self.pos_embed = nn.Parameter( + torch.zeros(1, self.num_patches + 1, self.hidden_size) + ) + self.pos_drop = nn.Dropout(p=vision_config.get("dropout", 0.0)) + self.blocks = nn.ModuleList( + [ + VisionTransformerBlock( + dim=self.hidden_size, + num_heads=self.num_heads, + mlp_ratio=self.intermediate_size / self.hidden_size, + dropout=vision_config.get("dropout", 0.0), + layer_norm_eps=self.layer_norm_eps, + ) + for _ in range(self.num_layers) + ] + ) + self.norm = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) + self._init_weights() + + def _init_weights(self): + nn.init.trunc_normal_(self.pos_embed, std=0.02) + nn.init.trunc_normal_(self.cls_token, std=0.02) + nn.init.trunc_normal_(self.patch_embed[0].weight, std=0.02) + + def forward(self, pixel_values): + batch_size = pixel_values.shape[0] + x = self.patch_embed(pixel_values) + x = x.flatten(2).transpose(1, 2) + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + x = torch.cat([cls_tokens, x], dim=1) + x = x + self.pos_embed + x = self.pos_drop(x) + for block in self.blocks: + x = block(x) + x = self.norm(x) + return x + +class VisionTransformerBlock(nn.Module): + def __init__(self, dim, num_heads, mlp_ratio=4.0, dropout=0.0, layer_norm_eps=1e-6): + super().__init__() + self.norm1 = nn.LayerNorm(dim, eps=layer_norm_eps) + self.attn = nn.MultiheadAttention( + dim, num_heads, dropout=dropout, batch_first=True + ) + self.drop_path1 = nn.Dropout(dropout) + self.norm2 = nn.LayerNorm(dim, eps=layer_norm_eps) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = nn.Sequential( + nn.Linear(dim, mlp_hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(mlp_hidden_dim, dim), + nn.Dropout(dropout), + ) + self.drop_path2 = nn.Dropout(dropout) + + def forward(self, x): + residual = x + x = self.norm1(x) + x = self.attn(x, x, x, need_weights=False)[0] + x = self.drop_path1(x) + x = residual + x + residual = x + x = self.norm2(x) + x = self.mlp(x) + x = self.drop_path2(x) + x = residual + x + return x + +class AudioEncoder(nn.Module): + def __init__(self, config): + super().__init__() + audio_config = config.audio_config + self.num_mel_bins = audio_config.get("num_mel_bins", 80) + self.hidden_size = audio_config.get("hidden_size", 1024) + self.num_layers = audio_config.get("num_layers", 12) + self.num_heads = audio_config.get("num_heads", 16) + self.intermediate_size = audio_config.get("intermediate_size", 4096) + self.max_audio_length = audio_config.get("max_audio_length", 3000) + self.dropout = audio_config.get("dropout", 0.0) + self.conv1 = nn.Sequential( + nn.Conv1d(self.num_mel_bins, self.hidden_size, kernel_size=3, padding=1), + nn.GELU(), + nn.Dropout(p=self.dropout), + ) + self.conv2 = nn.Sequential( + nn.Conv1d( + self.hidden_size, self.hidden_size, kernel_size=3, stride=2, padding=1 + ), + nn.GELU(), + nn.Dropout(p=self.dropout), + ) + self.pos_embed = nn.Parameter( + torch.zeros(1, self.max_audio_length // 2, self.hidden_size) + ) + self.pos_drop = nn.Dropout(p=self.dropout) + self.blocks = nn.ModuleList( + [ + AudioTransformerBlock( + dim=self.hidden_size, + num_heads=self.num_heads, + mlp_ratio=self.intermediate_size / self.hidden_size, + dropout=self.dropout, + ) + for _ in range(self.num_layers) + ] + ) + self.norm = nn.LayerNorm(self.hidden_size) + self._init_weights() + + def _init_weights(self): + nn.init.trunc_normal_(self.pos_embed, std=0.02) + + def forward(self, audio_features): + x = F.gelu(self.conv1(audio_features)) + x = F.gelu(self.conv2(x)) + x = x.transpose(1, 2) + seq_len = x.shape[1] + if seq_len <= self.pos_embed.shape[1]: + x = x + self.pos_embed[:, :seq_len, :] + else: + pos_embed_interp = F.interpolate( + self.pos_embed.transpose(1, 2), + size=seq_len, + mode="linear", + align_corners=False, + ).transpose(1, 2) + x = x + pos_embed_interp + x = self.pos_drop(x) + for block in self.blocks: + x = block(x) + x = self.norm(x) + return x + +class AudioTransformerBlock(nn.Module): + def __init__(self, dim, num_heads, mlp_ratio=4.0, dropout=0.0): + super().__init__() + self.norm1 = nn.LayerNorm(dim) + self.attn = nn.MultiheadAttention( + dim, num_heads, dropout=dropout, batch_first=True + ) + self.drop_path1 = nn.Dropout(dropout) + self.norm2 = nn.LayerNorm(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = nn.Sequential( + nn.Linear(dim, mlp_hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(mlp_hidden_dim, dim), + nn.Dropout(dropout), + ) + self.drop_path2 = nn.Dropout(dropout) + + def forward(self, x): + residual = x + x = self.norm1(x) + x = self.attn(x, x, x, need_weights=False)[0] + x = self.drop_path1(x) + x = residual + x + residual = x + x = self.norm2(x) + x = self.mlp(x) + x = self.drop_path2(x) + x = residual + x + return x + +class MultiModalProjector(nn.Module): + def __init__(self, input_size, output_size, projector_type="mlp", num_layers=2): + super().__init__() + self.projector_type = projector_type + if projector_type == "linear": + self.projector = nn.Linear(input_size, output_size) + elif projector_type == "mlp": + layers = [] + current_size = input_size + for i in range(num_layers - 1): + layers.extend( + [nn.Linear(current_size, output_size), nn.GELU(), nn.Dropout(0.1)] + ) + current_size = output_size + layers.append(nn.Linear(current_size, output_size)) + self.projector = nn.Sequential(*layers) + elif projector_type == "perceiver": + self.projector = PerceiverResampler( + input_size, output_size, num_latents=64, num_layers=2 + ) + elif projector_type == "qformer": + self.projector = QFormerProjector( + input_size, output_size, num_queries=32, num_layers=2 + ) + else: + raise ValueError(f"projector_type tidak dikenal: {projector_type}") + + def forward(self, x): + return self.projector(x) + +class PerceiverResampler(nn.Module): + def __init__(self, input_size, output_size, num_latents=64, num_layers=2): + super().__init__() + self.num_latents = num_latents + self.latents = nn.Parameter(torch.randn(num_latents, output_size)) + self.layers = nn.ModuleList( + [ + PerceiverLayer(output_size, input_size if i == 0 else output_size) + for i in range(num_layers) + ] + ) + self.norm = nn.LayerNorm(output_size) + + def forward(self, x): + batch_size = x.shape[0] + latents = self.latents.unsqueeze(0).expand(batch_size, -1, -1) + for i, layer in enumerate(self.layers): + if i == 0: + latents = layer(latents, x) + else: + latents = layer(latents, latents) + return self.norm(latents) + +class PerceiverLayer(nn.Module): + def __init__(self, query_dim, key_dim): + super().__init__() + self.cross_attn = nn.MultiheadAttention( + query_dim, num_heads=8, kdim=key_dim, vdim=key_dim, batch_first=True + ) + self.mlp = nn.Sequential( + nn.LayerNorm(query_dim), + nn.Linear(query_dim, query_dim * 4), + nn.GELU(), + nn.Linear(query_dim * 4, query_dim), + ) + self.norm1 = nn.LayerNorm(query_dim) + self.norm2 = nn.LayerNorm(query_dim) + + def forward(self, query, key): + query = ( + query + self.cross_attn(self.norm1(query), key, key, need_weights=False)[0] + ) + query = query + self.mlp(self.norm2(query)) + return query + +class QFormerProjector(nn.Module): + def __init__(self, input_size, output_size, num_queries=32, num_layers=2): + super().__init__() + self.num_queries = num_queries + self.query_embeds = nn.Parameter(torch.randn(num_queries, output_size)) + self.query_layers = nn.ModuleList( + [ + nn.TransformerEncoderLayer( + d_model=output_size, + nhead=8, + dim_feedforward=output_size * 4, + batch_first=True, + ) + for _ in range(num_layers) + ] + ) + self.cross_attn_layers = nn.ModuleList( + [ + nn.MultiheadAttention( + output_size, + num_heads=8, + kdim=input_size, + vdim=input_size, + batch_first=True, + ) + for _ in range(num_layers) + ] + ) + self.norm = nn.LayerNorm(output_size) + + def forward(self, x): + batch_size = x.shape[0] + queries = self.query_embeds.unsqueeze(0).expand(batch_size, -1, -1) + for query_layer, cross_attn_layer in zip( + self.query_layers, self.cross_attn_layers + ): + queries = query_layer(queries) + queries = queries + cross_attn_layer(queries, x, x, need_weights=False)[0] + return self.norm(queries) + +class CacaPreTrainedModel(PreTrainedModel): + config_class = CacaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["CacaDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, CacaModel): + module.gradient_checkpointing = value + +class CacaModel(CacaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.config = config + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = nn.ModuleList( + [ + CacaDecoderLayer(config, layer_idx=idx) + for idx in range(config.num_hidden_layers) + ] + ) + self.norm = CacaRMSNorm(config.hidden_size, config.rms_norm_eps) + self.gradient_checkpointing = False + self.metrics = MetricsTracker() + self._last_memory_check = 0 + self._memory_check_interval = 5 + if config.use_multimodal: + if config.vision_config: + self.vision_encoder = VisionEncoder(config) + vision_hidden_size = config.vision_config.get("hidden_size", 768) + self.vision_projector = MultiModalProjector( + vision_hidden_size, + config.hidden_size, + projector_type=config.vision_config.get("projector_type", "mlp"), + ) + else: + self.vision_encoder = None + self.vision_projector = None + if config.audio_config: + self.audio_encoder = AudioEncoder(config) + audio_hidden_size = config.audio_config.get("hidden_size", 768) + self.audio_projector = MultiModalProjector( + audio_hidden_size, + config.hidden_size, + projector_type=config.audio_config.get("projector_type", "mlp"), + ) + else: + self.audio_encoder = None + self.audio_projector = None + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def _prepare_attention_mask(self, attention_mask, input_shape, dtype): + if attention_mask is None: + return None + batch_size, seq_length = input_shape + if attention_mask.dim() == 2: + attention_mask = attention_mask[:, None, None, :] + elif attention_mask.dim() == 3: + attention_mask = attention_mask[:, None, :, :] + attention_mask = attention_mask.to(dtype=dtype) + attention_mask = (1.0 - attention_mask) * torch.finfo(dtype).min + return attention_mask + + def _check_memory_usage(self, device, layer_idx): + if not torch.cuda.is_available() or device.type != "cuda": + return + if layer_idx - self._last_memory_check >= self._memory_check_interval: + allocated_gb = torch.cuda.memory_allocated(device) / 1024**3 + reserved_gb = torch.cuda.memory_reserved(device) / 1024**3 + self.metrics.log("gpu_memory_allocated_gb", allocated_gb) + self.metrics.log("gpu_memory_reserved_gb", reserved_gb) + if allocated_gb > 12: + logger.warning( + f"Layer {layer_idx}: High GPU memory - " + f"Allocated: {allocated_gb:.2f}GB, Reserved: {reserved_gb:.2f}GB" + ) + if allocated_gb > 14: + logger.info(f"Clearing CUDA cache at layer {layer_idx}") + torch.cuda.empty_cache() + self._last_memory_check = layer_idx + + def forward( + self, + input_ids=None, + pixel_values=None, + audio_features=None, + attention_mask=None, + past_key_values=None, + use_cache=None, + output_hidden_states=False, + return_dict=True, + **kwargs, + ): + use_cache = use_cache if use_cache is not None else self.config.use_cache + if input_ids is not None: + batch_size, seq_length = input_ids.shape + device = input_ids.device + self.metrics.log("batch_size", batch_size) + self.metrics.log("seq_length", seq_length) + hidden_states = self.embed_tokens(input_ids) + else: + raise ValueError("input_ids cannot be None") + encoder_hidden_states = None + encoder_attention_mask = None + if self.config.use_multimodal: + multimodal_embeds = [] + if pixel_values is not None and self.vision_encoder is not None: + if pixel_values.device != device: + pixel_values = pixel_values.to(device) + try: + vision_features = self.vision_encoder(pixel_values) + vision_embeds = self.vision_projector(vision_features) + multimodal_embeds.append(vision_embeds) + self.metrics.log("vision_tokens", vision_embeds.shape[1]) + except RuntimeError as e: + logger.error(f"Vision encoder failed: {e}") + if audio_features is not None and self.audio_encoder is not None: + if audio_features.device != device: + audio_features = audio_features.to(device) + try: + audio_encoded = self.audio_encoder(audio_features) + audio_embeds = self.audio_projector(audio_encoded) + multimodal_embeds.append(audio_embeds) + self.metrics.log("audio_tokens", audio_embeds.shape[1]) + except RuntimeError as e: + logger.error(f"Audio encoder failed: {e}") + if multimodal_embeds: + if self.config.use_cross_attention: + encoder_hidden_states = torch.cat(multimodal_embeds, dim=1) + encoder_seq_len = encoder_hidden_states.shape[1] + encoder_attention_mask = torch.ones( + batch_size, + encoder_seq_len, + dtype=hidden_states.dtype, + device=device, + ) + else: + multimodal_concat = torch.cat(multimodal_embeds, dim=1) + max_multimodal_tokens = self.config.max_position_embeddings // 4 + if multimodal_concat.shape[1] > max_multimodal_tokens: + logger.warning( + f"Truncating multimodal tokens from " + f"{multimodal_concat.shape[1]} to {max_multimodal_tokens}" + ) + multimodal_concat = multimodal_concat[:, :max_multimodal_tokens] + hidden_states = torch.cat([multimodal_concat, hidden_states], dim=1) + seq_length = hidden_states.shape[1] + if attention_mask is not None: + multimodal_mask = torch.ones( + batch_size, + multimodal_concat.shape[1], + dtype=attention_mask.dtype, + device=device, + ) + attention_mask = torch.cat( + [multimodal_mask, attention_mask], dim=1 + ) + else: + attention_mask = torch.ones( + batch_size, + seq_length, + dtype=hidden_states.dtype, + device=device, + ) + if attention_mask is not None: + attention_mask = self._prepare_attention_mask( + attention_mask, (batch_size, seq_length), hidden_states.dtype + ) + if encoder_attention_mask is not None and self.config.use_cross_attention: + encoder_attention_mask = self._prepare_attention_mask( + encoder_attention_mask, + (batch_size, encoder_hidden_states.shape[1]), + hidden_states.dtype, + ) + if use_cache: + if past_key_values is None: + past_key_values = tuple([None] * len(self.layers)) + present_key_values = [] if use_cache else None + all_hidden_states = [] if output_hidden_states else None + total_aux_loss = 0.0 + total_z_loss = 0.0 + for idx, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states.append(hidden_states) + if self.training: + self._check_memory_usage(device, idx) + past_key_value = ( + past_key_values[idx] if past_key_values is not None else None + ) + if self.gradient_checkpointing and self.training and not use_cache: + hidden_states, aux_loss, z_loss = self._gradient_checkpointing_forward( + layer, + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + present_key_value = None + else: + hidden_states, present_key_value, aux_loss, z_loss = layer( + hidden_states, + attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + use_cache=use_cache, + ) + if use_cache: + present_key_values.append(present_key_value) + total_aux_loss += aux_loss + total_z_loss += z_loss + if self.training: + self.metrics.log(f"layer_{idx}_aux_loss", aux_loss) + self.metrics.log(f"layer_{idx}_z_loss", z_loss) + hidden_states = self.norm(hidden_states) + if output_hidden_states: + all_hidden_states.append(hidden_states) + self.metrics.step() + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + present_key_values, + all_hidden_states, + total_aux_loss, + total_z_loss, + ] + if v is not None + ) + return ( + BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=tuple(present_key_values) if use_cache else None, + hidden_states=all_hidden_states, + attentions=None, + ), + total_aux_loss, + total_z_loss, + ) + + def _gradient_checkpointing_forward( + self, + layer, + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + ): + from torch.utils.checkpoint import checkpoint + + def custom_forward(hidden_states): + output, _, aux_loss, z_loss = layer( + hidden_states, + attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=None, + use_cache=False, + ) + return output, aux_loss, z_loss + + hidden_states, aux_loss, z_loss = checkpoint( + custom_forward, + hidden_states, + use_reentrant=False, + ) + return hidden_states, aux_loss, z_loss + + def get_metrics_summary(self): + return self.metrics.get_summary() + + def get_attention_cache_stats(self): + stats = {} + for idx, layer in enumerate(self.layers): + stats[f"layer_{idx}"] = layer.self_attn.get_cache_stats() + return stats + +class CacaForCausalLM(CacaPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = CacaModel(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def forward( + self, + input_ids=None, + pixel_values=None, + audio_features=None, + attention_mask=None, + labels=None, + past_key_values=None, + inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + **kwargs, + ): + if input_ids is not None: + if input_ids.dtype.is_floating_point: + raise TypeError( + f"input_ids must be integer dtype, got {input_ids.dtype}. " + f"Use input_ids.long() to convert." + ) + if (input_ids < 0).any(): + neg_count = (input_ids < 0).sum().item() + raise ValueError( + f"input_ids contains {neg_count} negative values. " + f"All token IDs must be >= 0." + ) + max_id = input_ids.max().item() + if max_id >= self.config.vocab_size: + raise ValueError( + f"input_ids contains ID {max_id} which is >= vocab_size " + f"({self.config.vocab_size}). Valid range is [0, {self.config.vocab_size-1}]." + ) + if self.config.pad_token_id is not None: + pad_only_sequences = (input_ids == self.config.pad_token_id).all(dim=1) + if pad_only_sequences.any(): + pad_count = pad_only_sequences.sum().item() + logger.warning( + f"{pad_count} sequences contain only padding tokens. " + f"This may cause unexpected behavior." + ) + if labels is not None: + if not labels.dtype in [ + torch.long, + torch.int, + torch.int32, + torch.int64, + ]: + raise TypeError(f"labels must be integer dtype, got {labels.dtype}") + if labels.shape != input_ids.shape: + raise ValueError( + f"labels shape {labels.shape} doesn't match " + f"input_ids shape {input_ids.shape}" + ) + valid_labels = labels[labels != -100] + if valid_labels.numel() > 0: + invalid_negs = valid_labels[valid_labels < 0] + if invalid_negs.numel() > 0: + raise ValueError( + f"labels contains {invalid_negs.numel()} invalid negative " + f"values (not -100)" + ) + max_label = valid_labels.max().item() + if max_label >= self.config.vocab_size: + raise ValueError( + f"labels contains ID {max_label} which is >= vocab_size " + f"({self.config.vocab_size})" + ) + else: + logger.warning( + "All labels are -100 (ignored). No loss will be computed." + ) + if attention_mask is not None: + if attention_mask.shape[0] != input_ids.shape[0]: + raise ValueError( + f"attention_mask batch size ({attention_mask.shape[0]}) != " + f"input_ids batch size ({input_ids.shape[0]})" + ) + if attention_mask.shape[1] != input_ids.shape[1]: + raise ValueError( + f"attention_mask seq length ({attention_mask.shape[1]}) != " + f"input_ids seq length ({input_ids.shape[1]})" + ) + unique_vals = attention_mask.unique() + if not torch.all((unique_vals == 0) | (unique_vals == 1)): + logger.warning( + f"attention_mask contains values other than 0 and 1: {unique_vals.tolist()}. " + f"This may cause unexpected behavior." + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + use_amp = ( + self.training + and torch.cuda.is_available() + and hasattr(self.config, "use_amp") + and self.config.use_amp + ) + if use_amp: + from torch.cuda.amp import autocast + + with autocast( + dtype=torch.bfloat16 + if torch.cuda.is_bf16_supported() + else torch.float16 + ): + outputs, aux_loss, z_loss = self.model( + input_ids, + pixel_values=pixel_values, + audio_features=audio_features, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + else: + outputs, aux_loss, z_loss = self.model( + input_ids, + pixel_values=pixel_values, + audio_features=audio_features, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + if return_dict: + hidden_states = outputs.last_hidden_state + else: + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + if self.config.final_logit_softcapping: + logits = soft_cap_logits(logits, self.config.final_logit_softcapping) + loss = None + if labels is not None: + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + shift_logits_flat = shift_logits.view(-1, shift_logits.size(-1)) + shift_labels_flat = shift_labels.view(-1) + if ( + hasattr(self.config, "label_smoothing") + and self.config.label_smoothing > 0 + ): + loss_fct = nn.CrossEntropyLoss( + ignore_index=-100, label_smoothing=self.config.label_smoothing + ) + else: + loss_fct = nn.CrossEntropyLoss(ignore_index=-100) + lm_loss = loss_fct(shift_logits_flat, shift_labels_flat) + if self.training: + with torch.no_grad(): + perplexity = torch.exp(lm_loss) + self.model.metrics.log("perplexity", perplexity.item()) + self.model.metrics.log("lm_loss", lm_loss.item()) + if self.config.use_moe: + total_loss = ( + lm_loss + + (self.config.router_aux_loss_coef * aux_loss) + + (self.config.router_z_loss_coef * z_loss) + ) + if self.training: + self.model.metrics.log("aux_loss", aux_loss.item()) + self.model.metrics.log("z_loss", z_loss.item()) + else: + total_loss = lm_loss + loss = total_loss + if not return_dict: + output = (logits,) + if return_dict: + output += tuple( + v + for v in [outputs.past_key_values, outputs.hidden_states] + if v is not None + ) + return ((loss,) + output) if loss is not None else output + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values if return_dict else None, + hidden_states=outputs.hidden_states if return_dict else None, + attentions=None, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + pixel_values=None, + audio_features=None, + **kwargs, + ): + has_past = False + if past_key_values is not None: + try: + if len(past_key_values) > 0 and past_key_values[0] is not None: + if isinstance(past_key_values[0], (tuple, list)): + if ( + len(past_key_values[0]) > 0 + and past_key_values[0][0] is not None + and past_key_values[0][0].numel() > 0 + ): + has_past = True + except (TypeError, IndexError, AttributeError): + has_past = False + if has_past: + input_ids = input_ids[:, -1:] + pixel_values = None + audio_features = None + if inputs_embeds is not None and not has_past: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + model_inputs.update( + { + "past_key_values": past_key_values if has_past else None, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + "pixel_values": pixel_values, + "audio_features": audio_features, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + if layer_past is not None and len(layer_past) > 0: + reordered_layer_past = () + for past_state in layer_past: + if past_state is not None and past_state.numel() > 0: + reordered_layer_past += ( + past_state.index_select(0, beam_idx.to(past_state.device)), + ) + else: + reordered_layer_past += (None,) + reordered_past += (reordered_layer_past,) + else: + reordered_past += (None,) + return reordered_past + + def save_pretrained(self, save_directory, **kwargs): + has_quant_config = hasattr(self.config, "quantization_config") + quantization_config_backup = getattr(self.config, "quantization_config", None) + if has_quant_config and quantization_config_backup is None: + delattr(self.config, "quantization_config") + try: + super().save_pretrained(save_directory, **kwargs) + if self.training: + metrics_summary = self.model.get_metrics_summary() + cache_stats = self.model.get_attention_cache_stats() + import json + + stats_path = os.path.join(save_directory, "training_stats.json") + with open(stats_path, "w") as f: + json.dump( + {"metrics": metrics_summary, "cache_stats": cache_stats}, + f, + indent=2, + ) + logger.info(f"Saved training statistics to {stats_path}") + finally: + if has_quant_config: + self.config.quantization_config = quantization_config_backup + + def get_model_stats(self): + stats = { + "metrics": self.model.get_metrics_summary(), + "cache_stats": self.model.get_attention_cache_stats(), + "gradient_stats": {}, + } + for idx, layer in enumerate(self.model.layers): + stats["gradient_stats"][f"layer_{idx}"] = layer.get_gradient_stats() + if self.config.use_moe: + expert_usage = {} + for idx, layer in enumerate(self.model.layers): + if hasattr(layer.mlp, "expert_usage_count"): + usage = layer.mlp.expert_usage_count.cpu().numpy().tolist() + expert_usage[f"layer_{idx}"] = usage + stats["expert_usage"] = expert_usage + return stats + +class CacaForCausalLMQuantized(CacaForCausalLM): + def __init__(self, config, quantization_config=None): + super().__init__(config) + self.quantization_config = quantization_config + if quantization_config: + self._apply_quantization() + + def _apply_quantization(self): + if self.quantization_config.get("load_in_8bit"): + self._quantize_8bit() + elif self.quantization_config.get("load_in_4bit"): + self._quantize_4bit() + + def _quantize_8bit(self): + try: + import bitsandbytes as bnb + + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + has_bias = module.bias is not None + new_module = bnb.nn.Linear8bitLt( + module.in_features, + module.out_features, + has_bias, + threshold=self.quantization_config.get( + "llm_int8_threshold", 6.0 + ), + ) + new_module.weight = module.weight + if has_bias: + new_module.bias = module.bias + parent_name = ".".join(name.split(".")[:-1]) + child_name = name.split(".")[-1] + if parent_name: + parent = self.get_submodule(parent_name) + setattr(parent, child_name, new_module) + else: + setattr(self, child_name, new_module) + logger.info("Quantisasi 8-bit berhasil diterapkan") + except ImportError: + logger.error("bitsandbytes tidak terinstall! pip install bitsandbytes") + + def _quantize_4bit(self): + try: + import bitsandbytes as bnb + + compute_dtype = torch.float16 + if self.quantization_config.get("bnb_4bit_compute_dtype"): + compute_dtype = getattr( + torch, self.quantization_config["bnb_4bit_compute_dtype"] + ) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + has_bias = module.bias is not None + new_module = bnb.nn.Linear4bit( + module.in_features, + module.out_features, + bias=has_bias, + compute_dtype=compute_dtype, + quant_type=self.quantization_config.get( + "bnb_4bit_quant_type", "nf4" + ), + use_double_quant=self.quantization_config.get( + "bnb_4bit_use_double_quant", True + ), + ) + new_module.weight = module.weight + if has_bias: + new_module.bias = module.bias + parent_name = ".".join(name.split(".")[:-1]) + child_name = name.split(".")[-1] + if parent_name: + parent = self.get_submodule(parent_name) + setattr(parent, child_name, new_module) + else: + setattr(self, child_name, new_module) + logger.info("Quantisasi 4-bit berhasil diterapkan") + except ImportError: + logger.error("bitsandbytes tidak terinstall!") + + @classmethod + def from_pretrained_quantized(cls, model_path, quantization_config): + config = CacaConfig.from_pretrained(model_path) + model = cls(config, quantization_config=quantization_config) + state_dict = torch.load(f"{model_path}/pytorch_model.bin", map_location="cpu") + model.load_state_dict(state_dict, strict=False) + return model + +class CacaTrainer: + def __init__( + self, + model, + optimizer, + scheduler=None, + gradient_accumulation_steps=1, + max_grad_norm=1.0, + use_amp=False, + ): + self.model = model + self.optimizer = optimizer + self.scheduler = scheduler + self.gradient_accumulation_steps = gradient_accumulation_steps + self.max_grad_norm = max_grad_norm + self.use_amp = use_amp + self.global_step = 0 + self.epoch = 0 + if self.use_amp: + from torch.cuda.amp import GradScaler + + self.scaler = GradScaler() + else: + self.scaler = None + self.train_metrics = defaultdict(list) + + def training_step(self, batch): + self.model.train() + if self.use_amp: + from torch.cuda.amp import autocast + + with autocast( + dtype=torch.bfloat16 + if torch.cuda.is_bf16_supported() + else torch.float16 + ): + outputs = self.model(**batch) + loss = outputs.loss + loss = loss / self.gradient_accumulation_steps + else: + outputs = self.model(**batch) + loss = outputs.loss + loss = loss / self.gradient_accumulation_steps + if self.use_amp: + self.scaler.scale(loss).backward() + else: + loss.backward() + self.train_metrics["loss"].append( + loss.item() * self.gradient_accumulation_steps + ) + return loss.item() * self.gradient_accumulation_steps + + def optimizer_step(self): + if self.use_amp: + self.scaler.unscale_(self.optimizer) + grad_norm = torch.nn.utils.clip_grad_norm_( + self.model.parameters(), self.max_grad_norm + ) + self.scaler.step(self.optimizer) + self.scaler.update() + else: + grad_norm = torch.nn.utils.clip_grad_norm_( + self.model.parameters(), self.max_grad_norm + ) + self.optimizer.step() + self.optimizer.zero_grad(set_to_none=True) + if self.scheduler is not None: + self.scheduler.step() + self.train_metrics["grad_norm"].append(grad_norm.item()) + self.global_step += 1 + return grad_norm.item() + + def train_epoch(self, dataloader, log_interval=10): + self.model.train() + epoch_loss = 0.0 + num_batches = 0 + for step, batch in enumerate(dataloader): + loss = self.training_step(batch) + epoch_loss += loss + num_batches += 1 + if (step + 1) % self.gradient_accumulation_steps == 0: + grad_norm = self.optimizer_step() + if self.global_step % log_interval == 0: + avg_loss = epoch_loss / num_batches + lr = self.optimizer.param_groups[0]["lr"] + logger.info( + f"Epoch {self.epoch} | Step {self.global_step} | " + f"Loss: {avg_loss:.4f} | Grad Norm: {grad_norm:.4f} | " + f"LR: {lr:.2e}" + ) + if hasattr(self.model, "get_model_stats"): + stats = self.model.get_model_stats() + if "metrics" in stats and "perplexity" in stats["metrics"]: + ppl = stats["metrics"]["perplexity"]["last"] + logger.info(f"Perplexity: {ppl:.2f}") + self.epoch += 1 + return epoch_loss / num_batches + + def get_metrics(self): + metrics = {} + for key, values in self.train_metrics.items(): + if values: + metrics[key] = { + "mean": np.mean(values), + "std": np.std(values), + "min": np.min(values), + "max": np.max(values), + } + return metrics \ No newline at end of file