Spaces:
Sleeping
Sleeping
File size: 4,293 Bytes
ad13ccd 910ef03 ad13ccd 910ef03 ad13ccd 910ef03 ad13ccd 910ef03 ad13ccd 910ef03 ad13ccd 910ef03 ad13ccd 910ef03 ad13ccd 910ef03 ad13ccd 910ef03 ad13ccd 910ef03 ad13ccd 910ef03 ad13ccd 910ef03 ad13ccd 910ef03 ad13ccd 910ef03 ad13ccd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
# app.py
# Simple Gradio app to run inference with Flan-T5 models (text2text-generation)
import gradio as gr
from transformers import pipeline
import torch
import os
# Cache pipelines for models so we don't reload on every request
PIPES = {}
DEFAULT_MODELS = {
"flan-t5-small": "google/flan-t5-small",
"flan-t5-base": "google/flan-t5-base",
"flan-t5-large": "google/flan-t5-large",
# you can add "flan-t5-xl" or others if your Space has enough RAM/GPU
}
def get_device():
return 0 if torch.cuda.is_available() else -1
def get_pipeline(model_key_or_name):
"""
Returns a transformers pipeline for the given model.
model_key_or_name: either a key from DEFAULT_MODELS or a full model name.
"""
model_name = DEFAULT_MODELS.get(model_key_or_name, model_key_or_name)
if model_name in PIPES:
return PIPES[model_name]
device = get_device()
# pipeline will handle tokenizer/model download
pipe = pipeline(
"text2text-generation",
model=model_name,
tokenizer=model_name,
device=device,
# trust_remote_code=False by default; for official Flan-T5 models this is fine
)
PIPES[model_name] = pipe
return pipe
def generate(prompt: str, model_choice: str, max_length: int, temperature: float, num_return_sequences: int):
"""
Generate text from the prompt using the selected Flan-T5 model.
"""
if not prompt or not prompt.strip():
return "Please enter a prompt."
try:
pipe = get_pipeline(model_choice)
except Exception as e:
return f"Failed to load model {model_choice}: {e}"
# transformers pipeline arguments:
do_sample = temperature > 0.0
try:
outputs = pipe(
prompt,
max_length=max_length,
do_sample=do_sample,
temperature=float(temperature),
num_return_sequences=int(num_return_sequences),
# return_full_text=False is default for text2text-generation
)
except Exception as e:
return f"Generation failed: {e}"
# outputs is a list of dicts with key 'generated_text'
texts = [o.get("generated_text", "") for o in outputs]
# Join multiple outputs with separators
return "\n\n---\n\n".join(texts)
with gr.Blocks(title="Flan-T5 Inference (Text2Text)") as demo:
gr.Markdown(
"""
# Flan-T5 Text2Text Inference
Type your prompt and pick a Flan-T5 model. For best performance, enable a GPU in the Space settings.
"""
)
with gr.Row():
with gr.Column(scale=3):
prompt = gr.Textbox(
lines=8,
label="Input prompt",
placeholder="e.g. Summarize the following article in 2 sentences: ..."
)
examples = gr.Examples(
examples=[
["Summarize the key points of the American Declaration of Independence."],
["Translate the following English sentence to French: 'The weather is nice today.'"],
["Explain in simple terms how photosynthesis works."],
],
inputs=prompt
)
with gr.Column(scale=1):
model_choice = gr.Dropdown(list(DEFAULT_MODELS.keys()), value="flan-t5-base", label="Model")
max_length = gr.Slider(32, 1024, value=256, step=1, label="Max length (tokens)")
temperature = gr.Slider(0.0, 1.5, value=0.0, step=0.01, label="Temperature (0.0 = deterministic)")
num_return_sequences = gr.Slider(1, 3, value=1, step=1, label="Number of outputs")
run_btn = gr.Button("Generate")
output = gr.Textbox(label="Model output", lines=12)
run_btn.click(
fn=generate,
inputs=[prompt, model_choice, max_length, temperature, num_return_sequences],
outputs=output,
)
gr.Markdown(
"""
Notes:
- If you want faster generation, set Space to use a GPU (in Settings → Hardware).
- Larger models (flan-t5-large / flan-t5-xl) need more RAM — they may OOM on CPU or free GPU tiers.
- You can add other models to DEFAULT_MODELS above or input a full model name from the Hub.
"""
)
if __name__ == "__main__":
demo.launch() |