# app.py # Simple Gradio app to run inference with Flan-T5 models (text2text-generation) import gradio as gr from transformers import pipeline import torch import os # Cache pipelines for models so we don't reload on every request PIPES = {} DEFAULT_MODELS = { "flan-t5-small": "google/flan-t5-small", "flan-t5-base": "google/flan-t5-base", "flan-t5-large": "google/flan-t5-large", # you can add "flan-t5-xl" or others if your Space has enough RAM/GPU } def get_device(): return 0 if torch.cuda.is_available() else -1 def get_pipeline(model_key_or_name): """ Returns a transformers pipeline for the given model. model_key_or_name: either a key from DEFAULT_MODELS or a full model name. """ model_name = DEFAULT_MODELS.get(model_key_or_name, model_key_or_name) if model_name in PIPES: return PIPES[model_name] device = get_device() # pipeline will handle tokenizer/model download pipe = pipeline( "text2text-generation", model=model_name, tokenizer=model_name, device=device, # trust_remote_code=False by default; for official Flan-T5 models this is fine ) PIPES[model_name] = pipe return pipe def generate(prompt: str, model_choice: str, max_length: int, temperature: float, num_return_sequences: int): """ Generate text from the prompt using the selected Flan-T5 model. """ if not prompt or not prompt.strip(): return "Please enter a prompt." try: pipe = get_pipeline(model_choice) except Exception as e: return f"Failed to load model {model_choice}: {e}" # transformers pipeline arguments: do_sample = temperature > 0.0 try: outputs = pipe( prompt, max_length=max_length, do_sample=do_sample, temperature=float(temperature), num_return_sequences=int(num_return_sequences), # return_full_text=False is default for text2text-generation ) except Exception as e: return f"Generation failed: {e}" # outputs is a list of dicts with key 'generated_text' texts = [o.get("generated_text", "") for o in outputs] # Join multiple outputs with separators return "\n\n---\n\n".join(texts) with gr.Blocks(title="Flan-T5 Inference (Text2Text)") as demo: gr.Markdown( """ # Flan-T5 Text2Text Inference Type your prompt and pick a Flan-T5 model. For best performance, enable a GPU in the Space settings. """ ) with gr.Row(): with gr.Column(scale=3): prompt = gr.Textbox( lines=8, label="Input prompt", placeholder="e.g. Summarize the following article in 2 sentences: ..." ) examples = gr.Examples( examples=[ ["Summarize the key points of the American Declaration of Independence."], ["Translate the following English sentence to French: 'The weather is nice today.'"], ["Explain in simple terms how photosynthesis works."], ], inputs=prompt ) with gr.Column(scale=1): model_choice = gr.Dropdown(list(DEFAULT_MODELS.keys()), value="flan-t5-base", label="Model") max_length = gr.Slider(32, 1024, value=256, step=1, label="Max length (tokens)") temperature = gr.Slider(0.0, 1.5, value=0.0, step=0.01, label="Temperature (0.0 = deterministic)") num_return_sequences = gr.Slider(1, 3, value=1, step=1, label="Number of outputs") run_btn = gr.Button("Generate") output = gr.Textbox(label="Model output", lines=12) run_btn.click( fn=generate, inputs=[prompt, model_choice, max_length, temperature, num_return_sequences], outputs=output, ) gr.Markdown( """ Notes: - If you want faster generation, set Space to use a GPU (in Settings → Hardware). - Larger models (flan-t5-large / flan-t5-xl) need more RAM — they may OOM on CPU or free GPU tiers. - You can add other models to DEFAULT_MODELS above or input a full model name from the Hub. """ ) if __name__ == "__main__": demo.launch()