| import os, time, torch, warnings | |
| from transformers import GPT2LMHeadModel, GPT2Tokenizer | |
| class Inference(): | |
| def __init__(self, silent=False) -> None: | |
| start_time = time.perf_counter() | |
| self.tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2") | |
| self.model = GPT2LMHeadModel.from_pretrained(self.local_file_path("SaveState")) | |
| self.model.eval() | |
| if not silent: | |
| print(f"Model Loading Took {time.perf_counter()-start_time} Seconds") | |
| def local_file_path(self, path): | |
| return os.path.join(os.path.dirname(os.path.abspath(__file__)), path) | |
| def generate(self, prompt, max_length=2000, temperature=0.5, do_sample=True, stop_token=None, callback=None, silent=True): | |
| with warnings.catch_warnings(): | |
| warnings.simplefilter("ignore") | |
| start_time = time.perf_counter() | |
| input_ids = self.tokenizer.encode(prompt, return_tensors='pt') | |
| generated_text = input_ids | |
| while generated_text.shape[1] < max_length: | |
| length = min(50, max_length - generated_text.shape[1]) | |
| with torch.no_grad(): | |
| outputs = self.model.generate(input_ids, max_length=length, temperature=temperature, do_sample=do_sample, pad_token_id=self.tokenizer.eos_token_id) | |
| new_tokens = outputs[0][-length:] | |
| if callback is not None: | |
| for token in new_tokens: | |
| callback(self.tokenizer.decode([token])) | |
| generated_text = torch.cat((generated_text, new_tokens.unsqueeze(0)), dim=-1) | |
| input_ids = new_tokens.unsqueeze(0) | |
| if stop_token is not None and stop_token in self.tokenizer.decode(generated_text[0]): | |
| break | |
| if not silent: | |
| print(f"Model Loading Took {time.perf_counter()-start_time} Seconds") | |
| return self.tokenizer.decode(generated_text[0], skip_special_tokens=True) | |
| Inference = Inference() | |
| def spec(stre): | |
| print(stre, end="") | |
| if __name__=="__main__": | |
| while True: | |
| print(Inference.generate(input(">>> "), max_length=100, temperature=0.8, silent=True)) |