| from typing import Dict, Optional |
|
|
| import torch |
| from torch import nn |
| from torch.nn import functional as F |
|
|
|
|
| class SymbolFIMModel(nn.Module): |
| def __init__( |
| self, |
| vocab_size: int, |
| d_model: int, |
| n_layers: int, |
| n_heads: int, |
| window: int, |
| max_len: int, |
| ast_head_cfg=None, |
| ) -> None: |
| super().__init__() |
| self.vocab_size = vocab_size |
| self.max_len = max_len |
| self.window = window |
|
|
| self.token_emb = nn.Embedding(vocab_size, d_model) |
| self.pos_emb = nn.Embedding(max_len, d_model) |
| encoder_layer = nn.TransformerEncoderLayer( |
| d_model=d_model, |
| nhead=n_heads, |
| dim_feedforward=4 * d_model, |
| dropout=0.1, |
| activation="gelu", |
| batch_first=True, |
| ) |
| self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers) |
| self.norm = nn.LayerNorm(d_model) |
| self.lm_head = nn.Linear(d_model, vocab_size, bias=False) |
| self.dropout = nn.Dropout(0.1) |
| |
| self._init_weights() |
| |
| def _init_weights(self): |
| for module in self.modules(): |
| if isinstance(module, nn.Linear): |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| if module.bias is not None: |
| torch.nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.Embedding): |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| elif isinstance(module, nn.LayerNorm): |
| torch.nn.init.ones_(module.weight) |
| torch.nn.init.zeros_(module.bias) |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| labels: Optional[torch.Tensor] = None, |
| ) -> Dict[str, torch.Tensor]: |
| batch_size, seq_len = input_ids.shape |
| device = input_ids.device |
|
|
| positions = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1) |
| hidden_states = self.token_emb(input_ids) + self.pos_emb(positions) |
| hidden_states = self.dropout(hidden_states) |
|
|
| key_padding_mask = None |
| if attention_mask is not None: |
| key_padding_mask = ~attention_mask.to(torch.bool) |
|
|
| attn_mask = self._causal_mask(seq_len, device) |
| |
| logits = self.encoder( |
| hidden_states, |
| mask=attn_mask, |
| src_key_padding_mask=key_padding_mask, |
| ) |
| logits = self.norm(logits) |
| logits = self.lm_head(logits) |
| |
| if torch.isnan(logits).any() or torch.isinf(logits).any(): |
| logits = torch.clamp(logits, min=-50.0, max=50.0) |
| logits = torch.where(torch.isnan(logits), torch.zeros_like(logits), logits) |
| logits = torch.where(torch.isinf(logits), torch.zeros_like(logits), logits) |
|
|
| lm_loss = None |
| if labels is not None: |
| shift_logits = logits[:, :-1, :].contiguous() |
| shift_labels = labels[:, 1:].contiguous() |
| |
| valid_mask = (shift_labels != -100) |
| if valid_mask.sum() == 0: |
| lm_loss = torch.tensor(0.0, device=device, requires_grad=True) |
| else: |
| lm_loss = F.cross_entropy( |
| shift_logits.view(-1, self.vocab_size), |
| shift_labels.view(-1), |
| ignore_index=-100, |
| ) |
| |
| if torch.isnan(lm_loss) or torch.isinf(lm_loss): |
| valid_logits = shift_logits[valid_mask] |
| valid_labels = shift_labels[valid_mask] |
| if valid_logits.numel() > 0 and valid_labels.numel() > 0: |
| lm_loss = F.cross_entropy( |
| valid_logits.view(-1, self.vocab_size), |
| valid_labels.view(-1), |
| ) |
|
|
| return {"logits": logits, "lm_loss": lm_loss} |
|
|
| def _causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor: |
| if self.window > 0 and self.window < seq_len: |
| row_indices = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(1) |
| col_indices = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0) |
| |
| causal_mask = col_indices > row_indices |
| window_start = torch.clamp(row_indices - self.window + 1, min=0) |
| window_mask = col_indices < window_start |
| |
| mask = torch.zeros((seq_len, seq_len), device=device, dtype=torch.float32) |
| mask[causal_mask | window_mask] = float("-inf") |
| else: |
| mask = torch.triu(torch.full((seq_len, seq_len), float("-inf"), device=device, dtype=torch.float32), diagonal=1) |
| |
| return mask |
|
|
|
|