Spaces:
Sleeping
Sleeping
| # app.py | |
| # Simple Gradio app to run inference with Flan-T5 models (text2text-generation) | |
| import gradio as gr | |
| from transformers import pipeline | |
| import torch | |
| import os | |
| # Cache pipelines for models so we don't reload on every request | |
| PIPES = {} | |
| DEFAULT_MODELS = { | |
| "flan-t5-small": "google/flan-t5-small", | |
| "flan-t5-base": "google/flan-t5-base", | |
| "flan-t5-large": "google/flan-t5-large", | |
| # you can add "flan-t5-xl" or others if your Space has enough RAM/GPU | |
| } | |
| def get_device(): | |
| return 0 if torch.cuda.is_available() else -1 | |
| def get_pipeline(model_key_or_name): | |
| """ | |
| Returns a transformers pipeline for the given model. | |
| model_key_or_name: either a key from DEFAULT_MODELS or a full model name. | |
| """ | |
| model_name = DEFAULT_MODELS.get(model_key_or_name, model_key_or_name) | |
| if model_name in PIPES: | |
| return PIPES[model_name] | |
| device = get_device() | |
| # pipeline will handle tokenizer/model download | |
| pipe = pipeline( | |
| "text2text-generation", | |
| model=model_name, | |
| tokenizer=model_name, | |
| device=device, | |
| # trust_remote_code=False by default; for official Flan-T5 models this is fine | |
| ) | |
| PIPES[model_name] = pipe | |
| return pipe | |
| def generate(prompt: str, model_choice: str, max_length: int, temperature: float, num_return_sequences: int): | |
| """ | |
| Generate text from the prompt using the selected Flan-T5 model. | |
| """ | |
| if not prompt or not prompt.strip(): | |
| return "Please enter a prompt." | |
| try: | |
| pipe = get_pipeline(model_choice) | |
| except Exception as e: | |
| return f"Failed to load model {model_choice}: {e}" | |
| # transformers pipeline arguments: | |
| do_sample = temperature > 0.0 | |
| try: | |
| outputs = pipe( | |
| prompt, | |
| max_length=max_length, | |
| do_sample=do_sample, | |
| temperature=float(temperature), | |
| num_return_sequences=int(num_return_sequences), | |
| # return_full_text=False is default for text2text-generation | |
| ) | |
| except Exception as e: | |
| return f"Generation failed: {e}" | |
| # outputs is a list of dicts with key 'generated_text' | |
| texts = [o.get("generated_text", "") for o in outputs] | |
| # Join multiple outputs with separators | |
| return "\n\n---\n\n".join(texts) | |
| with gr.Blocks(title="Flan-T5 Inference (Text2Text)") as demo: | |
| gr.Markdown( | |
| """ | |
| # Flan-T5 Text2Text Inference | |
| Type your prompt and pick a Flan-T5 model. For best performance, enable a GPU in the Space settings. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| prompt = gr.Textbox( | |
| lines=8, | |
| label="Input prompt", | |
| placeholder="e.g. Summarize the following article in 2 sentences: ..." | |
| ) | |
| examples = gr.Examples( | |
| examples=[ | |
| ["Summarize the key points of the American Declaration of Independence."], | |
| ["Translate the following English sentence to French: 'The weather is nice today.'"], | |
| ["Explain in simple terms how photosynthesis works."], | |
| ], | |
| inputs=prompt | |
| ) | |
| with gr.Column(scale=1): | |
| model_choice = gr.Dropdown(list(DEFAULT_MODELS.keys()), value="flan-t5-base", label="Model") | |
| max_length = gr.Slider(32, 1024, value=256, step=1, label="Max length (tokens)") | |
| temperature = gr.Slider(0.0, 1.5, value=0.0, step=0.01, label="Temperature (0.0 = deterministic)") | |
| num_return_sequences = gr.Slider(1, 3, value=1, step=1, label="Number of outputs") | |
| run_btn = gr.Button("Generate") | |
| output = gr.Textbox(label="Model output", lines=12) | |
| run_btn.click( | |
| fn=generate, | |
| inputs=[prompt, model_choice, max_length, temperature, num_return_sequences], | |
| outputs=output, | |
| ) | |
| gr.Markdown( | |
| """ | |
| Notes: | |
| - If you want faster generation, set Space to use a GPU (in Settings → Hardware). | |
| - Larger models (flan-t5-large / flan-t5-xl) need more RAM — they may OOM on CPU or free GPU tiers. | |
| - You can add other models to DEFAULT_MODELS above or input a full model name from the Hub. | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |