File size: 1,205 Bytes
8e36426
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

def load_model(model_path):
    # Load the model
    model = AutoModelForCausalLM.from_pretrained(model_path)
    # Load the tokenizer (assuming it's saved alongside the model)
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    return model, tokenizer

def generate_text(model, tokenizer, prompt, max_length=100, temperature=0.7):
    # Encode the prompt
    input_ids = tokenizer.encode(prompt, return_tensors="pt")

    # Generate text
    with torch.no_grad():
        output = model.generate(
            input_ids,
            max_length=max_length,
            temperature=temperature,
            num_return_sequences=1,
            pad_token_id=tokenizer.eos_token_id
        )

    # Decode and return the generated text
    generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
    return generated_text

if __name__ == "__main__":
    model_path = "enhanced_transformer_model_500M_final.pth"
    model, tokenizer = load_model(model_path)

    prompt = "Once upon a time"
    generated_text = generate_text(model, tokenizer, prompt)
    print(generated_text)