import torch import tiktoken from model import GPTConfig, GPT # --- Config --- ckpt_path = '/home/user/350m_SmaLLMPro_Final/SmaLLMPro_iter_1500.pt' device = 'cuda' enc = tiktoken.get_encoding("gpt2") print("Loading SmaLLMPro...") checkpoint = torch.load(ckpt_path, map_location=device) gptconf = GPTConfig(**checkpoint['model_args']) model = GPT(gptconf) state_dict = checkpoint['model'] unwanted_prefix = '_orig_mod.' for k,v in list(state_dict.items()): if k.startswith(unwanted_prefix): state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) model.load_state_dict(state_dict) model.eval() model.to(device) print(f"Model {ckpt_path} ready!\n") def run_chat(): print("--- SmaLLMPro Chatbot (Type 'exit' to quit) ---") while True: user_input = input("You: ") if user_input.lower() in ["exit", "quit"]: break prompt = f"Instruction:\n{user_input}\n\nResponse:\n" x = torch.tensor(enc.encode(prompt), dtype=torch.long, device=device)[None, ...] print("SmaLLMPro: ", end="", flush=True) with torch.no_grad(): with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16): y = model.generate(x, max_new_tokens=500, temperature=0.65, top_k=25) full_text = enc.decode(y[0].tolist()) if "Response:\n" in full_text: response = full_text.split("Response:\n")[-1] else: response = full_text response = response.split("<|endoftext|>")[0].split("Instruction:")[0].strip() print(response + "\n") if __name__ == "__main__": run_chat()