caca-1M-untrained / caca_transformers.py
Lyon28's picture
Add custom modeling file
1f2acb4 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional, Tuple, List
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.generation.utils import GenerationMixin
from collections import OrderedDict
import logging
from functools import lru_cache
logger = logging.getLogger(__name__)
try:
from flash_attn import flash_attn_func
HAS_FLASH_ATTN = True
except ImportError:
HAS_FLASH_ATTN = False
try:
from xformers.ops import memory_efficient_attention
HAS_XFORMERS = True
except ImportError:
HAS_XFORMERS = False
HAS_SDPA = hasattr(F, 'scaled_dot_product_attention')
# --- config ---
class CacaConfig(PretrainedConfig):
model_type = "caca"
def __init__(
self,
vocab_size=32000,
hidden_size=2048,
intermediate_size=8192,
num_hidden_layers=24,
num_attention_heads=32,
num_key_value_heads=8,
head_dim=None,
max_position_embeddings=8192,
rms_norm_eps=1e-6,
qk_norm_eps=1e-6,
use_qk_norm=True,
initializer_range=0.02,
use_cache=True,
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
use_rotary_embeddings=True,
use_alibi=False,
attention_bias=False,
attention_dropout=0.0,
use_flash_attn=True,
use_grouped_query_attention=False,
use_multi_query_attention=False,
sliding_window=None,
use_longformer_attention=False,
longformer_attention_window=512,
attn_logit_softcapping=None,
final_logit_softcapping=None,
attention_sink_size=4,
attention_sink_window=1024,
use_attention_sink=False,
attention_pattern="all_global",
global_attention_every_n_layers=2,
mlp_bias=False,
hidden_dropout=0.1,
residual_dropout=0.1,
use_moe=False,
num_experts=8,
num_experts_per_tok=2,
use_expert_choice=False,
expert_choice_k=0.125,
router_aux_loss_coef=0.01,
router_z_loss_coef=0.001,
moe_layer_frequency=2,
expert_capacity_factor=1.0,
use_grouped_moe=False,
num_expert_groups=1,
use_layer_scale=False,
layer_scale_init=1e-5,
use_stochastic_depth=False,
stochastic_depth_prob=0.1,
use_mixture_of_depths=False,
mod_capacity_factor=0.5,
mod_route_method="learned",
use_cross_attention=False,
cross_attention_frequency=4,
use_multimodal=False,
vision_config=None,
audio_config=None,
projector_hidden_size=None,
use_soft_merging=False,
merge_threshold=0.5,
pretraining_tp=1,
tensor_parallel_size=1,
pipeline_parallel_size=1,
chat_template=None,
**kwargs
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.head_dim = head_dim or (hidden_size // num_attention_heads if hidden_size and num_attention_heads else None)
self.max_position_embeddings = max_position_embeddings
self.rms_norm_eps = rms_norm_eps
self.qk_norm_eps = qk_norm_eps
self.initializer_range = initializer_range
self.use_cache = use_cache
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.tie_word_embeddings = tie_word_embeddings
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.use_rotary_embeddings = use_rotary_embeddings
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.use_qk_norm = use_qk_norm
self.use_alibi = use_alibi
self.use_flash_attn = use_flash_attn
self.use_grouped_query_attention = use_grouped_query_attention
self.use_multi_query_attention = use_multi_query_attention
self.sliding_window = sliding_window
self.use_longformer_attention = use_longformer_attention
self.longformer_attention_window = longformer_attention_window
self.attn_logit_softcapping = attn_logit_softcapping
self.final_logit_softcapping = final_logit_softcapping
self.attention_sink_size = attention_sink_size
self.attention_sink_window = attention_sink_window
self.use_attention_sink = use_attention_sink
self.attention_pattern = attention_pattern
self.global_attention_every_n_layers = global_attention_every_n_layers
self.mlp_bias = mlp_bias
self.hidden_dropout = hidden_dropout
self.residual_dropout = residual_dropout
self.use_moe = use_moe
self.num_experts = num_experts
self.num_experts_per_tok = num_experts_per_tok
self.use_expert_choice = use_expert_choice
self.expert_choice_k = expert_choice_k
self.router_aux_loss_coef = router_aux_loss_coef
self.router_z_loss_coef = router_z_loss_coef
self.moe_layer_frequency = moe_layer_frequency
self.expert_capacity_factor = expert_capacity_factor
self.use_grouped_moe = use_grouped_moe
self.num_expert_groups = num_expert_groups
self.use_layer_scale = use_layer_scale
self.layer_scale_init = layer_scale_init
self.use_stochastic_depth = use_stochastic_depth
self.stochastic_depth_prob = stochastic_depth_prob
self.use_mixture_of_depths = use_mixture_of_depths
self.mod_capacity_factor = mod_capacity_factor
self.mod_route_method = mod_route_method
self.use_cross_attention = use_cross_attention
self.cross_attention_frequency = cross_attention_frequency
self.use_multimodal = use_multimodal
self.vision_config = vision_config or {}
self.audio_config = audio_config or {}
self.projector_hidden_size = projector_hidden_size or hidden_size
self.use_soft_merging = use_soft_merging
self.merge_threshold = merge_threshold
self.pretraining_tp = pretraining_tp
self.tensor_parallel_size = tensor_parallel_size
self.pipeline_parallel_size = pipeline_parallel_size
if chat_template is None:
self.chat_template = (
"{% for message in messages %}"
"{% if message['role'] == 'system' %}"
"System: {{ message['content'] }}\n"
"{% elif message['role'] == 'user' %}"
"User: {{ message['content'] }}\n"
"{% elif message['role'] == 'assistant' %}"
"Assistant: {{ message['content'] }}\n"
"{% endif %}"
"{% endfor %}"
"{% if add_generation_prompt %}Assistant:{% endif %}"
)
else:
self.chat_template = chat_template
self._validate_config()
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs
)
def _validate_config(self):
if self.num_attention_heads % self.num_key_value_heads != 0:
raise ValueError(
f"num_attention_heads ({self.num_attention_heads}) must be divisible by "
f"num_key_value_heads ({self.num_key_value_heads})"
)
if self.num_key_value_heads > self.num_attention_heads:
raise ValueError(
f"num_key_value_heads ({self.num_key_value_heads}) cannot be > "
f"num_attention_heads ({self.num_attention_heads})"
)
if self.hidden_size % self.num_attention_heads != 0:
raise ValueError(
f"hidden_size ({self.hidden_size}) must be divisible by "
f"num_attention_heads ({self.num_attention_heads})"
)
expected_head_dim = self.hidden_size // self.num_attention_heads
if self.head_dim is None:
self.head_dim = expected_head_dim
else:
total_dim = self.head_dim * self.num_attention_heads
if total_dim != self.hidden_size:
logger.warning(
f"head_dim ({self.head_dim}) * num_attention_heads ({self.num_attention_heads}) "
f"= {total_dim} != hidden_size ({self.hidden_size}). "
f"This may cause dimension mismatch issues."
)
if self.intermediate_size <= 0:
raise ValueError(f"intermediate_size must be > 0, got {self.intermediate_size}")
ffn_ratio = self.intermediate_size / self.hidden_size
if ffn_ratio < 1.5 or ffn_ratio > 10:
logger.warning(
f"intermediate_size/hidden_size ratio ({ffn_ratio:.2f}) is unusual. "
f"Typical range is 2-8x. Got intermediate_size={self.intermediate_size}, "
f"hidden_size={self.hidden_size}"
)
if self.vocab_size <= 0:
raise ValueError(f"vocab_size must be > 0, got {self.vocab_size}")
if self.vocab_size > 1000000:
logger.warning(
f"vocab_size ({self.vocab_size}) is very large. "
f"This may cause memory issues."
)
if self.use_flash_attn and not HAS_FLASH_ATTN:
logger.warning(
"use_flash_attn=True but flash-attn not installed. "
"Will fallback to SDPA/standard attention."
)
if self.sliding_window is not None:
if self.sliding_window > self.max_position_embeddings:
raise ValueError(
f"sliding_window ({self.sliding_window}) cannot be > "
f"max_position_embeddings ({self.max_position_embeddings})"
)
if self.sliding_window < 128:
logger.warning(
f"sliding_window ({self.sliding_window}) is very small. "
f"This may limit context significantly."
)
if self.use_moe:
if self.num_experts < self.num_experts_per_tok:
raise ValueError(
f"num_experts ({self.num_experts}) must be >= "
f"num_experts_per_tok ({self.num_experts_per_tok})"
)
if self.moe_layer_frequency <= 0:
raise ValueError(f"moe_layer_frequency must be > 0")
if self.moe_layer_frequency > self.num_hidden_layers:
logger.warning(
f"moe_layer_frequency ({self.moe_layer_frequency}) > "
f"num_hidden_layers ({self.num_hidden_layers}). "
f"MoE will not be used."
)
if self.expert_capacity_factor <= 0:
raise ValueError(f"expert_capacity_factor must be > 0")
def to_dict(self):
quantization_config_backup = getattr(self, 'quantization_config', None)
if quantization_config_backup is None and hasattr(self, 'quantization_config'):
delattr(self, 'quantization_config')
temp_removed = True
else:
temp_removed = False
try:
output = super().to_dict()
output['auto_map'] = {
"AutoConfig": "caca_transformers.CacaConfig",
"AutoModel": "caca_transformers.CacaModel",
"AutoModelForCausalLM": "caca_transformers.CacaForCausalLM"
}
return output
finally:
if temp_removed:
self.quantization_config = None
elif quantization_config_backup is not None:
self.quantization_config = quantization_config_backup
@classmethod
def from_dict(cls, config_dict, **kwargs):
config_dict = {k: v for k, v in config_dict.items() if k != 'auto_map'}
config_dict.update(kwargs)
return cls(**config_dict)
# --- Arsitektur Model ---
class MetricsTracker:
def __init__(self):
self.metrics = defaultdict(list)
self.reset_interval = 100
self.step_count = 0
def log(self, name, value):
if isinstance(value, torch.Tensor):
value = value.item()
self.metrics[name].append(value)
def step(self):
self.step_count += 1
if self.step_count % self.reset_interval == 0:
self.clear()
def get_summary(self):
summary = {}
for name, values in self.metrics.items():
if values:
summary[name] = {
"mean": np.mean(values),
"std": np.std(values),
"min": np.min(values),
"max": np.max(values),
"last": values[-1],
}
return summary
def clear(self):
self.metrics.clear()
class CacaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.eps = eps
self.variance_epsilon = eps
def forward(self, x):
input_dtype = x.dtype
x = x.float()
variance = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
return (self.weight * x).to(input_dtype)
class LayerScale(nn.Module):
def __init__(self, dim, init_value=1e-5):
super().__init__()
self.gamma = nn.Parameter(init_value * torch.ones(dim))
def forward(self, x):
return self.gamma * x
class StochasticDepth(nn.Module):
def __init__(self, drop_prob=0.0):
super().__init__()
self.drop_prob = drop_prob
def forward(self, x, training=True):
if not training or self.drop_prob == 0.0:
return x
keep_prob = 1 - self.drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_()
return x.div(keep_prob) * random_tensor
class CacaRotaryEmbedding(nn.Module):
def __init__(
self,
dim,
max_position_embeddings=8192,
base=10000.0,
scaling_factor=1.0,
scaling_type=None,
):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.scaling_factor = scaling_factor
self.scaling_type = scaling_type
inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)
)
if scaling_type == "linear":
inv_freq = inv_freq / scaling_factor
elif scaling_type == "dynamic":
inv_freq = inv_freq
elif scaling_type == "yarn":
inv_freq = self._yarn_get_inv_freq(inv_freq)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._cos_cache = None
self._sin_cache = None
self._cached_seq_len = 0
def _yarn_get_inv_freq(self, inv_freq):
if len(inv_freq) == 0:
return inv_freq
alpha = self.scaling_factor
beta_fast = 32
beta_slow = 1
freq_threshold = 1 / (self.max_position_embeddings * beta_fast)
low_freq_mask = inv_freq > freq_threshold
high_freq_mask = ~low_freq_mask
low_freq = inv_freq[low_freq_mask]
high_freq = inv_freq[high_freq_mask]
if len(low_freq) > 0:
low_freq = low_freq / alpha
if len(high_freq) > 0:
smooth_factor = (
self.max_position_embeddings * beta_slow / high_freq - beta_fast
) / (beta_slow - beta_fast)
smooth_factor = torch.clamp(smooth_factor, 0.0, 1.0)
high_freq = (1 - smooth_factor) * (
high_freq / alpha
) + smooth_factor * high_freq
result = torch.zeros_like(inv_freq)
result[low_freq_mask] = low_freq
result[high_freq_mask] = high_freq
return result
def forward(self, x, seq_len, position_offset=0):
if (
self._cos_cache is not None
and self._sin_cache is not None
and self._cached_seq_len >= seq_len
and position_offset == 0
):
return (
self._cos_cache[:, :, :seq_len, :].to(x.dtype),
self._sin_cache[:, :, :seq_len, :].to(x.dtype),
)
t = torch.arange(
position_offset, position_offset + seq_len, device=x.device
).type_as(self.inv_freq)
if self.scaling_type == "dynamic":
if seq_len > self.max_position_embeddings:
dynamic_scale = seq_len / self.max_position_embeddings
t = t / dynamic_scale
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()[None, None, :, :]
sin = emb.sin()[None, None, :, :]
if position_offset == 0 and seq_len > self._cached_seq_len:
self._cos_cache = cos
self._sin_cache = sin
self._cached_seq_len = seq_len
return cos.to(x.dtype), sin.to(x.dtype)
class ALiBiPositionalBias(nn.Module):
def __init__(self, num_heads, max_positions=8192):
super().__init__()
self.num_heads = num_heads
self.max_positions = max_positions
slopes = torch.tensor(self._get_slopes(num_heads))
self.register_buffer("slopes", slopes, persistent=False)
def _get_slopes(self, n):
def get_slopes_power_of_2(n):
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
ratio = start
return [start * (ratio**i) for i in range(n)]
if math.log2(n).is_integer():
return get_slopes_power_of_2(n)
else:
closest_power_of_2 = 2 ** math.floor(math.log2(n))
return (
get_slopes_power_of_2(closest_power_of_2)
+ self._get_slopes(2 * closest_power_of_2)[0::2][
: n - closest_power_of_2
]
)
def forward(self, seq_len, key_len=None):
if key_len is None:
key_len = seq_len
query_pos = torch.arange(seq_len, device=self.slopes.device).unsqueeze(1)
key_pos = torch.arange(key_len, device=self.slopes.device).unsqueeze(0)
relative_pos = key_pos - query_pos
bias = relative_pos.unsqueeze(0) * self.slopes.unsqueeze(1).unsqueeze(2)
return bias.unsqueeze(0)
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin):
cos = cos.to(q.dtype)
sin = sin.to(q.dtype)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def soft_cap_logits(x, cap):
if cap is None or cap <= 0:
return x
return x.clamp(-cap * 0.99, cap * 0.99)
class TopKRouter(nn.Module):
def __init__(self, hidden_size, num_experts, num_experts_per_tok):
super().__init__()
self.num_experts = num_experts
self.num_experts_per_tok = num_experts_per_tok
self.gate = nn.Linear(hidden_size, num_experts, bias=False)
self.gate_norm = nn.LayerNorm(hidden_size)
self.temperature = nn.Parameter(torch.ones(1))
self.jitter_noise = 0.01
def forward(self, hidden_states):
batch_size, seq_len, hidden_size = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_size)
hidden_states = self.gate_norm(hidden_states)
router_logits = self.gate(hidden_states)
router_logits = torch.clamp(router_logits, min=-20, max=20)
temperature = torch.clamp(self.temperature, min=0.1, max=10.0)
router_logits = router_logits / temperature
if self.training and self.jitter_noise > 0:
noise = torch.randn_like(router_logits) * self.jitter_noise
router_logits = router_logits + noise
routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float32)
top_k_weights, top_k_indices = torch.topk(
routing_weights, self.num_experts_per_tok, dim=-1
)
top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-9)
router_probs = routing_weights
expert_usage = router_probs.mean(dim=0)
mean_usage = expert_usage.mean()
aux_loss = expert_usage.std() / (mean_usage + 1e-10)
router_logits_for_z = router_logits.to(torch.float32)
z_loss = torch.logsumexp(router_logits_for_z, dim=-1).pow(2).mean()
router_probs_log = torch.log(router_probs + 1e-10)
entropy = -(router_probs * router_probs_log).sum(dim=-1).mean()
entropy_loss = -0.01 * entropy
total_aux_loss = aux_loss + entropy_loss
return top_k_weights, top_k_indices, total_aux_loss, z_loss
class ExpertChoiceRouter(nn.Module):
def __init__(self, hidden_size, num_experts, expert_choice_k):
super().__init__()
self.num_experts = num_experts
self.expert_choice_k = expert_choice_k
self.gate = nn.Linear(hidden_size, num_experts, bias=False)
def forward(self, hidden_states):
batch_size, seq_len, hidden_size = hidden_states.shape
total_tokens = batch_size * seq_len
hidden_states_flat = hidden_states.view(-1, hidden_size)
router_logits = self.gate(hidden_states_flat)
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32)
router_probs_t = router_probs.t()
capacity = max(1, int(self.expert_choice_k * total_tokens / self.num_experts))
top_k_values, top_k_indices = torch.topk(
router_probs_t, k=min(capacity, total_tokens), dim=-1
)
expert_mask = torch.zeros(
self.num_experts, total_tokens, device=hidden_states.device
)
for expert_idx in range(self.num_experts):
expert_mask[expert_idx, top_k_indices[expert_idx]] = 1.0
routing_weights = expert_mask.t() * router_probs
aux_loss = (router_probs.mean(dim=0) ** 2).sum() * self.num_experts
z_loss = torch.logsumexp(router_logits, dim=-1).mean()
return routing_weights, aux_loss, z_loss
class Expert(nn.Module):
def __init__(self, config):
super().__init__()
self.gate_proj = nn.Linear(
config.hidden_size, config.intermediate_size, bias=config.mlp_bias
)
self.up_proj = nn.Linear(
config.hidden_size, config.intermediate_size, bias=config.mlp_bias
)
self.down_proj = nn.Linear(
config.intermediate_size, config.hidden_size, bias=config.mlp_bias
)
self.dropout = nn.Dropout(config.hidden_dropout)
def forward(self, x):
gate = F.silu(self.gate_proj(x))
up = self.up_proj(x)
hidden = gate * up
hidden = self.dropout(hidden)
return self.down_proj(hidden)
class MixtureOfExperts(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.num_experts = config.num_experts
self.num_experts_per_tok = config.num_experts_per_tok
self.use_expert_choice = config.use_expert_choice
self.experts = nn.ModuleList([Expert(config) for _ in range(self.num_experts)])
if self.use_expert_choice:
self.router = ExpertChoiceRouter(
config.hidden_size, config.num_experts, config.expert_choice_k
)
else:
self.router = TopKRouter(
config.hidden_size, config.num_experts, config.num_experts_per_tok
)
self.register_buffer(
"expert_usage_count", torch.zeros(self.num_experts, dtype=torch.long)
)
def forward(self, hidden_states):
batch_size, seq_len, hidden_size = hidden_states.shape
hidden_states_flat = hidden_states.view(-1, hidden_size)
if (
torch.isnan(hidden_states_flat).any()
or torch.isinf(hidden_states_flat).any()
):
logger.error("NaN or Inf detected in MoE input. Using fallback.")
return (
hidden_states,
torch.tensor(0.0, device=hidden_states.device),
torch.tensor(0.0, device=hidden_states.device),
)
if self.use_expert_choice:
routing_weights, aux_loss, z_loss = self.router(hidden_states)
final_output = torch.zeros_like(hidden_states_flat)
for expert_idx, expert in enumerate(self.experts):
expert_mask = routing_weights[:, expert_idx] > 1e-6
if expert_mask.any():
if not self.training:
self.expert_usage_count[expert_idx] += expert_mask.sum()
try:
expert_input = hidden_states_flat[expert_mask]
expert_output = expert(expert_input)
if (
torch.isnan(expert_output).any()
or torch.isinf(expert_output).any()
):
logger.warning(
f"Expert {expert_idx} produced NaN/Inf. Skipping."
)
continue
final_output[expert_mask] += (
expert_output
* routing_weights[expert_mask, expert_idx : expert_idx + 1]
)
except RuntimeError as e:
logger.error(f"Expert {expert_idx} failed: {e}")
continue
else:
top_k_weights, top_k_indices, aux_loss, z_loss = self.router(hidden_states)
final_output = torch.zeros_like(hidden_states_flat)
for expert_idx in range(self.num_experts):
expert_mask = (top_k_indices == expert_idx).any(dim=-1)
if expert_mask.any():
if not self.training:
self.expert_usage_count[expert_idx] += expert_mask.sum()
try:
expert_input = hidden_states_flat[expert_mask]
expert_output = self.experts[expert_idx](expert_input)
if (
torch.isnan(expert_output).any()
or torch.isinf(expert_output).any()
):
logger.warning(
f"Expert {expert_idx} produced NaN/Inf. Skipping."
)
continue
token_indices = torch.where(expert_mask)[0]
for i, token_idx in enumerate(token_indices):
expert_positions = (
top_k_indices[token_idx] == expert_idx
).nonzero(as_tuple=True)[0]
if len(expert_positions) > 0:
weight = top_k_weights[token_idx, expert_positions[0]]
final_output[token_idx] += expert_output[i] * weight
except RuntimeError as e:
logger.error(f"Expert {expert_idx} failed: {e}")
continue
final_output = final_output.view(batch_size, seq_len, hidden_size)
if torch.isnan(final_output).any() or torch.isinf(final_output).any():
logger.error("MoE output contains NaN/Inf. Returning input.")
return hidden_states, aux_loss, z_loss
return final_output, aux_loss, z_loss
class MixtureOfDepthsRouter(nn.Module):
def __init__(self, hidden_size, capacity_factor=0.5, route_method="learned"):
super().__init__()
self.capacity_factor = capacity_factor
self.route_method = route_method
if route_method == "learned":
self.router = nn.Linear(hidden_size, 1)
def forward(self, hidden_states):
batch_size, seq_len, hidden_size = hidden_states.shape
if self.route_method == "learned":
routing_logits = self.router(hidden_states).squeeze(-1)
elif self.route_method == "random":
routing_logits = torch.rand(
batch_size, seq_len, device=hidden_states.device
)
else:
routing_logits = torch.zeros(
batch_size, seq_len, device=hidden_states.device
)
capacity = max(1, int(seq_len * self.capacity_factor))
_, top_indices = torch.topk(routing_logits, k=capacity, dim=-1)
process_mask = torch.zeros(
batch_size, seq_len, dtype=torch.bool, device=hidden_states.device
)
process_mask.scatter_(1, top_indices, True)
return process_mask
class CacaAttention(nn.Module):
def __init__(self, config, layer_idx=None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads
self.head_dim = config.head_dim
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.sliding_window = config.sliding_window
self.attn_logit_softcapping = config.attn_logit_softcapping
self.attention_sink_size = config.attention_sink_size
self.attention_sink_window = config.attention_sink_window
self.q_proj = nn.Linear(
self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias
)
self.k_proj = nn.Linear(
self.hidden_size,
self.num_key_value_heads * self.head_dim,
bias=config.attention_bias,
)
self.v_proj = nn.Linear(
self.hidden_size,
self.num_key_value_heads * self.head_dim,
bias=config.attention_bias,
)
self.o_proj = nn.Linear(
self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias
)
if config.use_qk_norm:
self.q_norm = CacaRMSNorm(self.head_dim, eps=config.qk_norm_eps)
self.k_norm = CacaRMSNorm(self.head_dim, eps=config.qk_norm_eps)
else:
self.q_norm = None
self.k_norm = None
if config.use_rotary_embeddings:
scaling_factor = 1.0
scaling_type = None
if config.rope_scaling is not None:
scaling_type = config.rope_scaling.get("type", "linear")
scaling_factor = config.rope_scaling.get("factor", 1.0)
self.rotary_emb = CacaRotaryEmbedding(
self.head_dim,
config.max_position_embeddings,
config.rope_theta,
scaling_factor=scaling_factor,
scaling_type=scaling_type,
)
else:
self.rotary_emb = None
if config.use_alibi:
self.alibi = ALiBiPositionalBias(
self.num_heads, config.max_position_embeddings
)
else:
self.alibi = None
self.attention_dropout = nn.Dropout(config.attention_dropout)
self.is_global_attention = self._determine_attention_type(config, layer_idx)
self.has_flash_attn = HAS_FLASH_ATTN and config.use_flash_attn
self.has_xformers = HAS_XFORMERS
self.has_sdpa = HAS_SDPA
self._mask_cache = OrderedDict()
self._max_cache_size = 10
self._cache_hits = 0
self._cache_misses = 0
def _determine_attention_type(self, config, layer_idx):
if layer_idx is None:
return False
if config.attention_pattern == "all_global":
return True
elif config.attention_pattern == "all_local":
return False
elif config.attention_pattern == "interleaved":
return (layer_idx % config.global_attention_every_n_layers) == (
config.global_attention_every_n_layers - 1
)
return False
def _create_causal_mask(
self, query_length, key_length, dtype, device, use_sliding_window
):
device_key = (device.type, device.index if device.type == "cuda" else None)
cache_key = (
query_length,
key_length,
str(dtype),
device_key,
use_sliding_window,
self.sliding_window if use_sliding_window else None,
)
if cache_key in self._mask_cache:
self._cache_hits += 1
self._mask_cache.move_to_end(cache_key)
cached_mask = self._mask_cache[cache_key]
return cached_mask.to(dtype).to(device)
self._cache_misses += 1
if query_length > key_length:
key_length = query_length
query_pos = torch.arange(query_length, device=device) + (
key_length - query_length
)
key_pos = torch.arange(key_length, device=device)
distance = query_pos[:, None] - key_pos[None, :]
mask = distance < 0
if use_sliding_window and self.sliding_window is not None:
if self.config.use_attention_sink and self.attention_sink_size > 0:
is_sink = key_pos[None, :] < self.attention_sink_size
in_window = (distance >= 0) & (distance <= self.sliding_window)
mask = (distance < 0) | ((~is_sink) & (~in_window))
else:
too_far_mask = distance > self.sliding_window
mask = mask | too_far_mask
float_mask = torch.zeros(
1, 1, query_length, key_length, dtype=dtype, device=device
)
float_mask.masked_fill_(mask.unsqueeze(0).unsqueeze(0), -1e9)
if len(self._mask_cache) >= self._max_cache_size:
self._mask_cache.popitem(last=False)
self._mask_cache[cache_key] = float_mask.detach().cpu()
return float_mask
def get_cache_stats(self):
total_requests = self._cache_hits + self._cache_misses
hit_rate = self._cache_hits / total_requests if total_requests > 0 else 0
return {
"hits": self._cache_hits,
"misses": self._cache_misses,
"hit_rate": hit_rate,
"cache_size": len(self._mask_cache),
}
def forward(
self, hidden_states, attention_mask=None, past_key_value=None, use_cache=False
):
if hidden_states is None:
raise ValueError("hidden_states cannot be None")
if hidden_states.shape[-1] != self.hidden_size:
raise ValueError(
f"Expected hidden_size {self.hidden_size}, got {hidden_states.shape[-1]}"
)
batch_size, seq_length, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(
batch_size, seq_length, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
batch_size, seq_length, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
batch_size, seq_length, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
if self.q_norm is not None and self.k_norm is not None:
query_states = self.q_norm(query_states)
key_states = self.k_norm(key_states)
position_offset = 0
if past_key_value is not None:
try:
if (
isinstance(past_key_value, (tuple, list))
and len(past_key_value) >= 2
and past_key_value[0] is not None
and past_key_value[0].numel() > 0
):
position_offset = past_key_value[0].shape[2]
except (IndexError, AttributeError, TypeError) as e:
logger.debug(f"Could not get position_offset: {e}")
position_offset = 0
if self.rotary_emb is not None:
cos, sin = self.rotary_emb(query_states, seq_length, position_offset)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin
)
if past_key_value is not None and past_key_value[0] is not None:
try:
if past_key_value[0].numel() > 0 and past_key_value[1].numel() > 0:
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
except RuntimeError as e:
logger.error(f"Failed to concatenate past_key_value: {e}")
if use_cache:
present_key_value = (key_states, value_states)
else:
present_key_value = None
key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1)
value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1)
kv_seq_len = key_states.shape[-2]
use_sliding_window = (not self.is_global_attention) and (
self.sliding_window is not None
)
if self.has_flash_attn and attention_mask is None:
if query_states.device.type == "cuda" and query_states.dtype in [
torch.float16,
torch.bfloat16,
]:
try:
attn_output = self._flash_attention(
query_states, key_states, value_states, use_sliding_window
)
except Exception as e:
logger.warning(f"Flash Attention failed: {e}")
attn_output = self._fallback_attention(
query_states,
key_states,
value_states,
attention_mask,
kv_seq_len,
use_sliding_window,
)
else:
attn_output = self._fallback_attention(
query_states,
key_states,
value_states,
attention_mask,
kv_seq_len,
use_sliding_window,
)
else:
attn_output = self._fallback_attention(
query_states,
key_states,
value_states,
attention_mask,
kv_seq_len,
use_sliding_window,
)
attn_output = self.o_proj(attn_output)
return attn_output, present_key_value
def _flash_attention(
self, query_states, key_states, value_states, use_sliding_window
):
batch_size, num_heads, seq_length, head_dim = query_states.shape
kv_seq_len = key_states.shape[-2]
if key_states.shape[0] != batch_size or key_states.shape[1] != num_heads:
raise ValueError(
f"Shape mismatch: query={query_states.shape}, key={key_states.shape}"
)
original_dtype = query_states.dtype
if original_dtype == torch.bfloat16:
if not torch.cuda.is_bf16_supported():
logger.warning("BF16 not supported, using FP16")
original_dtype = torch.float16
compute_dtype = (
torch.bfloat16
if original_dtype not in [torch.float16, torch.bfloat16]
else original_dtype
)
query_states = query_states.transpose(1, 2).contiguous().to(compute_dtype)
key_states = key_states.transpose(1, 2).contiguous().to(compute_dtype)
value_states = value_states.transpose(1, 2).contiguous().to(compute_dtype)
if use_sliding_window and self.sliding_window < kv_seq_len:
window_size = (self.sliding_window, 0)
else:
window_size = (-1, 0)
attn_output = flash_attn_func(
query_states,
key_states,
value_states,
dropout_p=self.config.attention_dropout if self.training else 0.0,
softmax_scale=None,
causal=True,
window_size=window_size,
)
attn_output = attn_output.to(original_dtype)
attn_output = attn_output.reshape(batch_size, seq_length, self.hidden_size)
return attn_output
def _fallback_attention(
self,
query_states,
key_states,
value_states,
attention_mask,
kv_seq_len,
use_sliding_window,
):
device_type = query_states.device.type
if self.has_xformers and device_type == "cuda" and attention_mask is None:
try:
return self._xformers_attention(
query_states,
key_states,
value_states,
kv_seq_len,
use_sliding_window,
)
except Exception as e:
logger.debug(f"xFormers failed: {e}")
if self.has_sdpa:
return self._sdpa_attention(
query_states,
key_states,
value_states,
attention_mask,
kv_seq_len,
use_sliding_window,
)
else:
return self._standard_attention(
query_states,
key_states,
value_states,
attention_mask,
kv_seq_len,
use_sliding_window,
)
def _sdpa_attention(
self,
query_states,
key_states,
value_states,
attention_mask,
kv_seq_len,
use_sliding_window,
):
batch_size, num_heads, seq_length, head_dim = query_states.shape
if attention_mask is None:
attention_mask = self._create_causal_mask(
seq_length,
kv_seq_len,
query_states.dtype,
query_states.device,
use_sliding_window,
)
if self.alibi is not None:
alibi_bias = self.alibi(seq_length, kv_seq_len)
attention_mask = attention_mask + alibi_bias
attn_output = F.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.config.attention_dropout if self.training else 0.0,
is_causal=False,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, seq_length, self.hidden_size)
return attn_output
def _xformers_attention(
self, query_states, key_states, value_states, kv_seq_len, use_sliding_window
):
batch_size, num_heads, seq_length, head_dim = query_states.shape
attn_bias = self._create_causal_mask(
seq_length,
kv_seq_len,
query_states.dtype,
query_states.device,
use_sliding_window,
)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
attn_output = memory_efficient_attention(
query_states,
key_states,
value_states,
attn_bias=attn_bias,
p=self.config.attention_dropout if self.training else 0.0,
)
attn_output = attn_output.reshape(batch_size, seq_length, self.hidden_size)
return attn_output
def _standard_attention(
self,
query_states,
key_states,
value_states,
attention_mask,
kv_seq_len,
use_sliding_window,
):
batch_size, num_heads, seq_length, head_dim = query_states.shape
attn_weights = torch.matmul(
query_states, key_states.transpose(2, 3)
) / math.sqrt(head_dim)
attn_weights = torch.clamp(attn_weights, min=-50.0, max=50.0)
attn_weights = soft_cap_logits(attn_weights, self.attn_logit_softcapping)
if attention_mask is None:
attention_mask = self._create_causal_mask(
seq_length,
kv_seq_len,
attn_weights.dtype,
attn_weights.device,
use_sliding_window,
)
if self.alibi is not None:
alibi_bias = self.alibi(seq_length, kv_seq_len)
attention_mask = attention_mask + alibi_bias
attn_weights = attn_weights + attention_mask
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
query_states.dtype
)
attn_weights = self.attention_dropout(attn_weights)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, seq_length, self.hidden_size)
return attn_output
class CacaCrossAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads
self.head_dim = config.head_dim
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.q_proj = nn.Linear(
self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias
)
self.k_proj = nn.Linear(
self.hidden_size,
self.num_key_value_heads * self.head_dim,
bias=config.attention_bias,
)
self.v_proj = nn.Linear(
self.hidden_size,
self.num_key_value_heads * self.head_dim,
bias=config.attention_bias,
)
self.o_proj = nn.Linear(
self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias
)
self.attention_dropout = nn.Dropout(config.attention_dropout)
def forward(self, hidden_states, encoder_hidden_states, attention_mask=None):
batch_size, seq_length, _ = hidden_states.size()
encoder_seq_length = encoder_hidden_states.size(1)
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(encoder_hidden_states)
value_states = self.v_proj(encoder_hidden_states)
query_states = query_states.view(
batch_size, seq_length, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
batch_size, encoder_seq_length, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
batch_size, encoder_seq_length, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1)
value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1)
attn_weights = torch.matmul(
query_states, key_states.transpose(2, 3)
) / math.sqrt(self.head_dim)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
query_states.dtype
)
attn_weights = self.attention_dropout(attn_weights)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, seq_length, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output
class CacaMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(
self.hidden_size, self.intermediate_size, bias=config.mlp_bias
)
self.up_proj = nn.Linear(
self.hidden_size, self.intermediate_size, bias=config.mlp_bias
)
self.down_proj = nn.Linear(
self.intermediate_size, self.hidden_size, bias=config.mlp_bias
)
self.dropout = nn.Dropout(config.hidden_dropout)
def forward(self, x):
gate = F.silu(self.gate_proj(x))
up = self.up_proj(x)
hidden = gate * up
hidden = self.dropout(hidden)
output = self.down_proj(hidden)
return output
class CacaDecoderLayer(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.self_attn = CacaAttention(config, layer_idx=layer_idx)
self.use_moe = config.use_moe and (layer_idx % config.moe_layer_frequency == 0)
if self.use_moe:
self.mlp = MixtureOfExperts(config)
else:
self.mlp = CacaMLP(config)
self.use_cross_attention = config.use_cross_attention and (
layer_idx % config.cross_attention_frequency == 0
)
if self.use_cross_attention:
self.cross_attn = CacaCrossAttention(config)
self.cross_attn_layernorm = CacaRMSNorm(
config.hidden_size, config.rms_norm_eps
)
self.input_layernorm = CacaRMSNorm(config.hidden_size, config.rms_norm_eps)
self.post_attention_layernorm = CacaRMSNorm(
config.hidden_size, config.rms_norm_eps
)
self.residual_dropout = nn.Dropout(config.residual_dropout)
if config.use_layer_scale:
self.layer_scale_1 = LayerScale(config.hidden_size, config.layer_scale_init)
self.layer_scale_2 = LayerScale(config.hidden_size, config.layer_scale_init)
if self.use_cross_attention:
self.layer_scale_cross = LayerScale(
config.hidden_size, config.layer_scale_init
)
else:
self.layer_scale_1 = None
self.layer_scale_2 = None
self.layer_scale_cross = None
if config.use_stochastic_depth:
drop_prob = (
config.stochastic_depth_prob * layer_idx / config.num_hidden_layers
)
self.stochastic_depth = StochasticDepth(drop_prob)
else:
self.stochastic_depth = None
if config.use_mixture_of_depths:
self.mod_router = MixtureOfDepthsRouter(
config.hidden_size, config.mod_capacity_factor, config.mod_route_method
)
else:
self.mod_router = None
self.gradient_stats = {
"max_grad_norm": 0.0,
"mean_grad_norm": 0.0,
"grad_clip_count": 0,
}
def _gradient_monitor_hook(self, grad):
if grad is None:
return grad
grad_norm = grad.norm().item()
self.gradient_stats["max_grad_norm"] = max(
self.gradient_stats["max_grad_norm"], grad_norm
)
self.gradient_stats["mean_grad_norm"] = (
0.9 * self.gradient_stats["mean_grad_norm"] + 0.1 * grad_norm
)
max_grad = 10.0
if grad_norm > max_grad:
self.gradient_stats["grad_clip_count"] += 1
if self.gradient_stats["grad_clip_count"] % 100 == 0:
logger.warning(
f"Layer {self.layer_idx}: High gradient norm {grad_norm:.2f} "
f"(clipped {self.gradient_stats['grad_clip_count']} times)"
)
grad = torch.clamp(grad, min=-max_grad, max=max_grad)
return grad
def forward(
self,
hidden_states,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
use_cache=False,
):
if hidden_states is None:
raise ValueError("hidden_states cannot be None")
if torch.isnan(hidden_states).any() or torch.isinf(hidden_states).any():
logger.error(f"Layer {self.layer_idx}: NaN/Inf in input!")
hidden_states = torch.nan_to_num(
hidden_states, nan=0.0, posinf=1e4, neginf=-1e4
)
aux_loss = 0.0
z_loss = 0.0
if self.training and hidden_states.requires_grad:
hidden_states.register_hook(self._gradient_monitor_hook)
if self.mod_router is not None:
process_mask = self.mod_router(hidden_states)
tokens_to_process = hidden_states[process_mask]
if tokens_to_process.numel() == 0:
present_key_value = past_key_value if use_cache else None
return hidden_states, present_key_value, aux_loss, z_loss
else:
process_mask = None
tokens_to_process = hidden_states
residual = tokens_to_process
tokens_to_process = self.input_layernorm(tokens_to_process)
attn_output, present_key_value = self.self_attn(
tokens_to_process,
attention_mask,
past_key_value=past_key_value,
use_cache=use_cache,
)
if torch.isnan(attn_output).any() or torch.isinf(attn_output).any():
logger.warning(
f"Layer {self.layer_idx}: NaN/Inf in attention. Using zero output."
)
attn_output = torch.zeros_like(attn_output)
if self.layer_scale_1 is not None:
attn_output = self.layer_scale_1(attn_output)
if self.stochastic_depth is not None:
attn_output = self.stochastic_depth(attn_output, self.training)
tokens_to_process = residual + self.residual_dropout(attn_output)
if self.training:
tokens_to_process = torch.clamp(tokens_to_process, min=-1e4, max=1e4)
if self.use_cross_attention and encoder_hidden_states is not None:
residual = tokens_to_process
tokens_to_process = self.cross_attn_layernorm(tokens_to_process)
cross_attn_output = self.cross_attn(
tokens_to_process,
encoder_hidden_states,
attention_mask=encoder_attention_mask,
)
if (
torch.isnan(cross_attn_output).any()
or torch.isinf(cross_attn_output).any()
):
logger.warning(
f"Layer {self.layer_idx}: NaN/Inf in cross-attention. Skipping."
)
cross_attn_output = torch.zeros_like(cross_attn_output)
if self.layer_scale_cross is not None:
cross_attn_output = self.layer_scale_cross(cross_attn_output)
if self.stochastic_depth is not None:
cross_attn_output = self.stochastic_depth(
cross_attn_output, self.training
)
tokens_to_process = residual + self.residual_dropout(cross_attn_output)
if self.training:
tokens_to_process = torch.clamp(tokens_to_process, min=-1e4, max=1e4)
residual = tokens_to_process
tokens_to_process = self.post_attention_layernorm(tokens_to_process)
if self.use_moe:
mlp_output, moe_aux_loss, moe_z_loss = self.mlp(tokens_to_process)
aux_loss += moe_aux_loss
z_loss += moe_z_loss
else:
mlp_output = self.mlp(tokens_to_process)
if torch.isnan(mlp_output).any() or torch.isinf(mlp_output).any():
logger.warning(
f"Layer {self.layer_idx}: NaN/Inf in MLP output. Using zero output."
)
mlp_output = torch.zeros_like(mlp_output)
if self.layer_scale_2 is not None:
mlp_output = self.layer_scale_2(mlp_output)
if self.stochastic_depth is not None:
mlp_output = self.stochastic_depth(mlp_output, self.training)
tokens_to_process = residual + self.residual_dropout(mlp_output)
if self.training:
tokens_to_process = torch.clamp(tokens_to_process, min=-1e4, max=1e4)
if process_mask is not None:
hidden_states[process_mask] = tokens_to_process
else:
hidden_states = tokens_to_process
return hidden_states, present_key_value, aux_loss, z_loss
def get_gradient_stats(self):
return self.gradient_stats.copy()
class VisionEncoder(nn.Module):
def __init__(self, config):
super().__init__()
vision_config = config.vision_config
self.patch_size = vision_config.get("patch_size", 14)
self.image_size = vision_config.get("image_size", 224)
self.num_channels = vision_config.get("num_channels", 3)
self.hidden_size = vision_config.get("hidden_size", 1024)
self.num_layers = vision_config.get("num_layers", 24)
self.num_heads = vision_config.get("num_heads", 16)
self.intermediate_size = vision_config.get("intermediate_size", 4096)
self.layer_norm_eps = vision_config.get("layer_norm_eps", 1e-6)
self.num_patches = (self.image_size // self.patch_size) ** 2
self.patch_embed = nn.Sequential(
nn.Conv2d(
self.num_channels,
self.hidden_size,
kernel_size=self.patch_size,
stride=self.patch_size,
bias=False,
),
nn.Dropout(p=vision_config.get("dropout", 0.0)),
)
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.hidden_size))
self.pos_embed = nn.Parameter(
torch.zeros(1, self.num_patches + 1, self.hidden_size)
)
self.pos_drop = nn.Dropout(p=vision_config.get("dropout", 0.0))
self.blocks = nn.ModuleList(
[
VisionTransformerBlock(
dim=self.hidden_size,
num_heads=self.num_heads,
mlp_ratio=self.intermediate_size / self.hidden_size,
dropout=vision_config.get("dropout", 0.0),
layer_norm_eps=self.layer_norm_eps,
)
for _ in range(self.num_layers)
]
)
self.norm = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
self._init_weights()
def _init_weights(self):
nn.init.trunc_normal_(self.pos_embed, std=0.02)
nn.init.trunc_normal_(self.cls_token, std=0.02)
nn.init.trunc_normal_(self.patch_embed[0].weight, std=0.02)
def forward(self, pixel_values):
batch_size = pixel_values.shape[0]
x = self.patch_embed(pixel_values)
x = x.flatten(2).transpose(1, 2)
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
x = torch.cat([cls_tokens, x], dim=1)
x = x + self.pos_embed
x = self.pos_drop(x)
for block in self.blocks:
x = block(x)
x = self.norm(x)
return x
class VisionTransformerBlock(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4.0, dropout=0.0, layer_norm_eps=1e-6):
super().__init__()
self.norm1 = nn.LayerNorm(dim, eps=layer_norm_eps)
self.attn = nn.MultiheadAttention(
dim, num_heads, dropout=dropout, batch_first=True
)
self.drop_path1 = nn.Dropout(dropout)
self.norm2 = nn.LayerNorm(dim, eps=layer_norm_eps)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(mlp_hidden_dim, dim),
nn.Dropout(dropout),
)
self.drop_path2 = nn.Dropout(dropout)
def forward(self, x):
residual = x
x = self.norm1(x)
x = self.attn(x, x, x, need_weights=False)[0]
x = self.drop_path1(x)
x = residual + x
residual = x
x = self.norm2(x)
x = self.mlp(x)
x = self.drop_path2(x)
x = residual + x
return x
class AudioEncoder(nn.Module):
def __init__(self, config):
super().__init__()
audio_config = config.audio_config
self.num_mel_bins = audio_config.get("num_mel_bins", 80)
self.hidden_size = audio_config.get("hidden_size", 1024)
self.num_layers = audio_config.get("num_layers", 12)
self.num_heads = audio_config.get("num_heads", 16)
self.intermediate_size = audio_config.get("intermediate_size", 4096)
self.max_audio_length = audio_config.get("max_audio_length", 3000)
self.dropout = audio_config.get("dropout", 0.0)
self.conv1 = nn.Sequential(
nn.Conv1d(self.num_mel_bins, self.hidden_size, kernel_size=3, padding=1),
nn.GELU(),
nn.Dropout(p=self.dropout),
)
self.conv2 = nn.Sequential(
nn.Conv1d(
self.hidden_size, self.hidden_size, kernel_size=3, stride=2, padding=1
),
nn.GELU(),
nn.Dropout(p=self.dropout),
)
self.pos_embed = nn.Parameter(
torch.zeros(1, self.max_audio_length // 2, self.hidden_size)
)
self.pos_drop = nn.Dropout(p=self.dropout)
self.blocks = nn.ModuleList(
[
AudioTransformerBlock(
dim=self.hidden_size,
num_heads=self.num_heads,
mlp_ratio=self.intermediate_size / self.hidden_size,
dropout=self.dropout,
)
for _ in range(self.num_layers)
]
)
self.norm = nn.LayerNorm(self.hidden_size)
self._init_weights()
def _init_weights(self):
nn.init.trunc_normal_(self.pos_embed, std=0.02)
def forward(self, audio_features):
x = F.gelu(self.conv1(audio_features))
x = F.gelu(self.conv2(x))
x = x.transpose(1, 2)
seq_len = x.shape[1]
if seq_len <= self.pos_embed.shape[1]:
x = x + self.pos_embed[:, :seq_len, :]
else:
pos_embed_interp = F.interpolate(
self.pos_embed.transpose(1, 2),
size=seq_len,
mode="linear",
align_corners=False,
).transpose(1, 2)
x = x + pos_embed_interp
x = self.pos_drop(x)
for block in self.blocks:
x = block(x)
x = self.norm(x)
return x
class AudioTransformerBlock(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4.0, dropout=0.0):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = nn.MultiheadAttention(
dim, num_heads, dropout=dropout, batch_first=True
)
self.drop_path1 = nn.Dropout(dropout)
self.norm2 = nn.LayerNorm(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(mlp_hidden_dim, dim),
nn.Dropout(dropout),
)
self.drop_path2 = nn.Dropout(dropout)
def forward(self, x):
residual = x
x = self.norm1(x)
x = self.attn(x, x, x, need_weights=False)[0]
x = self.drop_path1(x)
x = residual + x
residual = x
x = self.norm2(x)
x = self.mlp(x)
x = self.drop_path2(x)
x = residual + x
return x
class MultiModalProjector(nn.Module):
def __init__(self, input_size, output_size, projector_type="mlp", num_layers=2):
super().__init__()
self.projector_type = projector_type
if projector_type == "linear":
self.projector = nn.Linear(input_size, output_size)
elif projector_type == "mlp":
layers = []
current_size = input_size
for i in range(num_layers - 1):
layers.extend(
[nn.Linear(current_size, output_size), nn.GELU(), nn.Dropout(0.1)]
)
current_size = output_size
layers.append(nn.Linear(current_size, output_size))
self.projector = nn.Sequential(*layers)
elif projector_type == "perceiver":
self.projector = PerceiverResampler(
input_size, output_size, num_latents=64, num_layers=2
)
elif projector_type == "qformer":
self.projector = QFormerProjector(
input_size, output_size, num_queries=32, num_layers=2
)
else:
raise ValueError(f"projector_type tidak dikenal: {projector_type}")
def forward(self, x):
return self.projector(x)
class PerceiverResampler(nn.Module):
def __init__(self, input_size, output_size, num_latents=64, num_layers=2):
super().__init__()
self.num_latents = num_latents
self.latents = nn.Parameter(torch.randn(num_latents, output_size))
self.layers = nn.ModuleList(
[
PerceiverLayer(output_size, input_size if i == 0 else output_size)
for i in range(num_layers)
]
)
self.norm = nn.LayerNorm(output_size)
def forward(self, x):
batch_size = x.shape[0]
latents = self.latents.unsqueeze(0).expand(batch_size, -1, -1)
for i, layer in enumerate(self.layers):
if i == 0:
latents = layer(latents, x)
else:
latents = layer(latents, latents)
return self.norm(latents)
class PerceiverLayer(nn.Module):
def __init__(self, query_dim, key_dim):
super().__init__()
self.cross_attn = nn.MultiheadAttention(
query_dim, num_heads=8, kdim=key_dim, vdim=key_dim, batch_first=True
)
self.mlp = nn.Sequential(
nn.LayerNorm(query_dim),
nn.Linear(query_dim, query_dim * 4),
nn.GELU(),
nn.Linear(query_dim * 4, query_dim),
)
self.norm1 = nn.LayerNorm(query_dim)
self.norm2 = nn.LayerNorm(query_dim)
def forward(self, query, key):
query = (
query + self.cross_attn(self.norm1(query), key, key, need_weights=False)[0]
)
query = query + self.mlp(self.norm2(query))
return query
class QFormerProjector(nn.Module):
def __init__(self, input_size, output_size, num_queries=32, num_layers=2):
super().__init__()
self.num_queries = num_queries
self.query_embeds = nn.Parameter(torch.randn(num_queries, output_size))
self.query_layers = nn.ModuleList(
[
nn.TransformerEncoderLayer(
d_model=output_size,
nhead=8,
dim_feedforward=output_size * 4,
batch_first=True,
)
for _ in range(num_layers)
]
)
self.cross_attn_layers = nn.ModuleList(
[
nn.MultiheadAttention(
output_size,
num_heads=8,
kdim=input_size,
vdim=input_size,
batch_first=True,
)
for _ in range(num_layers)
]
)
self.norm = nn.LayerNorm(output_size)
def forward(self, x):
batch_size = x.shape[0]
queries = self.query_embeds.unsqueeze(0).expand(batch_size, -1, -1)
for query_layer, cross_attn_layer in zip(
self.query_layers, self.cross_attn_layers
):
queries = query_layer(queries)
queries = queries + cross_attn_layer(queries, x, x, need_weights=False)[0]
return self.norm(queries)
class CacaPreTrainedModel(PreTrainedModel):
config_class = CacaConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["CacaDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, CacaModel):
module.gradient_checkpointing = value
class CacaModel(CacaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.config = config
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList(
[
CacaDecoderLayer(config, layer_idx=idx)
for idx in range(config.num_hidden_layers)
]
)
self.norm = CacaRMSNorm(config.hidden_size, config.rms_norm_eps)
self.gradient_checkpointing = False
self.metrics = MetricsTracker()
self._last_memory_check = 0
self._memory_check_interval = 5
if config.use_multimodal:
if config.vision_config:
self.vision_encoder = VisionEncoder(config)
vision_hidden_size = config.vision_config.get("hidden_size", 768)
self.vision_projector = MultiModalProjector(
vision_hidden_size,
config.hidden_size,
projector_type=config.vision_config.get("projector_type", "mlp"),
)
else:
self.vision_encoder = None
self.vision_projector = None
if config.audio_config:
self.audio_encoder = AudioEncoder(config)
audio_hidden_size = config.audio_config.get("hidden_size", 768)
self.audio_projector = MultiModalProjector(
audio_hidden_size,
config.hidden_size,
projector_type=config.audio_config.get("projector_type", "mlp"),
)
else:
self.audio_encoder = None
self.audio_projector = None
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
def _prepare_attention_mask(self, attention_mask, input_shape, dtype):
if attention_mask is None:
return None
batch_size, seq_length = input_shape
if attention_mask.dim() == 2:
attention_mask = attention_mask[:, None, None, :]
elif attention_mask.dim() == 3:
attention_mask = attention_mask[:, None, :, :]
attention_mask = attention_mask.to(dtype=dtype)
attention_mask = (1.0 - attention_mask) * torch.finfo(dtype).min
return attention_mask
def _check_memory_usage(self, device, layer_idx):
if not torch.cuda.is_available() or device.type != "cuda":
return
if layer_idx - self._last_memory_check >= self._memory_check_interval:
allocated_gb = torch.cuda.memory_allocated(device) / 1024**3
reserved_gb = torch.cuda.memory_reserved(device) / 1024**3
self.metrics.log("gpu_memory_allocated_gb", allocated_gb)
self.metrics.log("gpu_memory_reserved_gb", reserved_gb)
if allocated_gb > 12:
logger.warning(
f"Layer {layer_idx}: High GPU memory - "
f"Allocated: {allocated_gb:.2f}GB, Reserved: {reserved_gb:.2f}GB"
)
if allocated_gb > 14:
logger.info(f"Clearing CUDA cache at layer {layer_idx}")
torch.cuda.empty_cache()
self._last_memory_check = layer_idx
def forward(
self,
input_ids=None,
pixel_values=None,
audio_features=None,
attention_mask=None,
past_key_values=None,
use_cache=None,
output_hidden_states=False,
return_dict=True,
**kwargs,
):
use_cache = use_cache if use_cache is not None else self.config.use_cache
if input_ids is not None:
batch_size, seq_length = input_ids.shape
device = input_ids.device
self.metrics.log("batch_size", batch_size)
self.metrics.log("seq_length", seq_length)
hidden_states = self.embed_tokens(input_ids)
else:
raise ValueError("input_ids cannot be None")
encoder_hidden_states = None
encoder_attention_mask = None
if self.config.use_multimodal:
multimodal_embeds = []
if pixel_values is not None and self.vision_encoder is not None:
if pixel_values.device != device:
pixel_values = pixel_values.to(device)
try:
vision_features = self.vision_encoder(pixel_values)
vision_embeds = self.vision_projector(vision_features)
multimodal_embeds.append(vision_embeds)
self.metrics.log("vision_tokens", vision_embeds.shape[1])
except RuntimeError as e:
logger.error(f"Vision encoder failed: {e}")
if audio_features is not None and self.audio_encoder is not None:
if audio_features.device != device:
audio_features = audio_features.to(device)
try:
audio_encoded = self.audio_encoder(audio_features)
audio_embeds = self.audio_projector(audio_encoded)
multimodal_embeds.append(audio_embeds)
self.metrics.log("audio_tokens", audio_embeds.shape[1])
except RuntimeError as e:
logger.error(f"Audio encoder failed: {e}")
if multimodal_embeds:
if self.config.use_cross_attention:
encoder_hidden_states = torch.cat(multimodal_embeds, dim=1)
encoder_seq_len = encoder_hidden_states.shape[1]
encoder_attention_mask = torch.ones(
batch_size,
encoder_seq_len,
dtype=hidden_states.dtype,
device=device,
)
else:
multimodal_concat = torch.cat(multimodal_embeds, dim=1)
max_multimodal_tokens = self.config.max_position_embeddings // 4
if multimodal_concat.shape[1] > max_multimodal_tokens:
logger.warning(
f"Truncating multimodal tokens from "
f"{multimodal_concat.shape[1]} to {max_multimodal_tokens}"
)
multimodal_concat = multimodal_concat[:, :max_multimodal_tokens]
hidden_states = torch.cat([multimodal_concat, hidden_states], dim=1)
seq_length = hidden_states.shape[1]
if attention_mask is not None:
multimodal_mask = torch.ones(
batch_size,
multimodal_concat.shape[1],
dtype=attention_mask.dtype,
device=device,
)
attention_mask = torch.cat(
[multimodal_mask, attention_mask], dim=1
)
else:
attention_mask = torch.ones(
batch_size,
seq_length,
dtype=hidden_states.dtype,
device=device,
)
if attention_mask is not None:
attention_mask = self._prepare_attention_mask(
attention_mask, (batch_size, seq_length), hidden_states.dtype
)
if encoder_attention_mask is not None and self.config.use_cross_attention:
encoder_attention_mask = self._prepare_attention_mask(
encoder_attention_mask,
(batch_size, encoder_hidden_states.shape[1]),
hidden_states.dtype,
)
if use_cache:
if past_key_values is None:
past_key_values = tuple([None] * len(self.layers))
present_key_values = [] if use_cache else None
all_hidden_states = [] if output_hidden_states else None
total_aux_loss = 0.0
total_z_loss = 0.0
for idx, layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states.append(hidden_states)
if self.training:
self._check_memory_usage(device, idx)
past_key_value = (
past_key_values[idx] if past_key_values is not None else None
)
if self.gradient_checkpointing and self.training and not use_cache:
hidden_states, aux_loss, z_loss = self._gradient_checkpointing_forward(
layer,
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
)
present_key_value = None
else:
hidden_states, present_key_value, aux_loss, z_loss = layer(
hidden_states,
attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
use_cache=use_cache,
)
if use_cache:
present_key_values.append(present_key_value)
total_aux_loss += aux_loss
total_z_loss += z_loss
if self.training:
self.metrics.log(f"layer_{idx}_aux_loss", aux_loss)
self.metrics.log(f"layer_{idx}_z_loss", z_loss)
hidden_states = self.norm(hidden_states)
if output_hidden_states:
all_hidden_states.append(hidden_states)
self.metrics.step()
if not return_dict:
return tuple(
v
for v in [
hidden_states,
present_key_values,
all_hidden_states,
total_aux_loss,
total_z_loss,
]
if v is not None
)
return (
BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=tuple(present_key_values) if use_cache else None,
hidden_states=all_hidden_states,
attentions=None,
),
total_aux_loss,
total_z_loss,
)
def _gradient_checkpointing_forward(
self,
layer,
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
):
from torch.utils.checkpoint import checkpoint
def custom_forward(hidden_states):
output, _, aux_loss, z_loss = layer(
hidden_states,
attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_value=None,
use_cache=False,
)
return output, aux_loss, z_loss
hidden_states, aux_loss, z_loss = checkpoint(
custom_forward,
hidden_states,
use_reentrant=False,
)
return hidden_states, aux_loss, z_loss
def get_metrics_summary(self):
return self.metrics.get_summary()
def get_attention_cache_stats(self):
stats = {}
for idx, layer in enumerate(self.layers):
stats[f"layer_{idx}"] = layer.self_attn.get_cache_stats()
return stats
class CacaForCausalLM(CacaPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.model = CacaModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def forward(
self,
input_ids=None,
pixel_values=None,
audio_features=None,
attention_mask=None,
labels=None,
past_key_values=None,
inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs,
):
if input_ids is not None:
if input_ids.dtype.is_floating_point:
raise TypeError(
f"input_ids must be integer dtype, got {input_ids.dtype}. "
f"Use input_ids.long() to convert."
)
if (input_ids < 0).any():
neg_count = (input_ids < 0).sum().item()
raise ValueError(
f"input_ids contains {neg_count} negative values. "
f"All token IDs must be >= 0."
)
max_id = input_ids.max().item()
if max_id >= self.config.vocab_size:
raise ValueError(
f"input_ids contains ID {max_id} which is >= vocab_size "
f"({self.config.vocab_size}). Valid range is [0, {self.config.vocab_size-1}]."
)
if self.config.pad_token_id is not None:
pad_only_sequences = (input_ids == self.config.pad_token_id).all(dim=1)
if pad_only_sequences.any():
pad_count = pad_only_sequences.sum().item()
logger.warning(
f"{pad_count} sequences contain only padding tokens. "
f"This may cause unexpected behavior."
)
if labels is not None:
if not labels.dtype in [
torch.long,
torch.int,
torch.int32,
torch.int64,
]:
raise TypeError(f"labels must be integer dtype, got {labels.dtype}")
if labels.shape != input_ids.shape:
raise ValueError(
f"labels shape {labels.shape} doesn't match "
f"input_ids shape {input_ids.shape}"
)
valid_labels = labels[labels != -100]
if valid_labels.numel() > 0:
invalid_negs = valid_labels[valid_labels < 0]
if invalid_negs.numel() > 0:
raise ValueError(
f"labels contains {invalid_negs.numel()} invalid negative "
f"values (not -100)"
)
max_label = valid_labels.max().item()
if max_label >= self.config.vocab_size:
raise ValueError(
f"labels contains ID {max_label} which is >= vocab_size "
f"({self.config.vocab_size})"
)
else:
logger.warning(
"All labels are -100 (ignored). No loss will be computed."
)
if attention_mask is not None:
if attention_mask.shape[0] != input_ids.shape[0]:
raise ValueError(
f"attention_mask batch size ({attention_mask.shape[0]}) != "
f"input_ids batch size ({input_ids.shape[0]})"
)
if attention_mask.shape[1] != input_ids.shape[1]:
raise ValueError(
f"attention_mask seq length ({attention_mask.shape[1]}) != "
f"input_ids seq length ({input_ids.shape[1]})"
)
unique_vals = attention_mask.unique()
if not torch.all((unique_vals == 0) | (unique_vals == 1)):
logger.warning(
f"attention_mask contains values other than 0 and 1: {unique_vals.tolist()}. "
f"This may cause unexpected behavior."
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
use_amp = (
self.training
and torch.cuda.is_available()
and hasattr(self.config, "use_amp")
and self.config.use_amp
)
if use_amp:
from torch.cuda.amp import autocast
with autocast(
dtype=torch.bfloat16
if torch.cuda.is_bf16_supported()
else torch.float16
):
outputs, aux_loss, z_loss = self.model(
input_ids,
pixel_values=pixel_values,
audio_features=audio_features,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
else:
outputs, aux_loss, z_loss = self.model(
input_ids,
pixel_values=pixel_values,
audio_features=audio_features,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if return_dict:
hidden_states = outputs.last_hidden_state
else:
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
if self.config.final_logit_softcapping:
logits = soft_cap_logits(logits, self.config.final_logit_softcapping)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
shift_logits_flat = shift_logits.view(-1, shift_logits.size(-1))
shift_labels_flat = shift_labels.view(-1)
if (
hasattr(self.config, "label_smoothing")
and self.config.label_smoothing > 0
):
loss_fct = nn.CrossEntropyLoss(
ignore_index=-100, label_smoothing=self.config.label_smoothing
)
else:
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
lm_loss = loss_fct(shift_logits_flat, shift_labels_flat)
if self.training:
with torch.no_grad():
perplexity = torch.exp(lm_loss)
self.model.metrics.log("perplexity", perplexity.item())
self.model.metrics.log("lm_loss", lm_loss.item())
if self.config.use_moe:
total_loss = (
lm_loss
+ (self.config.router_aux_loss_coef * aux_loss)
+ (self.config.router_z_loss_coef * z_loss)
)
if self.training:
self.model.metrics.log("aux_loss", aux_loss.item())
self.model.metrics.log("z_loss", z_loss.item())
else:
total_loss = lm_loss
loss = total_loss
if not return_dict:
output = (logits,)
if return_dict:
output += tuple(
v
for v in [outputs.past_key_values, outputs.hidden_states]
if v is not None
)
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values if return_dict else None,
hidden_states=outputs.hidden_states if return_dict else None,
attentions=None,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
pixel_values=None,
audio_features=None,
**kwargs,
):
has_past = False
if past_key_values is not None:
try:
if len(past_key_values) > 0 and past_key_values[0] is not None:
if isinstance(past_key_values[0], (tuple, list)):
if (
len(past_key_values[0]) > 0
and past_key_values[0][0] is not None
and past_key_values[0][0].numel() > 0
):
has_past = True
except (TypeError, IndexError, AttributeError):
has_past = False
if has_past:
input_ids = input_ids[:, -1:]
pixel_values = None
audio_features = None
if inputs_embeds is not None and not has_past:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"past_key_values": past_key_values if has_past else None,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"pixel_values": pixel_values,
"audio_features": audio_features,
}
)
return model_inputs
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
if layer_past is not None and len(layer_past) > 0:
reordered_layer_past = ()
for past_state in layer_past:
if past_state is not None and past_state.numel() > 0:
reordered_layer_past += (
past_state.index_select(0, beam_idx.to(past_state.device)),
)
else:
reordered_layer_past += (None,)
reordered_past += (reordered_layer_past,)
else:
reordered_past += (None,)
return reordered_past
def save_pretrained(self, save_directory, **kwargs):
has_quant_config = hasattr(self.config, "quantization_config")
quantization_config_backup = getattr(self.config, "quantization_config", None)
if has_quant_config and quantization_config_backup is None:
delattr(self.config, "quantization_config")
try:
super().save_pretrained(save_directory, **kwargs)
if self.training:
metrics_summary = self.model.get_metrics_summary()
cache_stats = self.model.get_attention_cache_stats()
import json
stats_path = os.path.join(save_directory, "training_stats.json")
with open(stats_path, "w") as f:
json.dump(
{"metrics": metrics_summary, "cache_stats": cache_stats},
f,
indent=2,
)
logger.info(f"Saved training statistics to {stats_path}")
finally:
if has_quant_config:
self.config.quantization_config = quantization_config_backup
def get_model_stats(self):
stats = {
"metrics": self.model.get_metrics_summary(),
"cache_stats": self.model.get_attention_cache_stats(),
"gradient_stats": {},
}
for idx, layer in enumerate(self.model.layers):
stats["gradient_stats"][f"layer_{idx}"] = layer.get_gradient_stats()
if self.config.use_moe:
expert_usage = {}
for idx, layer in enumerate(self.model.layers):
if hasattr(layer.mlp, "expert_usage_count"):
usage = layer.mlp.expert_usage_count.cpu().numpy().tolist()
expert_usage[f"layer_{idx}"] = usage
stats["expert_usage"] = expert_usage
return stats
class CacaForCausalLMQuantized(CacaForCausalLM):
def __init__(self, config, quantization_config=None):
super().__init__(config)
self.quantization_config = quantization_config
if quantization_config:
self._apply_quantization()
def _apply_quantization(self):
if self.quantization_config.get("load_in_8bit"):
self._quantize_8bit()
elif self.quantization_config.get("load_in_4bit"):
self._quantize_4bit()
def _quantize_8bit(self):
try:
import bitsandbytes as bnb
for name, module in self.named_modules():
if isinstance(module, nn.Linear):
has_bias = module.bias is not None
new_module = bnb.nn.Linear8bitLt(
module.in_features,
module.out_features,
has_bias,
threshold=self.quantization_config.get(
"llm_int8_threshold", 6.0
),
)
new_module.weight = module.weight
if has_bias:
new_module.bias = module.bias
parent_name = ".".join(name.split(".")[:-1])
child_name = name.split(".")[-1]
if parent_name:
parent = self.get_submodule(parent_name)
setattr(parent, child_name, new_module)
else:
setattr(self, child_name, new_module)
logger.info("Quantisasi 8-bit berhasil diterapkan")
except ImportError:
logger.error("bitsandbytes tidak terinstall! pip install bitsandbytes")
def _quantize_4bit(self):
try:
import bitsandbytes as bnb
compute_dtype = torch.float16
if self.quantization_config.get("bnb_4bit_compute_dtype"):
compute_dtype = getattr(
torch, self.quantization_config["bnb_4bit_compute_dtype"]
)
for name, module in self.named_modules():
if isinstance(module, nn.Linear):
has_bias = module.bias is not None
new_module = bnb.nn.Linear4bit(
module.in_features,
module.out_features,
bias=has_bias,
compute_dtype=compute_dtype,
quant_type=self.quantization_config.get(
"bnb_4bit_quant_type", "nf4"
),
use_double_quant=self.quantization_config.get(
"bnb_4bit_use_double_quant", True
),
)
new_module.weight = module.weight
if has_bias:
new_module.bias = module.bias
parent_name = ".".join(name.split(".")[:-1])
child_name = name.split(".")[-1]
if parent_name:
parent = self.get_submodule(parent_name)
setattr(parent, child_name, new_module)
else:
setattr(self, child_name, new_module)
logger.info("Quantisasi 4-bit berhasil diterapkan")
except ImportError:
logger.error("bitsandbytes tidak terinstall!")
@classmethod
def from_pretrained_quantized(cls, model_path, quantization_config):
config = CacaConfig.from_pretrained(model_path)
model = cls(config, quantization_config=quantization_config)
state_dict = torch.load(f"{model_path}/pytorch_model.bin", map_location="cpu")
model.load_state_dict(state_dict, strict=False)
return model
class CacaTrainer:
def __init__(
self,
model,
optimizer,
scheduler=None,
gradient_accumulation_steps=1,
max_grad_norm=1.0,
use_amp=False,
):
self.model = model
self.optimizer = optimizer
self.scheduler = scheduler
self.gradient_accumulation_steps = gradient_accumulation_steps
self.max_grad_norm = max_grad_norm
self.use_amp = use_amp
self.global_step = 0
self.epoch = 0
if self.use_amp:
from torch.cuda.amp import GradScaler
self.scaler = GradScaler()
else:
self.scaler = None
self.train_metrics = defaultdict(list)
def training_step(self, batch):
self.model.train()
if self.use_amp:
from torch.cuda.amp import autocast
with autocast(
dtype=torch.bfloat16
if torch.cuda.is_bf16_supported()
else torch.float16
):
outputs = self.model(**batch)
loss = outputs.loss
loss = loss / self.gradient_accumulation_steps
else:
outputs = self.model(**batch)
loss = outputs.loss
loss = loss / self.gradient_accumulation_steps
if self.use_amp:
self.scaler.scale(loss).backward()
else:
loss.backward()
self.train_metrics["loss"].append(
loss.item() * self.gradient_accumulation_steps
)
return loss.item() * self.gradient_accumulation_steps
def optimizer_step(self):
if self.use_amp:
self.scaler.unscale_(self.optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(
self.model.parameters(), self.max_grad_norm
)
self.scaler.step(self.optimizer)
self.scaler.update()
else:
grad_norm = torch.nn.utils.clip_grad_norm_(
self.model.parameters(), self.max_grad_norm
)
self.optimizer.step()
self.optimizer.zero_grad(set_to_none=True)
if self.scheduler is not None:
self.scheduler.step()
self.train_metrics["grad_norm"].append(grad_norm.item())
self.global_step += 1
return grad_norm.item()
def train_epoch(self, dataloader, log_interval=10):
self.model.train()
epoch_loss = 0.0
num_batches = 0
for step, batch in enumerate(dataloader):
loss = self.training_step(batch)
epoch_loss += loss
num_batches += 1
if (step + 1) % self.gradient_accumulation_steps == 0:
grad_norm = self.optimizer_step()
if self.global_step % log_interval == 0:
avg_loss = epoch_loss / num_batches
lr = self.optimizer.param_groups[0]["lr"]
logger.info(
f"Epoch {self.epoch} | Step {self.global_step} | "
f"Loss: {avg_loss:.4f} | Grad Norm: {grad_norm:.4f} | "
f"LR: {lr:.2e}"
)
if hasattr(self.model, "get_model_stats"):
stats = self.model.get_model_stats()
if "metrics" in stats and "perplexity" in stats["metrics"]:
ppl = stats["metrics"]["perplexity"]["last"]
logger.info(f"Perplexity: {ppl:.2f}")
self.epoch += 1
return epoch_loss / num_batches
def get_metrics(self):
metrics = {}
for key, values in self.train_metrics.items():
if values:
metrics[key] = {
"mean": np.mean(values),
"std": np.std(values),
"min": np.min(values),
"max": np.max(values),
}
return metrics