chatbot / components /model.py
frc 10252
add files
3905c4a
import torch
import torch.nn as nn
from torch.nn import functional as F
import math, time, os
from torch.utils.data import Dataset, DataLoader
import tiktoken
# from torch.cuda.amp import autocast, GradScaler
from torch.amp.autocast_mode import autocast
from torch.amp.grad_scaler import GradScaler
from tqdm import tqdm
from datasets import load_dataset
# class GPTModel(nn.Module):
# def __init__(
# self, vocab_size, n_embedding, n_layers, n_heads, dropout_p, block_size
# ):
# super(GPTModel, self).__init__()
# self.token_embedding = nn.Embedding(vocab_size, n_embedding)
# self.position_embedding = nn.Embedding(block_size, n_embedding)
# self.layers = nn.ModuleList(
# [
# nn.TransformerEncoderLayer(
# d_model=n_embedding, nhead=n_heads, dropout=dropout_p
# )
# for _ in range(n_layers)
# ]
# )
# self.ln_f = nn.LayerNorm(n_embedding)
# self.head = nn.Linear(n_embedding, vocab_size)
# self.dropout = nn.Dropout(dropout_p)
# self.block_size = block_size
# def forward(self, x):
# bsz, seq_len = x.size()
# positions = (
# torch.arange(0, seq_len, device=x.device).unsqueeze(0).expand(bsz, seq_len)
# )
# x = self.token_embedding(x) + self.position_embedding(positions)
# x = self.dropout(x)
# for layer in self.layers:
# x = layer(x)
# x = self.ln_f(x)
# logits = self.head(x)
# return logits
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
# ... existing imports ...
class MultiHeadAttention(nn.Module):
def __init__(self, n_embedding, n_heads, dropout_p):
super().__init__()
assert n_embedding % n_heads == 0
self.n_heads = n_heads
self.head_dim = n_embedding // n_heads
self.q_proj = nn.Linear(n_embedding, n_embedding)
self.k_proj = nn.Linear(n_embedding, n_embedding)
self.v_proj = nn.Linear(n_embedding, n_embedding)
self.out_proj = nn.Linear(n_embedding, n_embedding)
self.dropout = nn.Dropout(dropout_p)
def forward(self, x, attn_mask=None):
B, T, C = x.shape # batch size, seq length, embedding dim
q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
# c, embd dim split into n_heads x head_dim
# built-in scaled dot product attention for efficiency
attn_out = F.scaled_dot_product_attention(
q, k, v,
attn_mask=attn_mask,
dropout_p=self.dropout.p if self.training else 0.0,
is_causal=True,
)
attn_out = attn_out.transpose(1, 2).contiguous().view(B, T, C)
return self.out_proj(attn_out)
class FeedForward(nn.Module):
def __init__(self, n_embedding, dropout_p):
super().__init__()
self.net = nn.Sequential(
nn.Linear(n_embedding, 4 * n_embedding),
nn.GELU(),
nn.Linear(4 * n_embedding, n_embedding),
nn.Dropout(dropout_p),
)
def forward(self, x):
return self.net(x)
class TransformerBlock(nn.Module):
def __init__(self, n_embedding, n_heads, dropout_p):
super().__init__()
self.ln1 = nn.LayerNorm(n_embedding)
self.ln2 = nn.LayerNorm(n_embedding)
self.attn = MultiHeadAttention(n_embedding, n_heads, dropout_p)
self.ff = FeedForward(n_embedding, dropout_p)
def forward(self, x, attn_mask=None):
x = x + self.attn(self.ln1(x), attn_mask)
x = x + self.ff(self.ln2(x))
return x
class GPTModel(nn.Module):
def __init__(self, vocab_size, n_embedding, n_layers, n_heads, dropout_p, block_size):
super().__init__()
self.token_embed = nn.Embedding(vocab_size, n_embedding)
self.pos_embed = nn.Embedding(block_size, n_embedding)
self.blocks = nn.ModuleList([
TransformerBlock(n_embedding, n_heads, dropout_p)
for _ in range(n_layers)
])
self.ln_f = nn.LayerNorm(n_embedding)
self.head = nn.Linear(n_embedding, vocab_size, bias=False)
self.dropout = nn.Dropout(dropout_p)
self.block_size = block_size
def forward(self, idx):
B, T = idx.shape
assert T <= self.block_size, "Sequence exceeds block size."
pos = torch.arange(0, T, device=idx.device).unsqueeze(0)
x = self.token_embed(idx) + self.pos_embed(pos)
x = self.dropout(x)
# Causal mask for decoder: prevent attending to future tokens
attn_mask = torch.ones(T, T, device=idx.device, dtype=torch.bool).tril()
for block in self.blocks:
x = block(x, attn_mask)
x = self.ln_f(x)
logits = self.head(x)
return logits