File size: 1,739 Bytes
b0ddab5
 
 
 
b215b7c
b0ddab5
 
 
 
b215b7c
b0ddab5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b215b7c
b0ddab5
 
 
 
 
 
 
 
 
 
 
 
 
b215b7c
 
b0ddab5
 
b215b7c
 
 
 
 
 
b0ddab5
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
import tiktoken
from model import GPTConfig, GPT

# --- Config ---
ckpt_path = '/home/user/350m_SmaLLMPro_Final/SmaLLMPro_iter_1500.pt'
device = 'cuda'
enc = tiktoken.get_encoding("gpt2")

print("Loading SmaLLMPro...")
checkpoint = torch.load(ckpt_path, map_location=device)
gptconf = GPTConfig(**checkpoint['model_args'])
model = GPT(gptconf)

state_dict = checkpoint['model']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)

model.load_state_dict(state_dict)
model.eval()
model.to(device)
print(f"Model {ckpt_path} ready!\n")

def run_chat():
    print("--- SmaLLMPro Chatbot (Type 'exit' to quit) ---")
    
    while True:
        user_input = input("You: ")
        if user_input.lower() in ["exit", "quit"]:
            break

        prompt = f"Instruction:\n{user_input}\n\nResponse:\n"
        
        x = torch.tensor(enc.encode(prompt), dtype=torch.long, device=device)[None, ...]
        
        print("SmaLLMPro: ", end="", flush=True)
        with torch.no_grad():
            with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
                y = model.generate(x, max_new_tokens=500, temperature=0.65, top_k=25)
                
                full_text = enc.decode(y[0].tolist())
                
                if "Response:\n" in full_text:
                    response = full_text.split("Response:\n")[-1]
                else:
                    response = full_text
                
                response = response.split("<|endoftext|>")[0].split("Instruction:")[0].strip()
                print(response + "\n")

if __name__ == "__main__":
    run_chat()