|
|
import torch |
|
|
import json |
|
|
from torch.nn import functional as F |
|
|
from tokenizers import Tokenizer |
|
|
from pathlib import Path |
|
|
import os |
|
|
|
|
|
|
|
|
class LightweightGPT(torch.nn.Module): |
|
|
"""Compact GPT-like model with causal masking and positional encoding""" |
|
|
def __init__(self, vocab_size, block_size, n_embd, n_head, n_layer): |
|
|
super().__init__() |
|
|
self.block_size = block_size |
|
|
self.token_embedding = torch.nn.Embedding(vocab_size, n_embd) |
|
|
self.position_embedding = torch.nn.Embedding(block_size, n_embd) |
|
|
|
|
|
self.blocks = torch.nn.ModuleList([ |
|
|
torch.nn.TransformerDecoderLayer( |
|
|
d_model=n_embd, |
|
|
nhead=n_head, |
|
|
dim_feedforward=4 * n_embd, |
|
|
dropout=0.1, |
|
|
activation='gelu', |
|
|
batch_first=True, |
|
|
norm_first=True |
|
|
) |
|
|
for _ in range(n_layer) |
|
|
]) |
|
|
self.ln_f = torch.nn.LayerNorm(n_embd) |
|
|
self.lm_head = torch.nn.Linear(n_embd, vocab_size, bias=False) |
|
|
|
|
|
def forward(self, idx, targets=None): |
|
|
B, T = idx.shape |
|
|
device = idx.device |
|
|
causal_mask = torch.triu(torch.ones(T, T, device=device, dtype=torch.bool), diagonal=1) |
|
|
|
|
|
token_emb = self.token_embedding(idx) |
|
|
pos = torch.arange(0, T, dtype=torch.long, device=device) |
|
|
pos_emb = self.position_embedding(pos) |
|
|
|
|
|
x = token_emb + pos_emb |
|
|
|
|
|
for block in self.blocks: |
|
|
x = block(x, x, tgt_mask=causal_mask) |
|
|
|
|
|
x = self.ln_f(x) |
|
|
logits = self.lm_head(x) |
|
|
|
|
|
loss = None |
|
|
if targets is not None: |
|
|
loss = F.cross_entropy( |
|
|
logits.view(-1, logits.size(-1)), |
|
|
targets.view(-1), |
|
|
ignore_index=-1 |
|
|
) |
|
|
return logits, loss |
|
|
|
|
|
def generate(self, idx, max_new_tokens, temperature=0.8, top_k=50, stop_token=None): |
|
|
"""Generate text with context handling and positional encoding""" |
|
|
for _ in range(max_new_tokens): |
|
|
idx_cond = idx[:, -self.block_size:] |
|
|
logits, _ = self(idx_cond) |
|
|
logits = logits[:, -1, :] |
|
|
logits = logits / temperature |
|
|
|
|
|
if top_k is not None: |
|
|
v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
|
|
logits[logits < v[:, [-1]]] = -float('Inf') |
|
|
|
|
|
probs = F.softmax(logits, dim=-1) |
|
|
idx_next = torch.multinomial(probs, num_samples=1) |
|
|
|
|
|
if stop_token is not None and idx_next.item() == stop_token: |
|
|
break |
|
|
|
|
|
idx = torch.cat((idx, idx_next), dim=1) |
|
|
|
|
|
return idx |
|
|
|
|
|
|
|
|
class ChatInterface: |
|
|
def __init__(self, model_dir="/"): |
|
|
self.model_dir = Path(model_dir) |
|
|
self.device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu" |
|
|
self.load_model() |
|
|
|
|
|
def load_model(self): |
|
|
with open(self.model_dir / "config.json", "r") as f: |
|
|
self.config = json.load(f) |
|
|
|
|
|
self.tokenizer = Tokenizer.from_file(str(self.model_dir / "tokenizer.json")) |
|
|
self.end_token_id = self.config.get("end_token_id") |
|
|
|
|
|
self.model = LightweightGPT( |
|
|
vocab_size=self.config["vocab_size"], |
|
|
block_size=self.config["block_size"], |
|
|
n_embd=self.config["n_embd"], |
|
|
n_head=self.config["n_head"], |
|
|
n_layer=self.config["n_layer"] |
|
|
).to(self.device) |
|
|
|
|
|
self.model.load_state_dict(torch.load(self.model_dir / "model.pt", map_location=self.device)) |
|
|
self.model.eval() |
|
|
print("✅ Model loaded successfully!") |
|
|
|
|
|
def chat(self): |
|
|
print("\n===== AI Assistant Ready =====") |
|
|
print("Type 'quit' or 'exit' to end the chat.\n") |
|
|
|
|
|
while True: |
|
|
user_input = input("user: ") |
|
|
if user_input.lower() in ["quit", "exit"]: |
|
|
break |
|
|
|
|
|
prompt = f"user: {user_input}\nai:" |
|
|
input_ids = self.tokenizer.encode(prompt).ids |
|
|
input_tensor = torch.tensor([input_ids], dtype=torch.long, device=self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
output_ids = self.model.generate( |
|
|
input_tensor, |
|
|
max_new_tokens=150, |
|
|
temperature=0.7, |
|
|
top_k=40, |
|
|
stop_token=self.end_token_id |
|
|
) |
|
|
|
|
|
response_ids = output_ids[0, len(input_ids):].tolist() |
|
|
response = self.tokenizer.decode(response_ids) |
|
|
response = response.replace("<|endoftext|>", "").strip() |
|
|
|
|
|
print(f"ai: {response}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
model_folder = "aglm" |
|
|
if os.path.exists(model_folder) and os.path.exists(os.path.join(model_folder, "model.pt")): |
|
|
chat_bot = ChatInterface(model_dir=model_folder) |
|
|
chat_bot.chat() |
|
|
else: |
|
|
print(f"\nERROR: Model directory '{model_folder}' not found. Please ensure the model is trained and the directory contains 'model.pt' and 'config.json'.") |