| import torch | |
| from models.gem_model import GEM | |
| from utils.data_preprocessing import load_tokenizer | |
| from configs.config import MODEL_CONFIG | |
| def generate_text(model, tokenizer, prompt, max_length=100, temperature=0.7): | |
| device = torch.device(MODEL_CONFIG['DEVICE']) | |
| input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device) | |
| generated = model.generate(input_ids, max_length=max_length, temperature=temperature) | |
| return tokenizer.decode(generated[0], skip_special_tokens=True) | |
| def main(): | |
| device = torch.device(MODEL_CONFIG['DEVICE']) | |
| tokenizer = load_tokenizer() | |
| model = GEM( | |
| vocab_size=MODEL_CONFIG['VOCAB_SIZE'], | |
| d_model=MODEL_CONFIG['D_MODEL'], | |
| n_heads=MODEL_CONFIG['N_HEADS'], | |
| d_ff=MODEL_CONFIG['D_FF'], | |
| n_layers=MODEL_CONFIG['N_LAYERS'], | |
| max_seq_len=MODEL_CONFIG['MAX_SEQ_LEN'], | |
| dropout=MODEL_CONFIG['DROPOUT'] | |
| ).to(device) | |
| checkpoint = torch.load('final_model/model.pt') | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| model.eval() | |
| prompt = "Once upon a time" | |
| generated_text = generate_text(model, tokenizer, prompt, max_length=100) | |
| print(f"Generated text:\n{generated_text}") | |
| if __name__ == "__main__": | |
| main() | |