File size: 6,466 Bytes
0cf2212
 
 
 
 
 
 
b101b15
0cf2212
b101b15
 
 
 
 
 
 
0cf2212
b101b15
 
 
 
 
 
 
 
 
 
 
 
 
 
0cf2212
 
 
 
b101b15
0cf2212
 
 
 
 
b101b15
0cf2212
 
 
 
 
 
 
 
b101b15
 
 
0cf2212
 
b101b15
 
 
 
 
 
0cf2212
 
 
 
b101b15
0cf2212
 
 
 
 
b101b15
0cf2212
 
b101b15
 
 
 
 
 
 
 
 
 
 
 
 
0cf2212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b101b15
 
0cf2212
 
 
 
 
 
 
 
 
b101b15
 
 
 
 
 
 
 
0cf2212
 
b101b15
 
 
 
 
 
 
0cf2212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b101b15
0cf2212
 
 
 
 
 
b101b15
0cf2212
 
 
 
 
b101b15
0cf2212
b101b15
0cf2212
 
 
 
 
 
b101b15
0cf2212
 
b101b15
 
 
 
0cf2212
 
 
 
 
 
b101b15
0cf2212
 
 
 
 
 
 
 
 
b101b15
 
 
 
 
 
 
 
0cf2212
 
b101b15
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
import gradio as gr
import numpy as np
import random
# import spaces #[uncomment to use ZeroGPU]
from diffusers import DiffusionPipeline
import torch

# --- Model and Device Configuration ---

# Global dictionary to cache loaded models, preventing re-loading.
pipelines = {}
# Mapping of user-friendly names to Hugging Face model repository IDs.
MODEL_MAP = {
    "SDXL-Turbo": "stabilityai/sdxl-turbo",
    "Nano-Banana": "emilianJR/nano-banana-base-1.0"
}

device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

# This function loads a model if it's not already in our cache
def get_pipeline(model_name: str):
    """Loads and caches a diffusion pipeline based on the model name."""
    repo_id = MODEL_MAP[model_name]
    if repo_id not in pipelines:
        print(f"Loading model: {repo_id}...")
        pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch_dtype, variant="fp16" if torch.cuda.is_available() else "fp32")
        pipe.to(device)
        pipelines[repo_id] = pipe
        print("Model loaded successfully.")
    return pipelines[repo_id]

MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024

# --- Inference Function ---

# @spaces.GPU #[uncomment to use ZeroGPU]
def infer(
    prompt,
    negative_prompt,
    model_selection,  # New parameter to select the model
    seed,
    randomize_seed,
    width,
    height,
    guidance_scale,
    num_inference_steps,
    progress=gr.Progress(track_tqdm=True),
):
    # Load the selected pipeline
    pipe = get_pipeline(model_selection)
    
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
        
    generator = torch.Generator(device=device).manual_seed(seed)
    
    # SDXL-Turbo does not use guidance_scale, so we set it to 0.0 if that model is selected.
    # Other models might need it.
    effective_guidance_scale = 0.0 if model_selection == "SDXL-Turbo" else guidance_scale

    image = pipe(
        prompt=prompt,
        negative_prompt=negative_prompt,
        guidance_scale=effective_guidance_scale,
        num_inference_steps=num_inference_steps,
        width=width,
        height=height,
        generator=generator,
    ).images[0]
    
    return image, seed

# --- UI Helper Function ---

def update_settings_for_model(model_selection: str):
    """Updates the UI with recommended settings for the chosen model."""
    if model_selection == "SDXL-Turbo":
        # SDXL-Turbo works best with low steps and no guidance
        return gr.Slider(value=0.0), gr.Slider(value=2)
    elif model_selection == "Nano-Banana":
        # A more standard SDXL setup
        return gr.Slider(value=7.5), gr.Slider(value=25)
    return gr.Slider(), gr.Slider() # Default empty update

# --- Gradio UI Layout ---

examples = [
    "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
    "An astronaut riding a green horse",
    "A delicious ceviche cheesecake slice",
]

css = """
#col-container {
    margin: 0 auto;
    max-width: 640px;
}
"""

with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown("# Text-to-Image with Model Switching")
        
        with gr.Row():
            prompt = gr.Text(
                label="Prompt",
                show_label=False,
                max_lines=1,
                placeholder="Enter your prompt",
                container=False,
            )
            run_button = gr.Button("Run", scale=0, variant="primary")
            
        model_selection = gr.Radio(
            label="Select Model",
            choices=list(MODEL_MAP.keys()),
            value="SDXL-Turbo",
        )
            
        result = gr.Image(label="Result", show_label=False, type="pil")

        with gr.Accordion("Advanced Settings", open=False):
            # 1. Added Gemini API Key input box
            gemini_api_key = gr.Textbox(
                label="Gemini API Key",
                placeholder="Enter your Gemini API key here",
                type="password",
                visible=True, # Set to True to make it visible
            )
            negative_prompt = gr.Text(
                label="Negative prompt",
                max_lines=1,
                placeholder="Enter a negative prompt",
            )
            seed = gr.Slider(
                label="Seed",
                minimum=0,
                maximum=MAX_SEED,
                step=1,
                value=0,
            )
            randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
            with gr.Row():
                width = gr.Slider(
                    label="Width",
                    minimum=256,
                    maximum=MAX_IMAGE_SIZE,
                    step=32,
                    value=512, # Changed default to 512 for SDXL-Turbo
                )
                height = gr.Slider(
                    label="Height",
                    minimum=256,
                    maximum=MAX_IMAGE_SIZE,
                    step=32,
                    value=512, # Changed default to 512 for SDXL-Turbo
                )
            with gr.Row():
                guidance_scale = gr.Slider(
                    label="Guidance scale",
                    minimum=0.0,
                    maximum=20.0,
                    step=0.1,
                    value=0.0,  # Default for SDXL-Turbo
                )
                num_inference_steps = gr.Slider(
                    label="Number of inference steps",
                    minimum=1,
                    maximum=50,
                    step=1,
                    value=2,  # Default for SDXL-Turbo
                )
        gr.Examples(examples=examples, inputs=[prompt])

    # --- Event Handlers ---
    
    # Main inference trigger
    gr.on(
        triggers=[run_button.click, prompt.submit],
        fn=infer,
        inputs=[
            prompt,
            negative_prompt,
            model_selection,
            seed,
            randomize_seed,
            width,
            height,
            guidance_scale,
            num_inference_steps,
        ],
        outputs=[result, seed],
    )
    
    # Trigger to update settings when the model selection changes
    model_selection.change(
        fn=update_settings_for_model,
        inputs=model_selection,
        outputs=[guidance_scale, num_inference_steps]
    )


if __name__ == "__main__":
    demo.launch(debug=True)