File size: 4,293 Bytes
ad13ccd
 
 
910ef03
ad13ccd
 
 
 
 
 
 
 
 
 
 
 
 
910ef03
ad13ccd
 
910ef03
ad13ccd
910ef03
ad13ccd
 
910ef03
ad13ccd
 
 
 
 
 
 
 
 
 
 
 
 
 
910ef03
ad13ccd
 
 
 
 
 
 
 
 
 
910ef03
ad13ccd
 
 
 
 
 
 
 
 
 
 
 
 
910ef03
ad13ccd
 
 
 
910ef03
 
ad13ccd
 
 
 
 
 
 
910ef03
ad13ccd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
910ef03
ad13ccd
910ef03
ad13ccd
 
 
 
 
910ef03
ad13ccd
 
 
 
 
 
 
 
910ef03
 
ad13ccd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# app.py
# Simple Gradio app to run inference with Flan-T5 models (text2text-generation)

import gradio as gr
from transformers import pipeline
import torch
import os

# Cache pipelines for models so we don't reload on every request
PIPES = {}

DEFAULT_MODELS = {
    "flan-t5-small": "google/flan-t5-small",
    "flan-t5-base": "google/flan-t5-base",
    "flan-t5-large": "google/flan-t5-large",
    # you can add "flan-t5-xl" or others if your Space has enough RAM/GPU
}

def get_device():
    return 0 if torch.cuda.is_available() else -1

def get_pipeline(model_key_or_name):
    """
    Returns a transformers pipeline for the given model.
    model_key_or_name: either a key from DEFAULT_MODELS or a full model name.
    """
    model_name = DEFAULT_MODELS.get(model_key_or_name, model_key_or_name)
    if model_name in PIPES:
        return PIPES[model_name]
    device = get_device()
    # pipeline will handle tokenizer/model download
    pipe = pipeline(
        "text2text-generation",
        model=model_name,
        tokenizer=model_name,
        device=device,
        # trust_remote_code=False by default; for official Flan-T5 models this is fine
    )
    PIPES[model_name] = pipe
    return pipe

def generate(prompt: str, model_choice: str, max_length: int, temperature: float, num_return_sequences: int):
    """
    Generate text from the prompt using the selected Flan-T5 model.
    """
    if not prompt or not prompt.strip():
        return "Please enter a prompt."
    try:
        pipe = get_pipeline(model_choice)
    except Exception as e:
        return f"Failed to load model {model_choice}: {e}"

    # transformers pipeline arguments:
    do_sample = temperature > 0.0
    try:
        outputs = pipe(
            prompt,
            max_length=max_length,
            do_sample=do_sample,
            temperature=float(temperature),
            num_return_sequences=int(num_return_sequences),
            # return_full_text=False is default for text2text-generation
        )
    except Exception as e:
        return f"Generation failed: {e}"

    # outputs is a list of dicts with key 'generated_text'
    texts = [o.get("generated_text", "") for o in outputs]
    # Join multiple outputs with separators
    return "\n\n---\n\n".join(texts)


with gr.Blocks(title="Flan-T5 Inference (Text2Text)") as demo:
    gr.Markdown(
        """
        # Flan-T5 Text2Text Inference
        Type your prompt and pick a Flan-T5 model. For best performance, enable a GPU in the Space settings.
        """
    )

    with gr.Row():
        with gr.Column(scale=3):
            prompt = gr.Textbox(
                lines=8,
                label="Input prompt",
                placeholder="e.g. Summarize the following article in 2 sentences: ..."
            )
            examples = gr.Examples(
                examples=[
                    ["Summarize the key points of the American Declaration of Independence."],
                    ["Translate the following English sentence to French: 'The weather is nice today.'"],
                    ["Explain in simple terms how photosynthesis works."],
                ],
                inputs=prompt
            )
        with gr.Column(scale=1):
            model_choice = gr.Dropdown(list(DEFAULT_MODELS.keys()), value="flan-t5-base", label="Model")
            max_length = gr.Slider(32, 1024, value=256, step=1, label="Max length (tokens)")
            temperature = gr.Slider(0.0, 1.5, value=0.0, step=0.01, label="Temperature (0.0 = deterministic)")
            num_return_sequences = gr.Slider(1, 3, value=1, step=1, label="Number of outputs")
            run_btn = gr.Button("Generate")

    output = gr.Textbox(label="Model output", lines=12)

    run_btn.click(
        fn=generate,
        inputs=[prompt, model_choice, max_length, temperature, num_return_sequences],
        outputs=output,
    )

    gr.Markdown(
        """
        Notes:
        - If you want faster generation, set Space to use a GPU (in Settings → Hardware).
        - Larger models (flan-t5-large / flan-t5-xl) need more RAM — they may OOM on CPU or free GPU tiers.
        - You can add other models to DEFAULT_MODELS above or input a full model name from the Hub.
        """
    )

if __name__ == "__main__":
    demo.launch()