symbol-fim-model / models /symbol_fim_model.py
ethanker's picture
Upload models/symbol_fim_model.py with huggingface_hub
4954e06 verified
from typing import Dict, Optional
import torch
from torch import nn
from torch.nn import functional as F
class SymbolFIMModel(nn.Module):
def __init__(
self,
vocab_size: int,
d_model: int,
n_layers: int,
n_heads: int,
window: int,
max_len: int,
ast_head_cfg=None,
) -> None:
super().__init__()
self.vocab_size = vocab_size
self.max_len = max_len
self.window = window
self.token_emb = nn.Embedding(vocab_size, d_model)
self.pos_emb = nn.Embedding(max_len, d_model)
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=n_heads,
dim_feedforward=4 * d_model,
dropout=0.1,
activation="gelu",
batch_first=True,
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
self.norm = nn.LayerNorm(d_model)
self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
self.dropout = nn.Dropout(0.1)
self._init_weights()
def _init_weights(self):
for module in self.modules():
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm):
torch.nn.init.ones_(module.weight)
torch.nn.init.zeros_(module.bias)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
batch_size, seq_len = input_ids.shape
device = input_ids.device
positions = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
hidden_states = self.token_emb(input_ids) + self.pos_emb(positions)
hidden_states = self.dropout(hidden_states)
key_padding_mask = None
if attention_mask is not None:
key_padding_mask = ~attention_mask.to(torch.bool)
attn_mask = self._causal_mask(seq_len, device)
logits = self.encoder(
hidden_states,
mask=attn_mask,
src_key_padding_mask=key_padding_mask,
)
logits = self.norm(logits)
logits = self.lm_head(logits)
if torch.isnan(logits).any() or torch.isinf(logits).any():
logits = torch.clamp(logits, min=-50.0, max=50.0)
logits = torch.where(torch.isnan(logits), torch.zeros_like(logits), logits)
logits = torch.where(torch.isinf(logits), torch.zeros_like(logits), logits)
lm_loss = None
if labels is not None:
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = labels[:, 1:].contiguous()
valid_mask = (shift_labels != -100)
if valid_mask.sum() == 0:
lm_loss = torch.tensor(0.0, device=device, requires_grad=True)
else:
lm_loss = F.cross_entropy(
shift_logits.view(-1, self.vocab_size),
shift_labels.view(-1),
ignore_index=-100,
)
if torch.isnan(lm_loss) or torch.isinf(lm_loss):
valid_logits = shift_logits[valid_mask]
valid_labels = shift_labels[valid_mask]
if valid_logits.numel() > 0 and valid_labels.numel() > 0:
lm_loss = F.cross_entropy(
valid_logits.view(-1, self.vocab_size),
valid_labels.view(-1),
)
return {"logits": logits, "lm_loss": lm_loss}
def _causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor:
if self.window > 0 and self.window < seq_len:
row_indices = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(1)
col_indices = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0)
causal_mask = col_indices > row_indices
window_start = torch.clamp(row_indices - self.window + 1, min=0)
window_mask = col_indices < window_start
mask = torch.zeros((seq_len, seq_len), device=device, dtype=torch.float32)
mask[causal_mask | window_mask] = float("-inf")
else:
mask = torch.triu(torch.full((seq_len, seq_len), float("-inf"), device=device, dtype=torch.float32), diagonal=1)
return mask