| | |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from transformers import PreTrainedModel |
| |
|
| | from configuration_tinytransformer import TinyTransformerConfig |
| |
|
| |
|
| | class TinyTransformerModel(PreTrainedModel): |
| | """ |
| | 一个非常小的 Transformer 编码器 + 分类头,用于情感分类演示 |
| | """ |
| | config_class = TinyTransformerConfig |
| |
|
| | def __init__(self, config: TinyTransformerConfig): |
| | super().__init__(config) |
| | self.config = config |
| |
|
| | |
| | self.embedding = nn.Embedding( |
| | config.vocab_size, |
| | config.hidden_size, |
| | padding_idx=config.pad_token_id if hasattr(config, 'pad_token_id') else 0 |
| | ) |
| |
|
| | |
| | self.pos_embedding = nn.Embedding( |
| | config.max_position_embeddings, |
| | config.hidden_size |
| | ) |
| |
|
| | |
| | encoder_layer = nn.TransformerEncoderLayer( |
| | d_model=config.hidden_size, |
| | nhead=config.num_attention_heads, |
| | dim_feedforward=config.intermediate_size, |
| | dropout=config.dropout, |
| | activation="gelu", |
| | batch_first=True, |
| | norm_first=False |
| | ) |
| |
|
| | self.encoder = nn.TransformerEncoder( |
| | encoder_layer, |
| | num_layers=config.num_hidden_layers |
| | ) |
| |
|
| | |
| | self.dropout = nn.Dropout(config.dropout) |
| | self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
| |
|
| | |
| | self._tied_weights_keys = [] |
| |
|
| | |
| | print("DEBUG: 已设置 _tied_weights_keys =", self._tied_weights_keys) |
| |
|
| | def _init_weights(self, module=None): |
| | """简单权重初始化""" |
| | if module is None: |
| | module = self |
| | for m in module.modules(): |
| | if isinstance(m, nn.Linear): |
| | nn.init.xavier_uniform_(m.weight) |
| | if m.bias is not None: |
| | nn.init.zeros_(m.bias) |
| | elif isinstance(m, nn.Embedding): |
| | nn.init.normal_(m.weight, mean=0.0, std=0.02) |
| |
|
| | def forward( |
| | self, |
| | input_ids=None, |
| | attention_mask=None, |
| | labels=None, |
| | **kwargs |
| | ): |
| | batch_size, seq_len = input_ids.shape |
| |
|
| | |
| | position_ids = torch.arange( |
| | 0, seq_len, dtype=torch.long, device=input_ids.device |
| | ) |
| | position_ids = position_ids.unsqueeze(0).expand(batch_size, -1) |
| |
|
| | |
| | x = self.embedding(input_ids) + self.pos_embedding(position_ids) |
| | x = self.dropout(x) |
| |
|
| | |
| | if attention_mask is not None: |
| | |
| | src_key_padding_mask = (attention_mask == 0) |
| | else: |
| | src_key_padding_mask = None |
| |
|
| | |
| | x = self.encoder( |
| | x, |
| | src_key_padding_mask=src_key_padding_mask, |
| | ) |
| |
|
| | |
| | pooled = x[:, 0, :] |
| |
|
| | |
| | logits = self.classifier(pooled) |
| |
|
| | loss = None |
| | if labels is not None: |
| | loss_fct = nn.CrossEntropyLoss() |
| | loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) |
| |
|
| | |
| | return { |
| | "loss": loss, |
| | "logits": logits |
| | } |