import torch import tiktoken import os from model import GPTConfig, GPT out_dir = 'out' device = 'cuda' if torch.cuda.is_available() else 'cpu' ckpt_path = os.path.join(out_dir, 'ckpt.pt') checkpoint = torch.load(ckpt_path, map_location=device) gptconf = GPTConfig(**checkpoint['model_args']) model = GPT(gptconf) state_dict = checkpoint['model'] unwanted_prefix = '_orig_mod.' for k,v in list(state_dict.items()): if k.startswith(unwanted_prefix): state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) model.load_state_dict(state_dict) model.to(device) model.eval() enc = tiktoken.get_encoding("gpt2") EOS_TOKEN_ID = 50256 def ask_gpt(prompt, max_new_tokens=150, temperature=0.7, top_k=25): start_ids = enc.encode(prompt) x = torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...] with torch.no_grad(): y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k) full_ids = y[0].tolist() new_ids = full_ids[len(start_ids):] response = enc.decode(new_ids) response = response.split('<|endoftext|>')[0] return response print("--- Crest Completion Chat started ---") while True: user_input = input("\nYour Prompt: ") if user_input.lower() in ['exit', 'quit']: break antwort_rest = ask_gpt(user_input) print(f"\nCrest Completion: {user_input}{antwort_rest}") print("-" * 30)