AgLMExperiment1 / gpt_chat.py
AGofficial's picture
Upload 4 files
285857b verified
import torch
import json
from torch.nn import functional as F
from tokenizers import Tokenizer
from pathlib import Path
import os
# LightweightGPT Model
class LightweightGPT(torch.nn.Module):
"""Compact GPT-like model with causal masking and positional encoding"""
def __init__(self, vocab_size, block_size, n_embd, n_head, n_layer):
super().__init__()
self.block_size = block_size
self.token_embedding = torch.nn.Embedding(vocab_size, n_embd)
self.position_embedding = torch.nn.Embedding(block_size, n_embd)
self.blocks = torch.nn.ModuleList([
torch.nn.TransformerDecoderLayer(
d_model=n_embd,
nhead=n_head,
dim_feedforward=4 * n_embd,
dropout=0.1,
activation='gelu',
batch_first=True,
norm_first=True
)
for _ in range(n_layer)
])
self.ln_f = torch.nn.LayerNorm(n_embd)
self.lm_head = torch.nn.Linear(n_embd, vocab_size, bias=False)
def forward(self, idx, targets=None):
B, T = idx.shape
device = idx.device
causal_mask = torch.triu(torch.ones(T, T, device=device, dtype=torch.bool), diagonal=1)
token_emb = self.token_embedding(idx)
pos = torch.arange(0, T, dtype=torch.long, device=device)
pos_emb = self.position_embedding(pos)
x = token_emb + pos_emb
for block in self.blocks:
x = block(x, x, tgt_mask=causal_mask)
x = self.ln_f(x)
logits = self.lm_head(x)
loss = None
if targets is not None:
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
targets.view(-1),
ignore_index=-1
)
return logits, loss
def generate(self, idx, max_new_tokens, temperature=0.8, top_k=50, stop_token=None):
"""Generate text with context handling and positional encoding"""
for _ in range(max_new_tokens):
idx_cond = idx[:, -self.block_size:]
logits, _ = self(idx_cond)
logits = logits[:, -1, :]
logits = logits / temperature
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
if stop_token is not None and idx_next.item() == stop_token:
break
idx = torch.cat((idx, idx_next), dim=1)
return idx
# Chat Interface
class ChatInterface:
def __init__(self, model_dir="/"):
self.model_dir = Path(model_dir)
self.device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
self.load_model()
def load_model(self):
with open(self.model_dir / "config.json", "r") as f:
self.config = json.load(f)
self.tokenizer = Tokenizer.from_file(str(self.model_dir / "tokenizer.json"))
self.end_token_id = self.config.get("end_token_id")
self.model = LightweightGPT(
vocab_size=self.config["vocab_size"],
block_size=self.config["block_size"],
n_embd=self.config["n_embd"],
n_head=self.config["n_head"],
n_layer=self.config["n_layer"]
).to(self.device)
self.model.load_state_dict(torch.load(self.model_dir / "model.pt", map_location=self.device))
self.model.eval()
print("✅ Model loaded successfully!")
def chat(self):
print("\n===== AI Assistant Ready =====")
print("Type 'quit' or 'exit' to end the chat.\n")
while True:
user_input = input("user: ")
if user_input.lower() in ["quit", "exit"]:
break
prompt = f"user: {user_input}\nai:"
input_ids = self.tokenizer.encode(prompt).ids
input_tensor = torch.tensor([input_ids], dtype=torch.long, device=self.device)
with torch.no_grad():
output_ids = self.model.generate(
input_tensor,
max_new_tokens=150,
temperature=0.7,
top_k=40,
stop_token=self.end_token_id
)
response_ids = output_ids[0, len(input_ids):].tolist()
response = self.tokenizer.decode(response_ids)
response = response.replace("<|endoftext|>", "").strip()
print(f"ai: {response}")
# Main execution
if __name__ == "__main__":
model_folder = "aglm"
if os.path.exists(model_folder) and os.path.exists(os.path.join(model_folder, "model.pt")):
chat_bot = ChatInterface(model_dir=model_folder)
chat_bot.chat()
else:
print(f"\nERROR: Model directory '{model_folder}' not found. Please ensure the model is trained and the directory contains 'model.pt' and 'config.json'.")