Globe-1 / model.py
Ai128474's picture
Update model.py
4d4d4bf verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from tokenizers import ByteLevelBPETokenizer
import random
# === CONFIG ===
class Globe1Config:
vocab_size = 30000 # will match tokenizer
n_layers = 6
n_heads = 8
d_model = 512
d_ff = 2048
max_seq_len = 512
# === TRANSFORMER MODULES ===
class MultiHeadAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.n_heads = config.n_heads
self.d_model = config.d_model
self.d_head = self.d_model // self.n_heads
self.qkv = nn.Linear(config.d_model, 3 * config.d_model)
self.out = nn.Linear(config.d_model, config.d_model)
def forward(self, x, mask=None):
B, T, C = x.size()
qkv = self.qkv(x).view(B, T, self.n_heads, 3 * self.d_head)
q, k, v = qkv.split(self.d_head, dim=-1)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
att = (q @ k.transpose(-2, -1)) / (self.d_head ** 0.5)
if mask is not None:
att = att.masked_fill(mask == 0, float('-inf'))
att = F.softmax(att, dim=-1)
out = (att @ v).transpose(1, 2).contiguous().view(B, T, C)
return self.out(out)
class FeedForward(nn.Module):
def __init__(self, config):
super().__init__()
self.net = nn.Sequential(
nn.Linear(config.d_model, config.d_ff),
nn.GELU(),
nn.Linear(config.d_ff, config.d_model)
)
def forward(self, x):
return self.net(x)
class TransformerBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.ln1 = nn.LayerNorm(config.d_model)
self.attn = MultiHeadAttention(config)
self.ln2 = nn.LayerNorm(config.d_model)
self.ff = FeedForward(config)
def forward(self, x, mask=None):
x = x + self.attn(self.ln1(x), mask)
x = x + self.ff(self.ln2(x))
return x
class Globe1(nn.Module):
def __init__(self, config):
super().__init__()
self.token_emb = nn.Embedding(config.vocab_size, config.d_model)
self.pos_emb = nn.Embedding(config.max_seq_len, config.d_model)
self.blocks = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)])
self.ln_f = nn.LayerNorm(config.d_model)
self.head = nn.Linear(config.d_model, config.vocab_size)
def forward(self, idx, mask=None):
B, T = idx.size()
pos = torch.arange(0, T, device=idx.device).unsqueeze(0)
x = self.token_emb(idx) + self.pos_emb(pos)
for block in self.blocks:
x = block(x, mask)
x = self.ln_f(x)
logits = self.head(x)
return logits
# === DATASET ===
class TextDataset(Dataset):
def __init__(self, tokenizer, files, seq_len=512):
self.seq_len = seq_len
self.tokens = []
for file in files:
with open(file, "r", encoding="utf-8") as f:
text = f.read()
self.tokens += tokenizer.encode(text).ids
self.tokens = self.tokens
def __len__(self):
return len(self.tokens) - self.seq_len
def __getitem__(self, idx):
x = torch.tensor(self.tokens[idx:idx+self.seq_len], dtype=torch.long)
y = torch.tensor(self.tokens[idx+1:idx+self.seq_len+1], dtype=torch.long)
return x, y
# === TRAINING FUNCTION ===
def train_globe1(model, dataset, config, epochs=5, batch_size=16, lr=3e-4, device="cuda"):
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
model = model.to(device)
optimizer = AdamW(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
for epoch in range(epochs):
total_loss = 0
for step, (x, y) in enumerate(loader):
x, y = x.to(device), y.to(device)
logits = model(x)
loss = criterion(logits.view(-1, config.vocab_size), y.view(-1))
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
if step % 100 == 0:
print(f"Epoch {epoch+1}, Step {step}, Loss {loss.item():.4f}")
avg_loss = total_loss / len(loader)
print(f"Epoch {epoch+1} complete. Avg Loss: {avg_loss:.4f}")
return model
# === GENERATION FUNCTION ===
@torch.no_grad()
def generate_text(model, tokenizer, prompt, max_len=200, device="cuda", temperature=1.0, top_k=50):
model.eval()
tokens = tokenizer.encode(prompt).ids
tokens = torch.tensor(tokens, dtype=torch.long, device=device).unsqueeze(0)
for _ in range(max_len):
logits = model(tokens)
logits = logits[:, -1, :] / temperature
top_logits, top_indices = torch.topk(logits, top_k)
probs = F.softmax(top_logits, dim=-1)
next_token = top_indices[0, torch.multinomial(probs, 1)]
tokens = torch.cat([tokens, next_token.unsqueeze(0).unsqueeze(0)], dim=1)
output_ids = tokens[0].tolist()
return tokenizer.decode(output_ids)
# === MAIN ===
if __name__ == "__main__":
# 1️⃣ Load or train tokenizer
tokenizer = ByteLevelBPETokenizer()
tokenizer.train(files=["wiki.txt","openwebtext.txt"], vocab_size=30000)
tokenizer.save_model("globe_tokenizer")
# 2️⃣ Dataset
dataset = TextDataset(tokenizer, ["wiki.txt","openwebtext.txt"])
# 3️⃣ Model
config = Globe1Config()
model = Globe1(config)
# 4️⃣ Train
trained_model = train_globe1(model, dataset, config, epochs=3, batch_size=8)
# 5️⃣ Generate
prompt = "def fibonacci(n):"
print(generate_text(trained_model, tokenizer, prompt))