HanForge-base / modeling_hanforge.py
drlee1's picture
Upload folder using huggingface_hub
a00d81d verified
from __future__ import annotations
import math
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel
from transformers.generation import GenerationMixin
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
try:
from configuration_hanforge import HanForgeConfig
except ImportError:
from .configuration_hanforge import HanForgeConfig
def rotate_half(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
q = (q * cos) + (rotate_half(q) * sin)
k = (k * cos) + (rotate_half(k) * sin)
return q, k
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
if n_rep == 1:
return hidden_states
batch, num_key_value_heads, seq_len, head_dim = hidden_states.shape
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, seq_len, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, seq_len, head_dim)
# DISABLED (refactor 20260423, ยง4.2): YaRN ๋ณธ๋ฌธ ๋น„ํ™œ์„ฑํ™”. from-scratch 4k context์—์„œ๋Š” ๋ถˆํ•„์š”.
# ํ›„์ผ context ํ™•์žฅ ์‹œ ์ฐธ์กฐํ•  ์ˆ˜ ์žˆ๋„๋ก ์‹œ๊ทธ๋‹ˆ์ฒ˜๋Š” ๋‚จ๊ธฐ๊ณ  ๋ณธ๋ฌธ๋งŒ ์ฃผ์„ ์ฒ˜๋ฆฌํ•œ๋‹ค.
def _compute_yarn_parameters(config: HanForgeConfig, device=None):
raise NotImplementedError(
"YaRN is disabled in this refactor (see research/refactor_plan_20260423.md ยง4.2)."
)
# <<< disabled (refactor 20260423, ยง4.2)
# rope_params = dict(config.rope_scaling or {})
# dim = config.head_dim
# base = config.rope_theta
# if not rope_params or rope_params.get("rope_type", "default") == "default":
# inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
# return inv_freq, 1.0
#
# factor = float(rope_params["factor"])
# beta_fast = float(rope_params.get("beta_fast", 32.0))
# beta_slow = float(rope_params.get("beta_slow", 1.0))
# mscale = rope_params.get("mscale")
# mscale_all_dim = rope_params.get("mscale_all_dim")
# original_max = int(rope_params["original_max_position_embeddings"])
#
# def get_mscale(scale, scale_factor=1.0):
# if scale <= 1:
# return 1.0
# return 0.1 * scale_factor * math.log(scale) + 1.0
#
# if mscale is not None and mscale_all_dim is not None:
# attention_factor = float(get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dim))
# else:
# attention_factor = float(get_mscale(factor))
#
# def find_correction_dim(num_rotations, local_dim, local_base, max_position_embeddings):
# return (local_dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
# 2 * math.log(local_base)
# )
#
# def find_correction_range(low_rot, high_rot, local_dim, local_base, max_position_embeddings):
# low = math.floor(find_correction_dim(low_rot, local_dim, local_base, max_position_embeddings))
# high = math.ceil(find_correction_dim(high_rot, local_dim, local_base, max_position_embeddings))
# return max(low, 0), min(high, local_dim - 1)
#
# def linear_ramp_factor(min_idx, max_idx, local_dim):
# if min_idx == max_idx:
# max_idx += 0.001
# linear_func = (torch.arange(local_dim, dtype=torch.float32, device=device) - min_idx) / (max_idx - min_idx)
# return torch.clamp(linear_func, 0, 1)
#
# pos_freqs = base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
# inv_freq_extrapolation = 1.0 / pos_freqs
# inv_freq_interpolation = 1.0 / (factor * pos_freqs)
# low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_max)
# ramp = 1.0 - linear_ramp_factor(low, high, dim // 2)
# inv_freq = (inv_freq_interpolation * (1.0 - ramp)) + (inv_freq_extrapolation * ramp)
# return inv_freq, attention_factor
# >>> end disabled
def _compute_rope_parameters(config: HanForgeConfig, device=None):
dim = config.head_dim
base = config.rope_theta
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
return inv_freq
class HanForgeRMSNorm(nn.Module):
def __init__(self, hidden_size: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.eps = eps
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(dim=-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
return self.weight * hidden_states.to(input_dtype)
class HanForgeRotaryEmbedding(nn.Module):
def __init__(self, config: HanForgeConfig):
super().__init__()
inv_freq = _compute_rope_parameters(config)
self.register_buffer("inv_freq", inv_freq, persistent=False)
def forward(self, x: torch.Tensor, position_ids: torch.Tensor):
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()
freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
class HanForgeAttention(nn.Module):
def __init__(self, config: HanForgeConfig, layer_idx: int):
super().__init__()
self.layer_idx = layer_idx
self.num_heads = config.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = config.num_key_value_groups
self.head_dim = config.head_dim
# DISABLED (refactor 20260423, ยง4.1): hybrid local/global attention ๋น„ํ™œ์„ฑํ™”
# self.is_global = config.is_global_layer(layer_idx)
# self.sliding_window = config.sliding_window
self.q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
kv_hidden = config.num_key_value_heads * self.head_dim
self.k_proj = nn.Linear(config.hidden_size, kv_hidden, bias=False)
self.v_proj = nn.Linear(config.hidden_size, kv_hidden, bias=False)
self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
self.dropout = nn.Dropout(config.attention_dropout)
def forward(
self,
hidden_states: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
attention_mask: Optional[torch.Tensor],
) -> torch.Tensor:
batch_size, seq_len, hidden_size = hidden_states.shape
q = self.q_proj(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(hidden_states).view(batch_size, seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(hidden_states).view(batch_size, seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
q, k = apply_rotary_pos_emb(q, k, cos, sin)
k = repeat_kv(k, self.num_key_value_groups)
v = repeat_kv(v, self.num_key_value_groups)
scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
if attention_mask is not None:
scores = scores.masked_fill(~attention_mask, torch.finfo(scores.dtype).min)
attn = F.softmax(scores, dim=-1)
attn = self.dropout(attn)
out = attn @ v
out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, hidden_size)
return self.o_proj(out)
class HanForgeMLP(nn.Module):
def __init__(self, config: HanForgeConfig):
super().__init__()
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.down_proj(F.silu(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
class HanForgeDecoderLayer(nn.Module):
def __init__(self, config: HanForgeConfig, layer_idx: int):
super().__init__()
# DISABLED (refactor 20260423, ยง4.1): hybrid local/global ๋ ˆ์ด์–ด ๋ถ„๊ธฐ ๋น„ํ™œ์„ฑํ™”.
# ๋ชจ๋“  ๋ ˆ์ด์–ด๊ฐ€ causal full attention ๊ฒฝ๋กœ๋กœ ๋™์ž‘ํ•œ๋‹ค.
# self.is_global = config.is_global_layer(layer_idx)
self.input_layernorm = HanForgeRMSNorm(config.hidden_size, config.rms_norm_eps)
self.self_attn = HanForgeAttention(config, layer_idx)
self.post_attention_layernorm = HanForgeRMSNorm(config.hidden_size, config.rms_norm_eps)
self.mlp = HanForgeMLP(config)
def forward(self, hidden_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, attention_mask: torch.Tensor):
hidden_states = hidden_states + self.self_attn(self.input_layernorm(hidden_states), cos, sin, attention_mask)
hidden_states = hidden_states + self.mlp(self.post_attention_layernorm(hidden_states))
return hidden_states
class HanForgePreTrainedModel(PreTrainedModel):
config_class = HanForgeConfig
base_model_prefix = "model"
_no_split_modules = ["HanForgeDecoderLayer"]
def _init_weights(self, module):
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
class HanForgeModel(HanForgePreTrainedModel):
def __init__(self, config: HanForgeConfig):
super().__init__(config)
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList([HanForgeDecoderLayer(config, idx) for idx in range(config.num_hidden_layers)])
self.norm = HanForgeRMSNorm(config.hidden_size, config.rms_norm_eps)
self.rotary_emb = HanForgeRotaryEmbedding(config)
self.post_init()
def _build_causal_mask(self, batch_size: int, seq_len: int, device: torch.device) -> torch.Tensor:
base = torch.tril(torch.ones(seq_len, seq_len, device=device, dtype=torch.bool))
return base.unsqueeze(0).unsqueeze(0).expand(batch_size, 1, seq_len, seq_len)
# DISABLED (refactor 20260423, ยง4.1): sliding window local mask ๋น„ํ™œ์„ฑํ™”.
# def _build_local_mask(self, batch_size: int, seq_len: int, device: torch.device) -> torch.Tensor:
# row = torch.arange(seq_len, device=device)[:, None]
# col = torch.arange(seq_len, device=device)[None, :]
# causal = col <= row
# window = col >= (row - self.config.sliding_window + 1)
# mask = (causal & window).to(torch.bool)
# return mask.unsqueeze(0).unsqueeze(0).expand(batch_size, 1, seq_len, seq_len)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
return_dict: bool = True,
**_: dict,
):
batch_size, seq_len = input_ids.shape
hidden_states = self.embed_tokens(input_ids)
if position_ids is None:
position_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0).expand(batch_size, -1)
cos, sin = self.rotary_emb(hidden_states, position_ids)
full_mask = self._build_causal_mask(batch_size, seq_len, hidden_states.device)
if attention_mask is not None:
key_mask = attention_mask[:, None, None, :].to(torch.bool)
full_mask = full_mask & key_mask
# DISABLED (refactor 20260423, ยง4.1): ๋ชจ๋“  layer๊ฐ€ full causal mask ์‚ฌ์šฉ.
# local_mask ๋ถ„๊ธฐ๋Š” hybrid attention ์žฌ๋„์ž… ์‹œ์—๋งŒ ์‚ฌ์šฉํ•œ๋‹ค.
for layer in self.layers:
hidden_states = layer(hidden_states, cos, sin, full_mask)
hidden_states = self.norm(hidden_states)
if not return_dict:
return (hidden_states,)
return BaseModelOutputWithPast(last_hidden_state=hidden_states)
class HanForgeForCausalLM(HanForgePreTrainedModel, GenerationMixin):
# refactor 20260507 (ยงformat/EOS): _tied_weights_keys ์™„์ „ ์ œ๊ฑฐ.
# transformers 5.x์˜ _tied_weights_keys ๋ฉ”์ปค๋‹ˆ์ฆ˜์ด Phase 1 ๋””๋ฒ„๊น…์—์„œ from_pretrained ์‹œ
# .bin ํŒŒ์ผ์˜ ํ•™์Šต๋œ weight๋ฅผ silentํ•˜๊ฒŒ ๋ฌด์‹œํ•˜๊ณ  random init ๊ทธ๋Œ€๋กœ ์‚ฌ์šฉํ•˜๋Š” ๋ฒ„๊ทธ๋ฅผ
# ์ผ์œผํ‚ด. config tie_word_embeddings=False์™€ ๊ฒฐํ•ฉํ•ด์„œ ๋‘ weight๋ฅผ ๋ณ„๊ฐœ๋กœ ๋ช…์‹œ ์ฒ˜๋ฆฌ.
# (๊ฐ€๋Šฅํ•˜๋ฉด ํ•™์Šต ๋ชจ๋ธ์€ tie_word_embeddings=False๋กœ ์ €์žฅ. base ๋ชจ๋ธ์€ ์ผ์‹œ์ ์œผ๋กœ ์œ„ํ—˜.)
_tied_weights_keys = None
def __init__(self, config: HanForgeConfig):
super().__init__(config)
self.model = HanForgeModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
# refactor 20260423 (ยง9): tie lm_head.weight to embed_tokens.weight
# post_init ์•ˆ์—์„œ PreTrainedModel.tie_weights()๊ฐ€ ๋™์ผ ์ž‘์—…์„ ์‹œ๋„ํ•˜์ง€๋งŒ,
# ์ž‘์€ ๋ชจ๋ธ + 32k vocab์—์„œ ํŒŒ๋ผ๋ฏธํ„ฐ ์ ˆ์•ฝ์„ ๋ณด์žฅํ•˜๊ธฐ ์œ„ํ•ด ๋ช…์‹œ์ ์œผ๋กœ ํ•œ๋‹ค.
if getattr(config, "tie_word_embeddings", True):
self.lm_head.weight = self.model.embed_tokens.weight
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **kwargs):
if attention_mask is None:
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids = position_ids.clamp_min(0)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
return_dict: bool = True,
**kwargs,
):
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True, **kwargs)
hidden_states = outputs.last_hidden_state
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = labels[:, 1:].contiguous()
loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=-100,
)
if not return_dict:
result = (logits,)
if loss is not None:
result = (loss,) + result
return result
return CausalLMOutputWithPast(loss=loss, logits=logits)