| | import torch |
| | import tiktoken |
| | from model import GPTConfig, GPT |
| |
|
| | |
| | ckpt_path = '/home/user/350m_SmaLLMPro_Final/SmaLLMPro_iter_1500.pt' |
| | device = 'cuda' |
| | enc = tiktoken.get_encoding("gpt2") |
| |
|
| | print("Loading SmaLLMPro...") |
| | checkpoint = torch.load(ckpt_path, map_location=device) |
| | gptconf = GPTConfig(**checkpoint['model_args']) |
| | model = GPT(gptconf) |
| |
|
| | state_dict = checkpoint['model'] |
| | unwanted_prefix = '_orig_mod.' |
| | for k,v in list(state_dict.items()): |
| | if k.startswith(unwanted_prefix): |
| | state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) |
| |
|
| | model.load_state_dict(state_dict) |
| | model.eval() |
| | model.to(device) |
| | print(f"Model {ckpt_path} ready!\n") |
| |
|
| | def run_chat(): |
| | print("--- SmaLLMPro Chatbot (Type 'exit' to quit) ---") |
| | |
| | while True: |
| | user_input = input("You: ") |
| | if user_input.lower() in ["exit", "quit"]: |
| | break |
| |
|
| | prompt = f"Instruction:\n{user_input}\n\nResponse:\n" |
| | |
| | x = torch.tensor(enc.encode(prompt), dtype=torch.long, device=device)[None, ...] |
| | |
| | print("SmaLLMPro: ", end="", flush=True) |
| | with torch.no_grad(): |
| | with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16): |
| | y = model.generate(x, max_new_tokens=500, temperature=0.65, top_k=25) |
| | |
| | full_text = enc.decode(y[0].tolist()) |
| | |
| | if "Response:\n" in full_text: |
| | response = full_text.split("Response:\n")[-1] |
| | else: |
| | response = full_text |
| | |
| | response = response.split("<|endoftext|>")[0].split("Instruction:")[0].strip() |
| | print(response + "\n") |
| |
|
| | if __name__ == "__main__": |
| | run_chat() |