DiZH797 commited on
Commit
f85cea0
·
verified ·
1 Parent(s): bf47080

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +263 -0
app.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import random
4
+
5
+ # import spaces #[uncomment to use ZeroGPU]
6
+ from diffusers import DiffusionPipeline
7
+ import torch
8
+ from typing import Optional
9
+
10
+ # кэш для пайплайнов (чтобы не перезагружать модель при каждом запросе)
11
+ PIPE_CACHE: dict[str, DiffusionPipeline] = {}
12
+ DEFAULT_MODEL = "CompVis/stable-diffusion-v1-4"
13
+ BASE_MODEL_FOR_LORA = "CompVis/stable-diffusion-v1-4" # Base model used for LoRA training
14
+ LORA_MODEL_ID = "DiZH797/SberDiffusionModelsLora" # Your uploaded LoRA model ID
15
+ MODEL_OPTIONS = [
16
+ "CompVis/stable-diffusion-v1-4",
17
+ "stabilityai/stable-diffusion-2-1",
18
+ "stabilityai/sdxl-turbo",
19
+ LORA_MODEL_ID
20
+ ]
21
+
22
+
23
+ device = "cuda" if torch.cuda.is_available() else "cpu"
24
+ model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
25
+
26
+ if torch.cuda.is_available():
27
+ torch_dtype = torch.float16
28
+ else:
29
+ torch_dtype = torch.float32
30
+
31
+
32
+ # pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
33
+ # pipe = pipe.to(device)
34
+
35
+ MAX_SEED = np.iinfo(np.int32).max
36
+ MAX_IMAGE_SIZE = 1024
37
+
38
+ def get_pipe(model_id: str, lora_scale: float = 1.0):
39
+ """
40
+ Loads the pipeline for a given model ID.
41
+ If the selected model is the LoRA, it loads the base model and then merges the LoRA weights.
42
+ """
43
+ cache_key = f"{model_id}_{lora_scale}"
44
+
45
+ if cache_key in PIPE_CACHE:
46
+ return PIPE_CACHE[cache_key]
47
+
48
+ # Check if the selected model is the LoRA adapter
49
+ if model_id == LORA_MODEL_ID:
50
+ # Load the base model for LoRA
51
+ pipe = DiffusionPipeline.from_pretrained(
52
+ BASE_MODEL_FOR_LORA,
53
+ torch_dtype=torch_dtype
54
+ ).to(device)
55
+ # Load and merge the LoRA weights with the specified scale
56
+ pipe.load_lora_weights(LORA_MODEL_ID)
57
+ pipe.fuse_lora(lora_scale=lora_scale)
58
+ else:
59
+ # Load a standard model without LoRA
60
+ pipe = DiffusionPipeline.from_pretrained(
61
+ model_id,
62
+ torch_dtype=torch_dtype
63
+ ).to(device)
64
+
65
+ PIPE_CACHE[cache_key] = pipe
66
+ return pipe
67
+
68
+ # @spaces.GPU #[uncomment to use ZeroGPU]
69
+ def infer(
70
+ model_id: Optional[str] = DEFAULT_MODEL,
71
+ prompt: str = "",
72
+ negative_prompt: str = "",
73
+ seed: int = 42,
74
+ randomize_seed: bool = False,
75
+ width: int = 512,
76
+ height: int = 512,
77
+ guidance_scale: float = 7.0,
78
+ num_inference_steps: int = 20,
79
+ scheduler_name: Optional[str] = None,
80
+ progress=gr.Progress(track_tqdm=True),
81
+ ):
82
+ # получаем/загружаем нужный pipe
83
+ pipe = get_pipe(model_id)
84
+
85
+ # при желании можно подменить scheduler по имени (опционально)
86
+ if scheduler_name:
87
+ # примерная схема: словарь name->класс scheduler
88
+ # при необходимости добавить другие scheduler'ы — импортируйте их сверху и добавьте сюда
89
+ try:
90
+ from diffusers import DDIMScheduler, EulerAncestralDiscreteScheduler, PNDMScheduler
91
+ sched_map = {
92
+ "DDIM": DDIMScheduler,
93
+ "EulerAncestral": EulerAncestralDiscreteScheduler,
94
+ "PNDM": PNDMScheduler,
95
+ }
96
+ if scheduler_name in sched_map:
97
+ pipe.scheduler = sched_map[scheduler_name].from_config(pipe.scheduler.config)
98
+ except Exception:
99
+ # если что-то пошло не так — просто используем дефолтный scheduler
100
+ pass
101
+
102
+ if randomize_seed:
103
+ seed = random.randint(0, MAX_SEED)
104
+
105
+ generator = torch.Generator().manual_seed(int(seed))
106
+
107
+ image = pipe(
108
+ prompt=prompt,
109
+ negative_prompt=negative_prompt,
110
+ guidance_scale=guidance_scale,
111
+ num_inference_steps=num_inference_steps,
112
+ width=width,
113
+ height=height,
114
+ generator=generator,
115
+ ).images[0]
116
+
117
+ return image, seed
118
+
119
+
120
+
121
+ examples = [
122
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
123
+ "An astronaut riding a green horse",
124
+ "A delicious ceviche cheesecake slice",
125
+ ]
126
+
127
+ css = """
128
+ #col-container {
129
+ margin: 0 auto;
130
+ max-width: 640px;
131
+ }
132
+ """
133
+
134
+ with gr.Blocks(css=css) as demo:
135
+ with gr.Column(elem_id="col-container"):
136
+ gr.Markdown(" # Text-to-Image Gradio Template")
137
+
138
+ # Model selector (выпадающий список)
139
+ model_select = gr.Dropdown(
140
+ label="Model",
141
+ choices=MODEL_OPTIONS,
142
+ value=DEFAULT_MODEL,
143
+ interactive=True,
144
+ )
145
+
146
+ # опциональный селектор scheduler
147
+ scheduler_select = gr.Dropdown(
148
+ label="Scheduler (optional)",
149
+ choices=["", "DDIM", "EulerAncestral", "PNDM"],
150
+ value="",
151
+ )
152
+
153
+ # Add a new slider for LoRA scale
154
+ lora_scale_slider = gr.Slider(
155
+ label="LoRA Scale (Only for LoRA model)",
156
+ minimum=0.0,
157
+ maximum=2.0,
158
+ step=0.1,
159
+ value=1.0,
160
+ visible=False, # Initially hidden
161
+ )
162
+
163
+ with gr.Row():
164
+ prompt = gr.Text(
165
+ label="Prompt",
166
+ show_label=False,
167
+ max_lines=1,
168
+ placeholder="Enter your prompt",
169
+ container=False,
170
+ )
171
+
172
+ run_button = gr.Button("Run", scale=0, variant="primary")
173
+
174
+ result = gr.Image(label="Result", show_label=False)
175
+
176
+ with gr.Accordion("Advanced Settings", open=False):
177
+ negative_prompt = gr.Text(
178
+ label="Negative prompt",
179
+ max_lines=1,
180
+ placeholder="Enter a negative prompt",
181
+ visible=True,
182
+ )
183
+
184
+ seed = gr.Slider(
185
+ label="Seed",
186
+ minimum=0,
187
+ maximum=MAX_SEED,
188
+ step=1,
189
+ value=42,
190
+ )
191
+
192
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
193
+
194
+ with gr.Row():
195
+ width = gr.Slider(
196
+ label="Width",
197
+ minimum=256,
198
+ maximum=MAX_IMAGE_SIZE,
199
+ step=32,
200
+ value=1024, # Replace with defaults that work for your model
201
+ )
202
+
203
+ height = gr.Slider(
204
+ label="Height",
205
+ minimum=256,
206
+ maximum=MAX_IMAGE_SIZE,
207
+ step=32,
208
+ value=1024, # Replace with defaults that work for your model
209
+ )
210
+
211
+ with gr.Row():
212
+ guidance_scale = gr.Slider(
213
+ label="Guidance scale",
214
+ minimum=0.0,
215
+ maximum=10.0,
216
+ step=0.1,
217
+ value=7.0, # Replace with defaults that work for your model
218
+ )
219
+
220
+ num_inference_steps = gr.Slider(
221
+ label="Number of inference steps",
222
+ minimum=1,
223
+ maximum=50,
224
+ step=1,
225
+ value=20, # Replace with defaults that work for your model
226
+ )
227
+
228
+ gr.Examples(examples=examples, inputs=[prompt])
229
+
230
+ # Function to show/hide the LoRA scale slider based on model selection
231
+ def toggle_lora_scale_slider(model_id):
232
+ if model_id == LORA_MODEL_ID:
233
+ return gr.Slider(visible=True)
234
+ else:
235
+ return gr.Slider(visible=False)
236
+
237
+ model_select.change(
238
+ fn=toggle_lora_scale_slider,
239
+ inputs=model_select,
240
+ outputs=lora_scale_slider
241
+ )
242
+
243
+ gr.on(
244
+ triggers=[run_button.click, prompt.submit],
245
+ fn=infer,
246
+ inputs=[
247
+ model_select,
248
+ prompt,
249
+ negative_prompt,
250
+ seed,
251
+ randomize_seed,
252
+ width,
253
+ height,
254
+ guidance_scale,
255
+ num_inference_steps,
256
+ scheduler_select,
257
+ lora_scale_slider
258
+ ],
259
+ outputs=[result, seed],
260
+ )
261
+
262
+ if __name__ == "__main__":
263
+ demo.launch()