LH-Tech-AI's picture
Update chat.py
b215b7c verified
import torch
import tiktoken
from model import GPTConfig, GPT
# --- Config ---
ckpt_path = '/home/user/350m_SmaLLMPro_Final/SmaLLMPro_iter_1500.pt'
device = 'cuda'
enc = tiktoken.get_encoding("gpt2")
print("Loading SmaLLMPro...")
checkpoint = torch.load(ckpt_path, map_location=device)
gptconf = GPTConfig(**checkpoint['model_args'])
model = GPT(gptconf)
state_dict = checkpoint['model']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)
model.eval()
model.to(device)
print(f"Model {ckpt_path} ready!\n")
def run_chat():
print("--- SmaLLMPro Chatbot (Type 'exit' to quit) ---")
while True:
user_input = input("You: ")
if user_input.lower() in ["exit", "quit"]:
break
prompt = f"Instruction:\n{user_input}\n\nResponse:\n"
x = torch.tensor(enc.encode(prompt), dtype=torch.long, device=device)[None, ...]
print("SmaLLMPro: ", end="", flush=True)
with torch.no_grad():
with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
y = model.generate(x, max_new_tokens=500, temperature=0.65, top_k=25)
full_text = enc.decode(y[0].tolist())
if "Response:\n" in full_text:
response = full_text.split("Response:\n")[-1]
else:
response = full_text
response = response.split("<|endoftext|>")[0].split("Instruction:")[0].strip()
print(response + "\n")
if __name__ == "__main__":
run_chat()