File size: 1,171 Bytes
ce1e01a
2925c44
 
 
6d81201
2925c44
 
 
 
 
 
75a978b
13855c2
4aaf034
2925c44
 
 
4aaf034
75a978b
2925c44
 
 
 
ce1e01a
2925c44
 
 
a7be45a
2925c44
 
ce1e01a
 
2925c44
ce1e01a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import gradio as gr
from transformers import pipeline, T5ForConditionalGeneration, T5Tokenizer

# Load T5 model and tokenizer
model_name = "google/flan-t5-large" # t5-base ; google/flan-t5-large
model = T5ForConditionalGeneration.from_pretrained(model_name)
tokenizer = T5Tokenizer.from_pretrained(model_name)

# Define a function to generate text using T5
def generate_text(prompt):
    # Tokenize input and generate output
    input_ids = tokenizer.encode(prompt, return_tensors="pt", max_length=1024, truncation=True)
    #input_ids = tokenizer.encode(prompt, return_tensors="pt").input_ids
    
    output_ids = model.generate(input_ids)
    
    # Decode the generated output
    #generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    
    return generated_text

# Create a Gradio interface
iface = gr.Interface(
    fn=generate_text,
    inputs=gr.Textbox(),
    outputs=gr.Textbox(),
    live=False,
    title="T5 Text Generation",
    description="Enter a prompt, and the model will generate text based on it."
)

# Launch the Gradio interface
iface.launch()