import torch from transformers import AutoModelForCausalLM, AutoTokenizer def load_model(model_path): # Load the model model = AutoModelForCausalLM.from_pretrained(model_path) # Load the tokenizer (assuming it's saved alongside the model) tokenizer = AutoTokenizer.from_pretrained(model_path) return model, tokenizer def generate_text(model, tokenizer, prompt, max_length=100, temperature=0.7): # Encode the prompt input_ids = tokenizer.encode(prompt, return_tensors="pt") # Generate text with torch.no_grad(): output = model.generate( input_ids, max_length=max_length, temperature=temperature, num_return_sequences=1, pad_token_id=tokenizer.eos_token_id ) # Decode and return the generated text generated_text = tokenizer.decode(output[0], skip_special_tokens=True) return generated_text if __name__ == "__main__": model_path = "enhanced_transformer_model_500M_final.pth" model, tokenizer = load_model(model_path) prompt = "Once upon a time" generated_text = generate_text(model, tokenizer, prompt) print(generated_text)