File size: 3,223 Bytes
238d08f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
import torch
import torch.nn.functional as F
import tiktoken
from model import ChatGCLM, MAX_SEQ_LEN

MODEL_PATH = None
for f in os.listdir("."):
    if f.startswith("ChatGCLM_") and f.endswith(".pt"):
        MODEL_PATH = f
        break

if MODEL_PATH is None:
    print("Error: No model checkpoint found!")
    print("Please train the model first with: python3 train_chatgclm.py")
    exit(1)

TOKENIZER_NAME = "gpt2"
EOS_ID = 2

def load_model(device):
    tok = tiktoken.get_encoding(TOKENIZER_NAME)
    vocab_size = tok.n_vocab
    
    model = ChatGCLM(vocab_size).to(device)
    if os.path.exists(MODEL_PATH) and os.path.getsize(MODEL_PATH) > 0:
        print(f"Loading model from: {MODEL_PATH}")
        model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
        model.eval()
        return model, tok
    else:
        print(f"Error: Could not load model from {MODEL_PATH}")
        return None, None

@torch.no_grad()
def generate(model, prompt, tokenizer, device, max_new_tokens=200, temperature=0.8, top_k=50):
    model.eval()
    input_ids = tokenizer.encode(prompt)
    x = torch.tensor([input_ids], dtype=torch.long, device=device)
    
    print(f"\n{'='*70}")
    print(f"PROMPT: {prompt}")
    print(f"{'='*70}")
    print("GENERATED TEXT:")
    print(prompt, end="", flush=True)
    
    generated_tokens = []
    for _ in range(max_new_tokens):
        ctx = x[:, -MAX_SEQ_LEN:] if x.size(1) > MAX_SEQ_LEN else x
        logits = model(ctx)
        next_token_logits = logits[:, -1, :] / temperature
        
        if top_k is not None:
            v, _ = torch.topk(next_token_logits, min(top_k, next_token_logits.size(-1)))
            next_token_logits[next_token_logits < v[:, [-1]]] = -float('Inf')
        
        probs = F.softmax(next_token_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        idx = next_token.item()
        
        if idx == EOS_ID:
            break
        
        x = torch.cat((x, next_token), dim=1)
        generated_tokens.append(idx)
        token_text = tokenizer.decode([idx])
        print(token_text, end="", flush=True)
    
    print(f"\n{'='*70}\n")
    return tokenizer.decode(generated_tokens)

if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
    print(f"Using device: {device}")
    
    model, tokenizer = load_model(device)
    
    if model is None:
        exit(1)
    
    test_prompts = [
        "Once upon a time",
        "The future of AI is",
        "In a world where",
    ]
    
    print("\n" + "="*70)
    print("ChatGCLM Text Generation Demo")
    print("="*70)
    
    for prompt in test_prompts:
        generate(model, prompt, tokenizer, device, max_new_tokens=150, temperature=0.8, top_k=50)
    
    print("\n" + "="*70)
    print("Interactive Mode - Enter your own prompts!")
    print("="*70)
    
    while True:
        user_prompt = input("\nEnter prompt (or 'exit' to quit): ")
        if user_prompt.lower() == 'exit':
            break
        if user_prompt.strip():
            generate(model, user_prompt, tokenizer, device, max_new_tokens=200, temperature=0.8, top_k=50)