ChatGCLM-330M / sample.py
AGofficial's picture
Upload 6 files
238d08f verified
import os
import torch
import torch.nn.functional as F
import tiktoken
from model import ChatGCLM, MAX_SEQ_LEN
MODEL_PATH = None
for f in os.listdir("."):
if f.startswith("ChatGCLM_") and f.endswith(".pt"):
MODEL_PATH = f
break
if MODEL_PATH is None:
print("Error: No model checkpoint found!")
print("Please train the model first with: python3 train_chatgclm.py")
exit(1)
TOKENIZER_NAME = "gpt2"
EOS_ID = 2
def load_model(device):
tok = tiktoken.get_encoding(TOKENIZER_NAME)
vocab_size = tok.n_vocab
model = ChatGCLM(vocab_size).to(device)
if os.path.exists(MODEL_PATH) and os.path.getsize(MODEL_PATH) > 0:
print(f"Loading model from: {MODEL_PATH}")
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model.eval()
return model, tok
else:
print(f"Error: Could not load model from {MODEL_PATH}")
return None, None
@torch.no_grad()
def generate(model, prompt, tokenizer, device, max_new_tokens=200, temperature=0.8, top_k=50):
model.eval()
input_ids = tokenizer.encode(prompt)
x = torch.tensor([input_ids], dtype=torch.long, device=device)
print(f"\n{'='*70}")
print(f"PROMPT: {prompt}")
print(f"{'='*70}")
print("GENERATED TEXT:")
print(prompt, end="", flush=True)
generated_tokens = []
for _ in range(max_new_tokens):
ctx = x[:, -MAX_SEQ_LEN:] if x.size(1) > MAX_SEQ_LEN else x
logits = model(ctx)
next_token_logits = logits[:, -1, :] / temperature
if top_k is not None:
v, _ = torch.topk(next_token_logits, min(top_k, next_token_logits.size(-1)))
next_token_logits[next_token_logits < v[:, [-1]]] = -float('Inf')
probs = F.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
idx = next_token.item()
if idx == EOS_ID:
break
x = torch.cat((x, next_token), dim=1)
generated_tokens.append(idx)
token_text = tokenizer.decode([idx])
print(token_text, end="", flush=True)
print(f"\n{'='*70}\n")
return tokenizer.decode(generated_tokens)
if __name__ == "__main__":
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")
model, tokenizer = load_model(device)
if model is None:
exit(1)
test_prompts = [
"Once upon a time",
"The future of AI is",
"In a world where",
]
print("\n" + "="*70)
print("ChatGCLM Text Generation Demo")
print("="*70)
for prompt in test_prompts:
generate(model, prompt, tokenizer, device, max_new_tokens=150, temperature=0.8, top_k=50)
print("\n" + "="*70)
print("Interactive Mode - Enter your own prompts!")
print("="*70)
while True:
user_prompt = input("\nEnter prompt (or 'exit' to quit): ")
if user_prompt.lower() == 'exit':
break
if user_prompt.strip():
generate(model, user_prompt, tokenizer, device, max_new_tokens=200, temperature=0.8, top_k=50)