| import torch | |
| from SimpleRNN import SimpleRNN | |
| import json | |
| from tqdm import tqdm, trange | |
| parameters = json.loads(open("parameter.json").read()) | |
| model_path = parameters["model_path"] | |
| model = torch.load(model_path, weights_only=False) | |
| with open("vocab.json", "r") as f: | |
| chars = json.loads(f.read()) | |
| char_to_idx = {ch: i for i, ch in enumerate(chars)} | |
| idx_to_char = {i: ch for i, ch in enumerate(chars)} | |
| print("Loaded pre-trained model.") | |
| input_size = len(chars) | |
| hidden_size = parameters["hidden_size"] | |
| output_size = len(chars) | |
| def generate_text(start_text, length): | |
| model.eval() | |
| hidden = torch.zeros(1, 1, hidden_size) | |
| input_seq = torch.tensor([char_to_idx[ch] for ch in start_text]) | |
| generated_text = start_text | |
| for _ in trange(length): | |
| output, hidden = model(input_seq, hidden) | |
| predicted_idx = output.argmax().item() | |
| generated_text += idx_to_char[predicted_idx] | |
| input_seq = torch.cat((input_seq[1:], torch.tensor([predicted_idx]))) | |
| return generated_text | |
| while True: | |
| prompt = input("Ask LLM: ") | |
| length = int(input("Length of text: ")) | |
| print("LLM Output: ", generate_text(prompt, length)) | |