percobaan_baru / app.py
devagonal's picture
Update app.py
ad13ccd verified
# app.py
# Simple Gradio app to run inference with Flan-T5 models (text2text-generation)
import gradio as gr
from transformers import pipeline
import torch
import os
# Cache pipelines for models so we don't reload on every request
PIPES = {}
DEFAULT_MODELS = {
"flan-t5-small": "google/flan-t5-small",
"flan-t5-base": "google/flan-t5-base",
"flan-t5-large": "google/flan-t5-large",
# you can add "flan-t5-xl" or others if your Space has enough RAM/GPU
}
def get_device():
return 0 if torch.cuda.is_available() else -1
def get_pipeline(model_key_or_name):
"""
Returns a transformers pipeline for the given model.
model_key_or_name: either a key from DEFAULT_MODELS or a full model name.
"""
model_name = DEFAULT_MODELS.get(model_key_or_name, model_key_or_name)
if model_name in PIPES:
return PIPES[model_name]
device = get_device()
# pipeline will handle tokenizer/model download
pipe = pipeline(
"text2text-generation",
model=model_name,
tokenizer=model_name,
device=device,
# trust_remote_code=False by default; for official Flan-T5 models this is fine
)
PIPES[model_name] = pipe
return pipe
def generate(prompt: str, model_choice: str, max_length: int, temperature: float, num_return_sequences: int):
"""
Generate text from the prompt using the selected Flan-T5 model.
"""
if not prompt or not prompt.strip():
return "Please enter a prompt."
try:
pipe = get_pipeline(model_choice)
except Exception as e:
return f"Failed to load model {model_choice}: {e}"
# transformers pipeline arguments:
do_sample = temperature > 0.0
try:
outputs = pipe(
prompt,
max_length=max_length,
do_sample=do_sample,
temperature=float(temperature),
num_return_sequences=int(num_return_sequences),
# return_full_text=False is default for text2text-generation
)
except Exception as e:
return f"Generation failed: {e}"
# outputs is a list of dicts with key 'generated_text'
texts = [o.get("generated_text", "") for o in outputs]
# Join multiple outputs with separators
return "\n\n---\n\n".join(texts)
with gr.Blocks(title="Flan-T5 Inference (Text2Text)") as demo:
gr.Markdown(
"""
# Flan-T5 Text2Text Inference
Type your prompt and pick a Flan-T5 model. For best performance, enable a GPU in the Space settings.
"""
)
with gr.Row():
with gr.Column(scale=3):
prompt = gr.Textbox(
lines=8,
label="Input prompt",
placeholder="e.g. Summarize the following article in 2 sentences: ..."
)
examples = gr.Examples(
examples=[
["Summarize the key points of the American Declaration of Independence."],
["Translate the following English sentence to French: 'The weather is nice today.'"],
["Explain in simple terms how photosynthesis works."],
],
inputs=prompt
)
with gr.Column(scale=1):
model_choice = gr.Dropdown(list(DEFAULT_MODELS.keys()), value="flan-t5-base", label="Model")
max_length = gr.Slider(32, 1024, value=256, step=1, label="Max length (tokens)")
temperature = gr.Slider(0.0, 1.5, value=0.0, step=0.01, label="Temperature (0.0 = deterministic)")
num_return_sequences = gr.Slider(1, 3, value=1, step=1, label="Number of outputs")
run_btn = gr.Button("Generate")
output = gr.Textbox(label="Model output", lines=12)
run_btn.click(
fn=generate,
inputs=[prompt, model_choice, max_length, temperature, num_return_sequences],
outputs=output,
)
gr.Markdown(
"""
Notes:
- If you want faster generation, set Space to use a GPU (in Settings → Hardware).
- Larger models (flan-t5-large / flan-t5-xl) need more RAM — they may OOM on CPU or free GPU tiers.
- You can add other models to DEFAULT_MODELS above or input a full model name from the Hub.
"""
)
if __name__ == "__main__":
demo.launch()