open-lm-1b-201701 / modeling_open_lm_hf.py
dogtooth's picture
Upload folder using huggingface_hub
1659bfa verified
"""
Custom HuggingFace model for Open LM checkpoints.
Open LM uses LayerNorm (not RMSNorm) and QK norm, which standard
LlamaForCausalLM does not support. This module provides:
- OpenLMConfig: LlamaConfig subclass with qk_norm flag
- OpenLMForCausalLM: LlamaForCausalLM subclass with LayerNorm + QK norm
Usage:
model = AutoModelForCausalLM.from_pretrained(path, trust_remote_code=True)
"""
from typing import Callable, Optional
import torch
import torch.nn as nn
from transformers import LlamaConfig, LlamaForCausalLM
from transformers.models.llama.modeling_llama import (
ALL_ATTENTION_FUNCTIONS,
LlamaAttention,
LlamaRMSNorm,
apply_rotary_pos_emb,
eager_attention_forward,
)
try:
from typing import Unpack
from transformers.utils.generic import TransformersKwargs
except ImportError:
pass
from transformers.cache_utils import Cache
class OpenLMConfig(LlamaConfig):
model_type = "open_lm"
def __init__(self, qk_norm: bool = True, **kwargs):
super().__init__(**kwargs)
self.qk_norm = qk_norm
class OpenLMAttention(LlamaAttention):
"""LlamaAttention with QK norm applied before reshape (matching Open LM)."""
def __init__(self, config: OpenLMConfig, layer_idx: int):
super().__init__(config, layer_idx)
if getattr(config, "qk_norm", False):
self.q_norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps, bias=False)
self.k_norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps, bias=False)
else:
self.q_norm = nn.Identity()
self.k_norm = nn.Identity()
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_values: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
# QK norm applied to flat projected vectors BEFORE reshape (matches Open LM)
query_states = self.q_norm(self.q_proj(hidden_states)).view(hidden_shape).transpose(1, 2)
key_states = self.k_norm(self.k_proj(hidden_states)).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_values is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_values.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class OpenLMForCausalLM(LlamaForCausalLM):
"""LlamaForCausalLM with LayerNorm (instead of RMSNorm) and QK norm support."""
config_class = OpenLMConfig
def __init__(self, config: OpenLMConfig):
super().__init__(config)
# Replace all LlamaRMSNorm with nn.LayerNorm(bias=False)
eps = config.rms_norm_eps
hidden_size = config.hidden_size
self.model.norm = nn.LayerNorm(hidden_size, eps=eps, bias=False)
for layer in self.model.layers:
layer.input_layernorm = nn.LayerNorm(hidden_size, eps=eps, bias=False)
layer.post_attention_layernorm = nn.LayerNorm(hidden_size, eps=eps, bias=False)
# Replace attention module with QK norm version
layer.self_attn = OpenLMAttention(config, layer.self_attn.layer_idx)
# Re-run post_init to tie weights etc.
self.post_init()