| import torch | |
| from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config | |
| from app.core.config import settings | |
| class StoryGenerator: | |
| def __init__(self, model_path=settings.MODEL_PATH): | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| print(f"Using device: {self.device}") | |
| self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2') | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| self.config = GPT2Config( | |
| vocab_size=self.tokenizer.vocab_size, | |
| n_positions=256, | |
| n_ctx=256, | |
| n_embd=256, | |
| n_layer=4, | |
| n_head=8 | |
| ) | |
| self.model = GPT2LMHeadModel(self.config) | |
| self.load_model(model_path) | |
| self.model.to(self.device) | |
| self.best_loss = float('inf') | |
| def load_model(self, path): | |
| checkpoint = torch.load(path, map_location=self.device) | |
| self.model.load_state_dict(checkpoint['model_state_dict']) | |
| def generate_story(self, prompt, max_length=200, temperature=0.7): | |
| self.model.eval() | |
| input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.device) | |
| with torch.no_grad(): | |
| output_ids = self.model.generate( | |
| input_ids, | |
| max_length=max_length, | |
| num_return_sequences=1, | |
| do_sample=True, | |
| top_k=50, | |
| top_p=0.95, | |
| num_beams=1, | |
| temperature=temperature, | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| eos_token_id=self.tokenizer.eos_token_id, | |
| length_penalty=1.0, | |
| no_repeat_ngram_size=3, | |
| early_stopping=True | |
| )[0] | |
| return self.tokenizer.decode(output_ids, skip_special_tokens=True) | |