File size: 6,799 Bytes
038a912 baf2ee7 038a912 45553f3 baf2ee7 45553f3 baf2ee7 45553f3 baf2ee7 45553f3 baf2ee7 45553f3 baf2ee7 45553f3 baf2ee7 45553f3 baf2ee7 45553f3 baf2ee7 45553f3 baf2ee7 45553f3 baf2ee7 45553f3 baf2ee7 45553f3 baf2ee7 45553f3 038a912 45553f3 038a912 45553f3 038a912 45553f3 038a912 45553f3 038a912 45553f3 038a912 45553f3 038a912 45553f3 038a912 45553f3 038a912 45553f3 038a912 45553f3 038a912 baf2ee7 45553f3 baf2ee7 45553f3 baf2ee7 45553f3 baf2ee7 45553f3 baf2ee7 45553f3 baf2ee7 45553f3 baf2ee7 45553f3 baf2ee7 45553f3 baf2ee7 45553f3 baf2ee7 45553f3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
import torch
import torch.nn as nn
from torch.nn import functional as F
import json
import os
# --- Hyperparameters ---
# These are the settings for our model. You can experiment with these values.
batch_size = 64 # Increased from 32 to process more sequences in parallel
block_size = 32 # Increased from 8. This is the maximum context length for predictions. A larger value helps the model see more of the text, leading to better coherence.
max_iters = 15000 # Increased from 3000 to give the model more training time to learn complex patterns.
eval_interval = 500 # How often to evaluate the model
learning_rate = 3e-4 # A slightly lower learning rate is often better for more complex models.
device = 'cuda' if torch.cuda.is_available() else 'cpu' # Use GPU if available
eval_iters = 200 # Number of iterations for evaluation
n_embd = 64 # Increased from 32. The dimension of the token embeddings. A larger embedding size allows the model to store more information about each character.
n_layer = 4 # Increased from 2. The number of LSTM layers. More layers can capture more abstract patterns.
dropout = 0.0 # Dropout rate for regularization
# --- Data Preparation ---
# This code now expects a 'dataset.jsonl' file to be present in the same directory.
file_path = 'dataset.jsonl'
corpus = ""
try:
with open(file_path, 'r') as f:
for line in f:
data_point = json.loads(line)
# The corrected line now uses 'header' and 'formal_statement'
corpus += data_point['header'] + '\n' + data_point['formal_statement'] + '\n'
except FileNotFoundError:
print(f"Error: The file '{file_path}' was not found. Please create it and run again.")
exit()
except (json.JSONDecodeError, KeyError) as e:
print(f"Error: There was a problem parsing a line in '{file_path}'. Details: {e}")
exit()
if not corpus:
print("Error: The corpus is empty. The dataset file might be empty or incorrectly formatted.")
exit()
# Create a simple character-level tokenizer.
chars = sorted(list(set(corpus)))
vocab_size = len(chars)
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])
# Convert the entire text into a PyTorch tensor.
data = torch.tensor(encode(corpus), dtype=torch.long)
# Create a simple train/validation split.
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]
# --- Helper Functions ---
# This function gets a random batch of data from either the training or validation set.
def get_batch(split):
data = train_data if split == 'train' else val_data
ix = torch.randint(len(data) - block_size, (batch_size,))
x = torch.stack([data[i:i + block_size] for i in ix])
y = torch.stack([data[i + 1:i + block_size + 1] for i in ix])
x, y = x.to(device), y.to(device)
return x, y
# This function is used to estimate the model's loss.
@torch.no_grad()
def estimate_loss():
out = {}
model.eval() # Set the model to evaluation mode.
for split in ['train', 'val']:
losses = torch.zeros(eval_iters)
for k in range(eval_iters):
X, Y = get_batch(split)
logits, loss = model(X, Y)
losses[k] = loss.item()
out[split] = losses.mean()
model.train() # Set the model back to training mode.
return out
# --- The Main LSTM Language Model ---
class LanguageModel(nn.Module):
def __init__(self):
super().__init__()
# An embedding table to convert tokens to dense vectors.
self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
# An LSTM layer to process the sequence.
self.lstm = nn.LSTM(n_embd, n_embd, num_layers=n_layer, batch_first=True)
# A final linear layer to project the LSTM's output to the vocabulary size.
self.lm_head = nn.Linear(n_embd, vocab_size)
def forward(self, idx, targets=None):
# Get the token embeddings.
tok_emb = self.token_embedding_table(idx) # (B, T, n_embd)
# Pass the embeddings through the LSTM layer.
lstm_out, _ = self.lstm(tok_emb) # lstm_out shape: (B, T, n_embd)
# Project the LSTM's output to the vocabulary size to get logits.
logits = self.lm_head(lstm_out) # (B, T, vocab_size)
loss = None
if targets is not None:
# Reshape for cross-entropy loss calculation.
B, T, C = logits.shape
logits = logits.view(B * T, C)
targets = targets.view(B * T)
loss = F.cross_entropy(logits, targets)
return logits, loss
def generate(self, idx, max_new_tokens):
# The `generate` method for LSTMs needs to handle hidden and cell states.
h_and_c = None # Start with no hidden state.
for _ in range(max_new_tokens):
# We only need the last token to predict the next one.
idx_cond = idx[:, -1].unsqueeze(1) # (B, 1)
tok_emb = self.token_embedding_table(idx_cond) # (B, 1, n_embd)
# Pass the single token through the LSTM, along with the previous hidden state.
lstm_out, h_and_c = self.lstm(tok_emb, h_and_c)
# Focus on the output of the last time step.
logits = self.lm_head(lstm_out[:, -1, :]) # (B, vocab_size)
# Apply softmax to get probabilities.
probs = F.softmax(logits, dim=-1)
# Sample from the distribution.
idx_next = torch.multinomial(probs, num_samples=1)
# Append the new token to the sequence.
idx = torch.cat((idx, idx_next), dim=1)
return idx
# --- Training and Generation ---
model = LanguageModel()
m = model.to(device)
# Create a PyTorch optimizer.
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
# Main training loop.
for iter in range(max_iters):
# Every few iterations, evaluate the loss on both splits.
if iter % eval_interval == 0:
losses = estimate_loss()
print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
# Sample a batch of data.
xb, yb = get_batch('train')
# Forward pass: compute loss.
logits, loss = model(xb, yb)
# Backward pass: compute gradients.
optimizer.zero_grad(set_to_none=True)
loss.backward()
# Update the model parameters.
optimizer.step()
# --- Generate new text from the trained model ---
context = torch.zeros((1, 1), dtype=torch.long, device=device)
generated_text_indices = m.generate(context, max_new_tokens=20)
print("\nGenerated text:")
print(decode(generated_text_indices[0].tolist()))
# Save the model's state dictionary after training
torch.save(m.state_dict(), 'model.pt')
print("Model saved to model.pt")
|