residual-tiny / modeling_residualnet.py
if001's picture
Update modeling_residualnet.py
1b9a33d verified
"""
「差分→Attn→Dense を前半層で繰り返し、後半層は“Denseでのアップスケール(=長さ+1)”→Attn→Dense を繰り返して最終的に元の seq_len に戻す」アーキテクチャ
"""
from typing import Optional, Tuple, List
import torch
from torch import nn
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.generation.utils import GenerationMixin
from transformers.models.phi3.configuration_phi3 import Phi3Config
from transformers.models.phi3.modeling_phi3 import (
Phi3PreTrainedModel,
Phi3RotaryEmbedding,
Phi3RMSNorm,
Phi3Attention,
# Phi3SdpaAttention, # 既定の SDPA 注意
Phi3MLP,
)
#from models.phi3_config import Phi3Config
#from models.phi3 import (
# Phi3PreTrainedModel,
# Phi3RMSNorm,
# Phi3MLP,
# # Phi3SdpaAttention,
# Phi3Attention,
# Phi3RotaryEmbedding,
#)
class ResidualNetConfig(Phi3Config):
model_type = "ResidualNetConfig"
def __init__(self, **kwargs):
super().__init__(**kwargs)
# ---------- 長さ変換用の前処理 ----------
class DiffPreprocessor(nn.Module):
"""一次差分: (B, L, H) -> (B, L-1, H) と 2D mask の AND 縮約"""
def forward(
self,
hidden_states: torch.Tensor,
attention_mask_2d: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# hidden_states: (B, L, H)
x1 = hidden_states[:, 1:, :]
x0 = hidden_states[:, :-1, :]
diff = x1 - x0 # (B, L-1, H)
if attention_mask_2d is not None:
m = (attention_mask_2d[:, 1:].bool() & attention_mask_2d[:, :-1].bool()).to(attention_mask_2d.dtype)
else:
m = None
return diff, m
class IntegratePreprocessor(nn.Module):
"""
学習可能な“積分”で (B, m, H) -> (B, m+1, H)
1) seed y0 = MLP(mean_pool(z))
2) y = cumsum([y0, z], dim=1)
"""
def __init__(self, hidden_size: int):
super().__init__()
self.seed_mlp = nn.Sequential(
nn.Linear(hidden_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask_2d: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# hidden_states: (B, m, H)
if attention_mask_2d is not None:
denom = attention_mask_2d.sum(dim=1, keepdim=True).clamp_min(1)
pooled = (hidden_states * attention_mask_2d.unsqueeze(-1)).sum(dim=1) / denom # (B, H)
batch_valid = (attention_mask_2d.sum(dim=1) > 0).to(attention_mask_2d.dtype) # (B,)
else:
pooled = hidden_states.mean(dim=1)
batch_valid = None
y0 = self.seed_mlp(pooled).unsqueeze(1) # (B,1,H)
y = torch.cumsum(torch.cat([y0, hidden_states], dim=1), dim=1) # (B, m+1, H)
if attention_mask_2d is not None:
new_first = batch_valid.unsqueeze(1) # (B,1)
mask = torch.cat([new_first, attention_mask_2d], dim=1)
else:
mask = None
return y, mask
# ---------- レイヤーブロック(Phi3 部品で構成) ----------
class ResidualDiffLayer(nn.Module):
"""
(差分で L-1) -> Attn -> MLP
- RoPE は Phi-3 と同様に Attention 内で適用
- 各層で position_ids を 0..len-1 に張り直す
"""
def __init__(self, config: ResidualNetConfig, layer_idx: int, rotary_emb: Phi3RotaryEmbedding):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.input_norm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pre = DiffPreprocessor()
# self.attn = Phi3SdpaAttention(config, layer_idx=layer_idx)
self.attn = Phi3Attention(config, layer_idx=layer_idx)
self.dropout_attn = nn.Dropout(config.resid_pdrop)
self.post_norm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.mlp = Phi3MLP(config)
self.dropout_mlp = nn.Dropout(config.resid_pdrop)
self.rotary_emb = rotary_emb # 共有 RoPE インスタンス
def _to_4d_mask(
self, mask2d: Optional[torch.Tensor], bsz: int, seqlen: int, hidden_states: torch.Tensor
) -> Optional[torch.Tensor]:
if mask2d is None:
return None
return _prepare_4d_causal_attention_mask(
mask2d, (bsz, seqlen), hidden_states, past_key_values_length=0, sliding_window=self.config.sliding_window
)
def forward(
self,
hidden_states: torch.Tensor, # (B, L, H)
attention_mask_2d: Optional[torch.Tensor], # (B, L)
position_ids: Optional[torch.LongTensor], # (B, L)
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
x = self.input_norm(hidden_states)
# L -> L-1
x, mask2d = self.pre(x, attention_mask_2d)
bsz, seqlen, _ = x.shape
# position_ids を再生成(0..seqlen-1)
device = x.device
pos_ids = torch.arange(seqlen, device=device).unsqueeze(0).expand(bsz, -1)
position_embeddings = self.rotary_emb(hidden_states, pos_ids)
attn_mask_4d = self._to_4d_mask(mask2d, bsz, seqlen, x)
attn_out, attn_weights = self.attn(
hidden_states=x,
attention_mask=attn_mask_4d,
position_ids=pos_ids,
position_embeddings=position_embeddings,
past_key_value=None,
output_attentions=output_attentions,
use_cache=False,
)
x = x + self.dropout_attn(attn_out)
h = self.post_norm(x)
h = self.mlp(h)
x = x + self.dropout_mlp(h)
return x, mask2d, attn_weights if output_attentions else None
class IntegrateUpscaleLayer(nn.Module):
"""
(積分で L+1) -> Attn -> MLP
"""
def __init__(self, config: ResidualNetConfig, layer_idx: int, rotary_emb: Phi3RotaryEmbedding):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.input_norm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pre = IntegratePreprocessor(config.hidden_size)
# self.attn = Phi3SdpaAttention(config, layer_idx=layer_idx)
self.attn = Phi3Attention(config, layer_idx=layer_idx)
self.dropout_attn = nn.Dropout(config.resid_pdrop)
self.post_norm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.mlp = Phi3MLP(config)
self.dropout_mlp = nn.Dropout(config.resid_pdrop)
self.rotary_emb = rotary_emb
def _to_4d_mask(
self, mask2d: Optional[torch.Tensor], bsz: int, seqlen: int, hidden_states: torch.Tensor
) -> Optional[torch.Tensor]:
if mask2d is None:
return None
return _prepare_4d_causal_attention_mask(
mask2d, (bsz, seqlen), hidden_states, past_key_values_length=0, sliding_window=self.config.sliding_window
)
def forward(
self,
hidden_states: torch.Tensor, # (B, L, H)
attention_mask_2d: Optional[torch.Tensor], # (B, L)
position_ids: Optional[torch.LongTensor], # (B, L)
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
x = self.input_norm(hidden_states)
# L -> L+1
x, mask2d = self.pre(x, attention_mask_2d)
bsz, seqlen, _ = x.shape
# position_ids を再生成(0..seqlen-1)
device = x.device
pos_ids = torch.arange(seqlen, device=device).unsqueeze(0).expand(bsz, -1)
position_embeddings = self.rotary_emb(hidden_states, pos_ids)
attn_mask_4d = self._to_4d_mask(mask2d, bsz, seqlen, x)
attn_out, attn_weights = self.attn(
hidden_states=x,
attention_mask=attn_mask_4d,
position_ids=pos_ids,
position_embeddings=position_embeddings,
past_key_value=None,
output_attentions=output_attentions,
use_cache=False,
)
x = x + self.dropout_attn(attn_out)
h = self.post_norm(x)
h = self.mlp(h)
x = x + self.dropout_mlp(h)
return x, mask2d, attn_weights if output_attentions else None
# ---------- モデル本体(Phi3PreTrainedModel を継承) ----------
class ResidualNetModel(Phi3PreTrainedModel):
"""
前半: ResidualDiffLayer × (N/2) で系列長を縮約
後半: IntegrateUpscaleLayer × (N/2) で系列長を復元
"""
config_class = ResidualNetConfig
def __init__(self, config: ResidualNetConfig):
super().__init__(config)
assert config.num_hidden_layers % 2 == 0, "num_hidden_layers は偶数にしてください。"
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.norm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = Phi3RotaryEmbedding(config=config)
self.gradient_checkpointing = False
half = config.num_hidden_layers // 2
# 前半 (down)
self.down_layers = nn.ModuleList(
[ResidualDiffLayer(config, layer_idx=i, rotary_emb=self.rotary_emb) for i in range(half)]
)
# 後半 (up)
self.up_layers = nn.ModuleList(
[IntegrateUpscaleLayer(config, layer_idx=half + i, rotary_emb=self.rotary_emb) for i in range(half)]
)
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, # (B, L) in {0,1}
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
use_cache: Optional[bool] = None, # 未対応(強制 False)
**kwargs
):
output_attentions = output_attentions if output_attentions is not None else False
output_hidden_states = output_hidden_states if output_hidden_states is not None else False
return_dict = True if return_dict is None else return_dict
if input_ids is None and inputs_embeds is None:
raise ValueError("You must specify either input_ids or inputs_embeds.")
if inputs_embeds is None:
hidden_states = self.embed_tokens(input_ids) # (B, L, H)
else:
hidden_states = inputs_embeds
mask2d = attention_mask
bsz, orig_len, _ = hidden_states.shape
all_hidden_states: List[torch.Tensor] = [] if output_hidden_states else None
all_attns: List[torch.Tensor] = [] if output_attentions else None
# ---- 前半: 差分で縮約 ----
for layer in self.down_layers:
if output_hidden_states:
all_hidden_states.append(hidden_states)
hidden_states, mask2d, attn = layer(
hidden_states, mask2d, position_ids, output_attentions=output_attentions
)
if output_attentions:
all_attns.append(attn)
# ---- 後半: 積分で復元 ----
for layer in self.up_layers:
if output_hidden_states:
all_hidden_states.append(hidden_states)
hidden_states, mask2d, attn = layer(
hidden_states, mask2d, position_ids, output_attentions=output_attentions
)
if output_attentions:
all_attns.append(attn)
# 最終長の整合性(念のため)
if hidden_states.size(1) != orig_len:
raise RuntimeError(f"seq_len が復元されていません: got {hidden_states.size(1)} vs {orig_len}")
hidden_states = self.norm(hidden_states)
if not return_dict:
out = (hidden_states,)
if output_hidden_states:
out = out + (all_hidden_states,)
if output_attentions:
out = out + (all_attns,)
return out
return {
"last_hidden_state": hidden_states,
"hidden_states": all_hidden_states,
"attentions": all_attns,
}
# ---------- CausalLM ヘッド(Phi3PreTrainedModel + GenerationMixin) ----------
class ResidualNetForCausalLM(Phi3PreTrainedModel, GenerationMixin):
config_class = ResidualNetConfig
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
def __init__(self, config: ResidualNetConfig):
super().__init__(config)
self.model = ResidualNetModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# weight tying
# self.lm_head.weight = self.model.embed_tokens.weight
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
use_cache: Optional[bool] = None, # 未対応
past_key_values: Optional[List[torch.Tensor]] = None, # 未対応
**kwargs
) -> CausalLMOutputWithPast:
return_dict = True if return_dict is None else return_dict
model_out = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
use_cache=False,
**kwargs,
)
hidden_states = model_out["last_hidden_state"] # (B, L, H)
logits = self.lm_head(hidden_states).float()
loss = None
if labels is not None:
# 因果言語モデリング損失
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = labels[:, 1:].contiguous()
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, self.vocab_size), shift_labels.view(-1))
if not return_dict:
return (logits, loss)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=None, # 未対応
hidden_states=model_out["hidden_states"],
attentions=model_out["attentions"],
)
@property
def base_model(self):
return self.model