|
|
import os |
|
|
import json |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.nn import functional as F |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from tokenizers import Tokenizer |
|
|
from tokenizers.models import BPE |
|
|
from tokenizers.trainers import BpeTrainer |
|
|
from tokenizers.pre_tokenizers import Whitespace |
|
|
from pathlib import Path |
|
|
import argparse |
|
|
|
|
|
class LightweightGPT(nn.Module): |
|
|
def __init__(self, vocab_size, block_size, n_embd, n_head, n_layer): |
|
|
super().__init__() |
|
|
self.block_size = block_size |
|
|
self.token_embedding = nn.Embedding(vocab_size, n_embd) |
|
|
self.position_embedding = nn.Embedding(block_size, n_embd) |
|
|
|
|
|
self.blocks = nn.ModuleList([ |
|
|
nn.TransformerDecoderLayer( |
|
|
d_model=n_embd, |
|
|
nhead=n_head, |
|
|
dim_feedforward=4 * n_embd, |
|
|
dropout=0.1, |
|
|
activation='gelu', |
|
|
batch_first=True, |
|
|
norm_first=True |
|
|
) |
|
|
for _ in range(n_layer) |
|
|
]) |
|
|
self.ln_f = nn.LayerNorm(n_embd) |
|
|
self.lm_head = nn.Linear(n_embd, vocab_size, bias=False) |
|
|
|
|
|
def forward(self, idx, targets=None): |
|
|
B, T = idx.shape |
|
|
device = idx.device |
|
|
causal_mask = torch.triu(torch.ones(T, T, device=device, dtype=torch.bool), diagonal=1) |
|
|
|
|
|
token_emb = self.token_embedding(idx) |
|
|
pos = torch.arange(0, T, dtype=torch.long, device=device) |
|
|
pos_emb = self.position_embedding(pos) |
|
|
|
|
|
x = token_emb + pos_emb |
|
|
|
|
|
for block in self.blocks: |
|
|
x = block(x, x, tgt_mask=causal_mask) |
|
|
|
|
|
x = self.ln_f(x) |
|
|
logits = self.lm_head(x) |
|
|
|
|
|
loss = None |
|
|
if targets is not None: |
|
|
loss = F.cross_entropy( |
|
|
logits.view(-1, logits.size(-1)), |
|
|
targets.view(-1), |
|
|
ignore_index=-1 |
|
|
) |
|
|
return logits, loss |
|
|
|
|
|
def generate(self, idx, max_new_tokens, temperature=0.8, top_k=50, stop_token=None): |
|
|
for _ in range(max_new_tokens): |
|
|
idx_cond = idx[:, -self.block_size:] |
|
|
logits, _ = self(idx_cond) |
|
|
logits = logits[:, -1, :] |
|
|
logits = logits / temperature |
|
|
|
|
|
if top_k is not None: |
|
|
v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
|
|
logits[logits < v[:, [-1]]] = -float('Inf') |
|
|
|
|
|
probs = F.softmax(logits, dim=-1) |
|
|
idx_next = torch.multinomial(probs, num_samples=1) |
|
|
|
|
|
if stop_token is not None and idx_next.item() == stop_token: |
|
|
break |
|
|
|
|
|
idx = torch.cat((idx, idx_next), dim=1) |
|
|
|
|
|
return idx |
|
|
|
|
|
class ConversationDataset(Dataset): |
|
|
def __init__(self, tokens, block_size, end_token_id): |
|
|
self.end_token = end_token_id |
|
|
self.block_size = block_size |
|
|
self.segments = [] |
|
|
current_start = 0 |
|
|
for i, token in enumerate(tokens): |
|
|
if token == end_token_id: |
|
|
segment = tokens[current_start:i+1] |
|
|
if len(segment) < block_size + 1: |
|
|
padding = [end_token_id] * (block_size + 1 - len(segment)) |
|
|
segment.extend(padding) |
|
|
self.segments.append(segment) |
|
|
current_start = i + 1 |
|
|
print(f"Created {len(self.segments)} conversation segments.") |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.segments) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
segment = self.segments[idx] |
|
|
start_pos = torch.randint(0, max(1, len(segment) - self.block_size), (1,)).item() |
|
|
chunk = segment[start_pos:start_pos + self.block_size + 1] |
|
|
|
|
|
x = torch.tensor(chunk[:-1], dtype=torch.long) |
|
|
y = torch.tensor(chunk[1:], dtype=torch.long) |
|
|
return x, y |
|
|
|
|
|
class AIBuilder: |
|
|
def __init__(self, model_name: str): |
|
|
self.model_name = model_name |
|
|
self.output_folder = model_name.replace(" ", "_").lower() |
|
|
self.device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu" |
|
|
print(f"Using device: {self.device}") |
|
|
|
|
|
self.model_config = { |
|
|
"block_size": 128, |
|
|
"n_embd": 128, |
|
|
"n_head": 4, |
|
|
"n_layer": 4, |
|
|
"vocab_size": 8000, |
|
|
"batch_size": 8, |
|
|
"grad_accum": 4, |
|
|
"max_epochs": 3, |
|
|
} |
|
|
|
|
|
def _build_tokenizer(self, training_data: str): |
|
|
tokenizer = Tokenizer(BPE(unk_token="[UNK]")) |
|
|
tokenizer.pre_tokenizer = Whitespace() |
|
|
trainer = BpeTrainer( |
|
|
special_tokens=["[UNK]", "[PAD]", "user:", "ai:", "<|endoftext|>"], |
|
|
vocab_size=self.model_config["vocab_size"] |
|
|
) |
|
|
tokenizer.train_from_iterator(self._get_text_iterator(training_data), trainer) |
|
|
return tokenizer |
|
|
|
|
|
def _get_text_iterator(self, text, chunk_size=1000): |
|
|
for i in range(0, len(text), chunk_size): |
|
|
yield text[i:i + chunk_size] |
|
|
|
|
|
def _prepare_dataloader(self, tokenizer, text): |
|
|
tokens = tokenizer.encode(text).ids |
|
|
end_token_id = tokenizer.token_to_id("<|endoftext|>") |
|
|
dataset = ConversationDataset(tokens, self.model_config["block_size"], end_token_id) |
|
|
|
|
|
def collate_fn(batch): |
|
|
xs, ys = zip(*batch) |
|
|
return torch.stack(xs), torch.stack(ys) |
|
|
|
|
|
return DataLoader(dataset, batch_size=self.model_config["batch_size"], shuffle=True, collate_fn=collate_fn) |
|
|
|
|
|
def train(self, training_data: str): |
|
|
os.makedirs(self.output_folder, exist_ok=True) |
|
|
|
|
|
print("Building and saving tokenizer...") |
|
|
tokenizer = self._build_tokenizer(training_data) |
|
|
tokenizer.save(os.path.join(self.output_folder, "tokenizer.json")) |
|
|
|
|
|
print("Saving configuration file...") |
|
|
self._save_config(tokenizer) |
|
|
|
|
|
print("Preparing data for training...") |
|
|
dataloader = self._prepare_dataloader(tokenizer, training_data) |
|
|
|
|
|
model = LightweightGPT( |
|
|
vocab_size=tokenizer.get_vocab_size(), |
|
|
block_size=self.model_config["block_size"], |
|
|
n_embd=self.model_config["n_embd"], |
|
|
n_head=self.model_config["n_head"], |
|
|
n_layer=self.model_config["n_layer"] |
|
|
).to(self.device) |
|
|
|
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4) |
|
|
model_path = os.path.join(self.output_folder, "model.pt") |
|
|
|
|
|
print("\n--- Starting Model Training ---") |
|
|
model.train() |
|
|
best_loss = float('inf') |
|
|
|
|
|
for epoch in range(self.model_config["max_epochs"]): |
|
|
optimizer.zero_grad() |
|
|
for batch_idx, (x, y) in enumerate(dataloader): |
|
|
x, y = x.to(self.device), y.to(self.device) |
|
|
_, loss = model(x, y) |
|
|
|
|
|
loss = loss / self.model_config["grad_accum"] |
|
|
loss.backward() |
|
|
|
|
|
if (batch_idx + 1) % self.model_config["grad_accum"] == 0: |
|
|
optimizer.step() |
|
|
optimizer.zero_grad() |
|
|
|
|
|
current_loss = loss.detach().item() * self.model_config["grad_accum"] |
|
|
|
|
|
if batch_idx % 50 == 0: |
|
|
print(f"Epoch {epoch+1} | Batch {batch_idx} | Loss: {current_loss:.4f}") |
|
|
|
|
|
if current_loss < best_loss: |
|
|
best_loss = current_loss |
|
|
torch.save(model.state_dict(), model_path) |
|
|
print(f"🎉 New best model saved with loss: {best_loss:.4f}") |
|
|
|
|
|
print(f"✅ Training complete. Final best loss: {best_loss:.4f}") |
|
|
|
|
|
def _save_config(self, tokenizer): |
|
|
config = { |
|
|
"model_name": self.model_name, |
|
|
**self.model_config, |
|
|
"vocab_size": tokenizer.get_vocab_size(), |
|
|
"end_token_id": tokenizer.token_to_id("<|endoftext|>") |
|
|
} |
|
|
with open(os.path.join(self.output_folder, "config.json"), "w") as f: |
|
|
json.dump(config, f, indent=2) |
|
|
print(f"Configuration saved to {os.path.join(self.output_folder, 'config.json')}") |
|
|
|
|
|
class ChatInterface: |
|
|
def __init__(self, model_dir="aglm"): |
|
|
self.model_dir = Path(model_dir) |
|
|
self.device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu" |
|
|
self.load_model() |
|
|
|
|
|
def load_model(self): |
|
|
with open(self.model_dir / "config.json", "r") as f: |
|
|
self.config = json.load(f) |
|
|
|
|
|
self.tokenizer = Tokenizer.from_file(str(self.model_dir / "tokenizer.json")) |
|
|
self.end_token_id = self.config.get("end_token_id") |
|
|
|
|
|
self.model = LightweightGPT( |
|
|
vocab_size=self.config["vocab_size"], |
|
|
block_size=self.config["block_size"], |
|
|
n_embd=self.config["n_embd"], |
|
|
n_head=self.config["n_head"], |
|
|
n_layer=self.config["n_layer"] |
|
|
).to(self.device) |
|
|
|
|
|
self.model.load_state_dict(torch.load(self.model_dir / "model.pt", map_location=self.device)) |
|
|
self.model.eval() |
|
|
print("✅ Model loaded successfully!") |
|
|
|
|
|
def chat(self): |
|
|
print("\n===== AI Assistant Ready =====") |
|
|
print("Type 'quit' or 'exit' to end the chat.\n") |
|
|
|
|
|
while True: |
|
|
user_input = input("user: ") |
|
|
if user_input.lower() in ["quit", "exit"]: |
|
|
break |
|
|
|
|
|
prompt = f"user: {user_input}\nai:" |
|
|
input_ids = self.tokenizer.encode(prompt).ids |
|
|
input_tensor = torch.tensor([input_ids], dtype=torch.long, device=self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
output_ids = self.model.generate( |
|
|
input_tensor, |
|
|
max_new_tokens=150, |
|
|
temperature=0.7, |
|
|
top_k=40, |
|
|
stop_token=self.end_token_id |
|
|
) |
|
|
|
|
|
response_ids = output_ids[0, len(input_ids):].tolist() |
|
|
response = self.tokenizer.decode(response_ids) |
|
|
response = response.replace("<|endoftext|>", "").strip() |
|
|
|
|
|
print(f"ai: {response}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser(description="Train or chat with an AgLM model.") |
|
|
parser.add_argument('action', choices=['train', 'chat'], nargs='?', default='train', help="Choose 'train' (default) or 'chat'.") |
|
|
args = parser.parse_args() |
|
|
|
|
|
model_folder = "aglm" |
|
|
|
|
|
if args.action == 'train': |
|
|
print("--- Starting Setup for AgLM ---") |
|
|
builder = AIBuilder("AgLM") |
|
|
try: |
|
|
with open("train.txt", "r", encoding="utf-8") as f: |
|
|
data = f.read() |
|
|
builder.train(data) |
|
|
print("\n✅ Training finished. You can now run with the 'chat' argument.") |
|
|
print(f"To chat, run: python {os.path.basename(__file__)} chat") |
|
|
except FileNotFoundError: |
|
|
print("\nERROR: train.txt not found. Please create train.txt with your conversational data to train the model.") |
|
|
|
|
|
elif args.action == 'chat': |
|
|
print("--- Starting Chat Interface for AgLM ---") |
|
|
if os.path.exists(model_folder) and os.path.exists(os.path.join(model_folder, "model.pt")): |
|
|
chat_bot = ChatInterface(model_dir=model_folder) |
|
|
chat_bot.chat() |
|
|
else: |
|
|
print(f"\nERROR: Model directory '{model_folder}' not found. Please run training first.") |