|
|
| import math, torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from dataclasses import dataclass |
| from transformers import PretrainedConfig, PreTrainedModel, GenerationMixin |
|
|
| class EasySellConfig(PretrainedConfig): |
| model_type = "easysell" |
| def __init__(self, vocab_size=856, block_size=1024, n_layer=14, |
| n_head=12, n_embd=768, dropout=0.1, |
| pad_token_id=0, bos_token_id=1, eos_token_id=2, **kwargs): |
| self.vocab_size = vocab_size |
| self.block_size = block_size |
| self.n_layer = n_layer |
| self.n_head = n_head |
| self.n_embd = n_embd |
| self.dropout = dropout |
| super().__init__(pad_token_id=pad_token_id, |
| bos_token_id=bos_token_id, |
| eos_token_id=eos_token_id, **kwargs) |
|
|
| @dataclass |
| class GPTConfig: |
| vocab_size : int |
| block_size : int |
| n_layer : int = 14 |
| n_head : int = 12 |
| n_embd : int = 768 |
| dropout : float = 0.1 |
|
|
| class CausalSelfAttention(nn.Module): |
| def __init__(self, cfg): |
| super().__init__() |
| assert cfg.n_embd % cfg.n_head == 0 |
| self.cfg = cfg |
| self.c_attn = nn.Linear(cfg.n_embd, 3 * cfg.n_embd) |
| self.c_proj = nn.Linear(cfg.n_embd, cfg.n_embd) |
| self.attn_dropout = nn.Dropout(cfg.dropout) |
| self.resid_dropout = nn.Dropout(cfg.dropout) |
| self.register_buffer("bias", |
| torch.tril(torch.ones(cfg.block_size, cfg.block_size)) |
| .view(1, 1, cfg.block_size, cfg.block_size)) |
| def forward(self, x): |
| B, T, C = x.size() |
| nh, hs = self.cfg.n_head, C // self.cfg.n_head |
| q, k, v = self.c_attn(x).split(C, dim=2) |
| q = q.view(B, T, nh, hs).transpose(1, 2) |
| k = k.view(B, T, nh, hs).transpose(1, 2) |
| v = v.view(B, T, nh, hs).transpose(1, 2) |
| att = (q @ k.transpose(-2, -1)) / math.sqrt(hs) |
| att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf")) |
| att = self.attn_dropout(F.softmax(att, dim=-1)) |
| y = (att @ v).transpose(1, 2).contiguous().view(B, T, C) |
| return self.resid_dropout(self.c_proj(y)) |
|
|
| class MLP(nn.Module): |
| def __init__(self, cfg): |
| super().__init__() |
| self.fc = nn.Linear(cfg.n_embd, 4 * cfg.n_embd) |
| self.proj = nn.Linear(4 * cfg.n_embd, cfg.n_embd) |
| self.dropout = nn.Dropout(cfg.dropout) |
| def forward(self, x): |
| return self.dropout(self.proj(F.gelu(self.fc(x)))) |
|
|
| class Block(nn.Module): |
| def __init__(self, cfg): |
| super().__init__() |
| self.ln1 = nn.LayerNorm(cfg.n_embd) |
| self.attn = CausalSelfAttention(cfg) |
| self.ln2 = nn.LayerNorm(cfg.n_embd) |
| self.mlp = MLP(cfg) |
| def forward(self, x): |
| x = x + self.attn(self.ln1(x)) |
| x = x + self.mlp(self.ln2(x)) |
| return x |
|
|
| class GPT(nn.Module): |
| def __init__(self, cfg): |
| super().__init__() |
| self.cfg = cfg |
| self.wte = nn.Embedding(cfg.vocab_size, cfg.n_embd) |
| self.wpe = nn.Embedding(cfg.block_size, cfg.n_embd) |
| self.drop = nn.Dropout(cfg.dropout) |
| self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layer)]) |
| self.ln_f = nn.LayerNorm(cfg.n_embd) |
| self.lm_head = nn.Linear(cfg.n_embd, cfg.vocab_size, bias=False) |
| self.apply(self._init_weights) |
| def _init_weights(self, m): |
| if isinstance(m, nn.Linear): |
| nn.init.normal_(m.weight, 0.0, 0.02) |
| if m.bias is not None: nn.init.zeros_(m.bias) |
| if isinstance(m, nn.Embedding): |
| nn.init.normal_(m.weight, 0.0, 0.02) |
| def forward(self, idx, targets=None): |
| B, T = idx.size() |
| pos = torch.arange(T, device=idx.device).unsqueeze(0) |
| x = self.drop(self.wte(idx) + self.wpe(pos)) |
| for blk in self.blocks: |
| x = blk(x) |
| logits = self.lm_head(self.ln_f(x)) |
| loss = None |
| if targets is not None: |
| loss = F.cross_entropy( |
| logits.view(-1, logits.size(-1)), |
| targets.view(-1), ignore_index=0) |
| return logits, loss |
|
|
| class EasySellForCausalLM(PreTrainedModel, GenerationMixin): |
| def __init__(self, config: EasySellConfig): |
| super().__init__(config) |
| gpt_cfg = GPTConfig( |
| vocab_size=config.vocab_size, block_size=config.block_size, |
| n_layer=config.n_layer, n_head=config.n_head, |
| n_embd=config.n_embd, dropout=config.dropout) |
| self.model = GPT(gpt_cfg) |
|
|
| @property |
| def all_tied_weights_keys(self): |
| return {} |
|
|
| def forward(self, input_ids, labels=None, **kwargs): |
| logits, loss = self.model(input_ids, targets=labels) |
| from transformers.modeling_outputs import CausalLMOutput |
| return CausalLMOutput(loss=loss, logits=logits) |
|
|
| def prepare_inputs_for_generation(self, input_ids, **kwargs): |
| return {"input_ids": input_ids} |
|
|
| def generate(self, input_ids, max_new_tokens=120, |
| temperature=0.1, top_k=30, **kwargs): |
| self.eval() |
| device = next(self.parameters()).device |
| x = input_ids.to(device) |
| eos_id = self.config.eos_token_id |
| generated = [] |
| with torch.no_grad(): |
| for _ in range(max_new_tokens): |
| x_cond = x[:, -self.config.block_size:] |
| logits, _ = self.model(x_cond) |
| logits = logits[:, -1, :] / max(temperature, 1e-6) |
| if len(generated) == 0: |
| logits[:, eos_id] = float("-inf") |
| if top_k > 0: |
| v, ix = torch.topk(logits, top_k) |
| probs = torch.zeros_like(logits).scatter_( |
| 1, ix, F.softmax(v, dim=-1)) |
| else: |
| probs = F.softmax(logits, dim=-1) |
| next_id = torch.multinomial(probs, num_samples=1) |
| if next_id.item() == eos_id and len(generated) > 0: |
| break |
| generated.append(next_id.item()) |
| if len(generated) >= 8 and len(set(generated[-8:])) <= 2: |
| break |
| x = torch.cat([x, next_id], dim=1) |
| gen_tensor = torch.tensor([generated], dtype=torch.long, device=device) |
| return torch.cat([input_ids.to(device), gen_tensor], dim=1) |
|
|
| EasySellForCausalLM.config_class = EasySellConfig |
|
|