File size: 3,581 Bytes
846dc7c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 | # modeling_tinytransformer.py
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from configuration_tinytransformer import TinyTransformerConfig
class TinyTransformerModel(PreTrainedModel):
"""
一个非常小的 Transformer 编码器 + 分类头,用于情感分类演示
"""
config_class = TinyTransformerConfig
def __init__(self, config: TinyTransformerConfig):
super().__init__(config)
self.config = config
# 词嵌入
self.embedding = nn.Embedding(
config.vocab_size,
config.hidden_size,
padding_idx=config.pad_token_id if hasattr(config, 'pad_token_id') else 0
)
# 位置嵌入(学出来的)
self.pos_embedding = nn.Embedding(
config.max_position_embeddings,
config.hidden_size
)
# Transformer Encoder 层
encoder_layer = nn.TransformerEncoderLayer(
d_model=config.hidden_size,
nhead=config.num_attention_heads,
dim_feedforward=config.intermediate_size,
dropout=config.dropout,
activation="gelu",
batch_first=True,
norm_first=False # 经典 post-norm
)
self.encoder = nn.TransformerEncoder(
encoder_layer,
num_layers=config.num_hidden_layers
)
# dropout 和分类头
self.dropout = nn.Dropout(config.dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
# 必须加这一行,且放在最后
self._tied_weights_keys = []
# 可选调试打印(加这个看是否执行到)
print("DEBUG: 已设置 _tied_weights_keys =", self._tied_weights_keys)
def _init_weights(self, module=None):
"""简单权重初始化"""
if module is None:
module = self
for m in module.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Embedding):
nn.init.normal_(m.weight, mean=0.0, std=0.02)
def forward(
self,
input_ids=None,
attention_mask=None,
labels=None,
**kwargs
):
batch_size, seq_len = input_ids.shape
# 位置编码
position_ids = torch.arange(
0, seq_len, dtype=torch.long, device=input_ids.device
)
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
# 嵌入 + 位置
x = self.embedding(input_ids) + self.pos_embedding(position_ids)
x = self.dropout(x)
# 处理 attention mask
if attention_mask is not None:
# src_key_padding_mask: True 表示要忽略的位置 (padding)
src_key_padding_mask = (attention_mask == 0)
else:
src_key_padding_mask = None
# 通过 Transformer Encoder
x = self.encoder(
x,
src_key_padding_mask=src_key_padding_mask,
)
# 取 [CLS] token(第一个位置)作为句子表示
pooled = x[:, 0, :]
# 分类头
logits = self.classifier(pooled)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
# 返回标准格式
return {
"loss": loss,
"logits": logits
} |