VISDOM / src /model.py
VishalPreetham's picture
Upload folder using huggingface_hub
18be545 verified
Raw
History Blame Contribute Delete
7.65 kB
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
@dataclass
class GPTConfig:
vocab_size: int = 32000
block_size: int = 256
n_layer: int = 6
n_head: int = 8
n_embd: int = 384
dropout: float = 0.1
bias: bool = False
class CausalSelfAttention(nn.Module):
def __init__(self, config: GPTConfig):
super().__init__()
assert config.n_embd % config.n_head == 0, "n_embd must be divisible by n_head"
self.n_head = config.n_head
self.n_embd = config.n_embd
self.dropout = config.dropout
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
self.attn_dropout = nn.Dropout(config.dropout)
self.resid_dropout = nn.Dropout(config.dropout)
self.register_buffer(
"bias",
torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size),
persistent=False,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
b, t, c = x.size()
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
q = q.view(b, t, self.n_head, c // self.n_head).transpose(1, 2)
k = k.view(b, t, self.n_head, c // self.n_head).transpose(1, 2)
v = v.view(b, t, self.n_head, c // self.n_head).transpose(1, 2)
# Prefer PyTorch's fused scaled-dot-product attention when available.
if hasattr(F, "scaled_dot_product_attention"):
y = F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=None,
dropout_p=self.dropout if self.training else 0.0,
is_causal=True,
)
else:
att = (q @ k.transpose(-2, -1)) * (1.0 / (k.size(-1) ** 0.5))
att = att.masked_fill(self.bias[:, :, :t, :t] == 0, float("-inf"))
att = F.softmax(att, dim=-1)
att = self.attn_dropout(att)
y = att @ v
y = y.transpose(1, 2).contiguous().view(b, t, c)
return self.resid_dropout(self.c_proj(y))
class MLP(nn.Module):
def __init__(self, config: GPTConfig):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
self.gelu = nn.GELU()
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.dropout(self.c_proj(self.gelu(self.c_fc(x))))
class Block(nn.Module):
def __init__(self, config: GPTConfig):
super().__init__()
self.ln_1 = nn.LayerNorm(config.n_embd, bias=config.bias)
self.attn = CausalSelfAttention(config)
self.ln_2 = nn.LayerNorm(config.n_embd, bias=config.bias)
self.mlp = MLP(config)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.attn(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class GPTLanguageModel(nn.Module):
def __init__(self, config: GPTConfig):
super().__init__()
self.config = config
self.transformer = nn.ModuleDict(
dict(
wte=nn.Embedding(config.vocab_size, config.n_embd),
wpe=nn.Embedding(config.block_size, config.n_embd),
drop=nn.Dropout(config.dropout),
h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
ln_f=nn.LayerNorm(config.n_embd, bias=config.bias),
)
)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# Weight tying saves parameters and is common in GPT-style models.
self.transformer.wte.weight = self.lm_head.weight
self.apply(self._init_weights)
for name, param in self.named_parameters():
if name.endswith("c_proj.weight"):
nn.init.normal_(param, mean=0.0, std=0.02 / (2 * config.n_layer) ** 0.5)
def _init_weights(self, module: nn.Module) -> None:
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward_hidden(self, idx: torch.Tensor) -> torch.Tensor:
b, t = idx.size()
if t > self.config.block_size:
raise ValueError(f"Sequence length {t} exceeds block_size {self.config.block_size}")
pos = torch.arange(0, t, dtype=torch.long, device=idx.device)
tok_emb = self.transformer.wte(idx)
pos_emb = self.transformer.wpe(pos)
x = self.transformer.drop(tok_emb + pos_emb)
for block in self.transformer.h:
x = block(x)
return self.transformer.ln_f(x)
def forward(self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None):
x = self.forward_hidden(idx)
logits = self.lm_head(x)
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
return logits, loss
@torch.no_grad()
def generate(
self,
idx: torch.Tensor,
max_new_tokens: int,
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: float = 1.0,
repetition_penalty: float = 1.0,
) -> torch.Tensor:
for _ in range(max_new_tokens):
idx_cond = idx[:, -self.config.block_size :]
logits, _ = self(idx_cond)
logits = logits[:, -1, :] / max(temperature, 1e-8)
if repetition_penalty > 1.0:
# Downweight tokens already seen in the current context to reduce loops.
for batch_idx in range(idx.size(0)):
seen_tokens = torch.unique(idx[batch_idx])
logits[batch_idx, seen_tokens] = logits[batch_idx, seen_tokens] / repetition_penalty
if top_k is not None and top_k > 0:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float("inf")
if 0.0 < top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
sorted_probs = F.softmax(sorted_logits, dim=-1)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = False
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits = logits.masked_fill(indices_to_remove, -float("inf"))
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, idx_next), dim=1)
return idx
def config_from_dict(cfg: dict) -> GPTConfig:
return GPTConfig(
vocab_size=int(cfg["vocab_size"]),
block_size=int(cfg["block_size"]),
n_layer=int(cfg["n_layer"]),
n_head=int(cfg["n_head"]),
n_embd=int(cfg["n_embd"]),
dropout=float(cfg.get("dropout", 0.1)),
bias=bool(cfg.get("bias", False)),
)