seawolf2357 commited on
Commit
ea08698
·
verified ·
1 Parent(s): 8de2f43

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1146 -630
app.py CHANGED
@@ -1,660 +1,1176 @@
1
  import os
 
 
 
 
 
 
2
  import json
3
- import copy
4
- import time
5
- import requests
6
- import random
7
- import logging
 
 
8
  import numpy as np
9
- import spaces
10
- from typing import Any, Dict, List, Optional, Union
11
 
12
- import torch
13
- from PIL import Image
14
- import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- from diffusers import (
17
- DiffusionPipeline,
18
- AutoencoderKL,
19
- ZImagePipeline
20
- )
21
-
22
- from huggingface_hub import (
23
- hf_hub_download,
24
- HfFileSystem,
25
- ModelCard,
26
- snapshot_download)
27
-
28
- from diffusers.utils import load_image
29
- from typing import Iterable
30
- from gradio.themes import Soft
31
- from gradio.themes.utils import colors, fonts, sizes
32
-
33
- colors.orange_red = colors.Color(
34
- name="orange_red",
35
- c50="#FFF0E5",
36
- c100="#FFE0CC",
37
- c200="#FFC299",
38
- c300="#FFA366",
39
- c400="#FF8533",
40
- c500="#FF4500",
41
- c600="#E63E00",
42
- c700="#CC3700",
43
- c800="#B33000",
44
- c900="#992900",
45
- c950="#802200",
46
- )
47
-
48
- class OrangeRedTheme(Soft):
49
- def __init__(
50
- self,
51
- *,
52
- primary_hue: colors.Color | str = colors.gray,
53
- secondary_hue: colors.Color | str = colors.orange_red, # Use the new color
54
- neutral_hue: colors.Color | str = colors.slate,
55
- text_size: sizes.Size | str = sizes.text_lg,
56
- font: fonts.Font | str | Iterable[fonts.Font | str] = (
57
- fonts.GoogleFont("Outfit"), "Arial", "sans-serif",
58
- ),
59
- font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
60
- fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace",
61
- ),
62
- ):
63
- super().__init__(
64
- primary_hue=primary_hue,
65
- secondary_hue=secondary_hue,
66
- neutral_hue=neutral_hue,
67
- text_size=text_size,
68
- font=font,
69
- font_mono=font_mono,
70
- )
71
- super().set(
72
- background_fill_primary="*primary_50",
73
- background_fill_primary_dark="*primary_900",
74
- body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
75
- body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",
76
- button_primary_text_color="white",
77
- button_primary_text_color_hover="white",
78
- button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
79
- button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
80
- button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_700)",
81
- button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_600)",
82
- button_secondary_text_color="black",
83
- button_secondary_text_color_hover="white",
84
- button_secondary_background_fill="linear-gradient(90deg, *primary_300, *primary_300)",
85
- button_secondary_background_fill_hover="linear-gradient(90deg, *primary_400, *primary_400)",
86
- button_secondary_background_fill_dark="linear-gradient(90deg, *primary_500, *primary_600)",
87
- button_secondary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)",
88
- slider_color="*secondary_500",
89
- slider_color_dark="*secondary_600",
90
- block_title_text_weight="600",
91
- block_border_width="3px",
92
- block_shadow="*shadow_drop_lg",
93
- button_primary_shadow="*shadow_drop_lg",
94
- button_large_padding="11px",
95
- color_accent_soft="*primary_100",
96
- block_label_background_fill="*primary_200",
97
- )
98
 
99
- orange_red_theme = OrangeRedTheme()
 
 
 
 
 
 
 
 
100
 
101
- loras = [
102
- # 로컬 jimin LoRA (app.py와 같은 디렉토리에 jimin.safetensors 필요)
103
- {
104
- "image": "https://i.namu.wiki/i/umL8EZtn0hs-nMRYeFxIrkGrMe-R1u5c9fJE8ufrLjvXz52VcSIbG7TT9QJoL2rR7vsFww1lLrE4bwfn5uOBzfq9a90HGdNdlTLmr_KoqOchTovbVC3RDzhDbp7FI-Wq-esCu7_BYIptqethL4onBg.webp",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
- "title": "jimin Style",
107
- "repo": "./", # 로컬 경로
108
- "weights": "jimin.safetensors",
109
- "trigger_word": "jimin"
110
- },
111
- {
112
- "image": "https://huggingface.co/strangerzonehf/Flux-Ultimate-LoRA-Collection/resolve/main/images/1111111111.png",
113
- "title": "AWPortrait Z",
114
- "repo": "Shakker-Labs/AWPortrait-Z", #1
115
- "weights": "AWPortrait-Z.safetensors",
116
- "trigger_word": "Portrait"
117
- },
118
-
119
- {
120
- "image": "https://cdn-uploads.huggingface.co/production/uploads/653cd3049107029eb004f968/DLCGlF9uUnFo5zxR5uyx6.png",
121
- "title": "50s Western",
122
- "repo": "neph1/50s_western_lora_zit",
123
- "weights": "50s_western_z_100.safetensors",
124
- "trigger_word": "50s_western"
125
- },
126
- {
127
- "image": "https://huggingface.co/Sumitc13/Z-image-Turbo_LogC4_lora/resolve/main/images/1764464517272__000005000_1.jpg",
128
- "title": "LogC4",
129
- "repo": "Sumitc13/Z-image-Turbo_LogC4_lora", #30
130
- "weights": "z-image-logc4_000005000.safetensors",
131
- "trigger_word": "LogC4"
132
- },
133
- {
134
- "image": "https://huggingface.co/neph1/80s_scifi_lora_zit/resolve/main/images/ComfyUI_10288_.png",
135
- "title": "80s Scifi",
136
- "repo": "neph1/80s_scifi_lora_zit",
137
- "weights": "80s_scifi_z_80.safetensors",
138
- "trigger_word": "80s_scifi"
139
- },
140
-
141
- # --------------------------------------------------------------------------------------------------------------------------------------
142
- {
143
- "image": "https://huggingface.co/Ttio2/Z-Image-Turbo-pencil-sketch/resolve/main/images/z-image_00097_.png",
144
- "title": "Turbo Pencil",
145
- "repo": "Ttio2/Z-Image-Turbo-pencil-sketch", #0
146
- "weights": "Zimage_pencil_sketch.safetensors",
147
- "trigger_word": "pencil sketch"
148
- },
149
- {
150
- "image": "https://huggingface.co/neph1/50s_scifi_lora_zit/resolve/main/images/ComfyUI_08067_.png",
151
- "title": "50s Scifi",
152
- "repo": "neph1/50s_scifi_lora_zit",
153
- "weights": "50s_scifi_z_80.safetensors",
154
- "trigger_word": "50s_scifi"
155
- },
156
- {
157
- "image": "https://huggingface.co/strangerzonehf/Flux-Ultimate-LoRA-Collection/resolve/main/images/cookie-mons.png",
158
- "title": "Yarn Art Style",
159
- "repo": "linoyts/yarn-art-style", #28
160
- "weights": "yarn-art-style_000001250.safetensors",
161
- "trigger_word": "yarn art style"
162
- },
163
- {
164
- "image": "https://huggingface.co/Quorlen/Z-Image-Turbo-Behind-Reeded-Glass-Lora/resolve/main/images/ComfyUI_00391_.png",
165
- "title": "Behind Reeded Glass",
166
- "repo": "Quorlen/Z-Image-Turbo-Behind-Reeded-Glass-Lora", #26
167
- "weights": "Z_Image_Turbo_Behind_Reeded_Glass_Lora_TAV2_000002750.safetensors",
168
- "trigger_word": "Act1vate!, Behind reeded glass"
169
- },
170
- {
171
- "image": "https://huggingface.co/ostris/z_image_turbo_childrens_drawings/resolve/main/images/1764433619736__000003000_9.jpg",
172
- "title": "Childrens Drawings",
173
- "repo": "ostris/z_image_turbo_childrens_drawings", #2
174
- "weights": "z_image_turbo_childrens_drawings.safetensors",
175
- "trigger_word": "Children Drawings"
176
- },
177
- {
178
- "image": "https://huggingface.co/strangerzonehf/Flux-Ultimate-LoRA-Collection/resolve/main/images/xcxc.png",
179
- "title": "Tarot Z",
180
- "repo": "multimodalart/tarot-z-image-lora", #22
181
- "weights": "tarot-z-image_000001250.safetensors",
182
- "trigger_word": "trtcrd"
183
- },
184
- {
185
- "image": "https://huggingface.co/renderartist/Technically-Color-Z-Image-Turbo/resolve/main/images/ComfyUI_00917_.png",
186
- "title": "Technically Color Z",
187
- "repo": "renderartist/Technically-Color-Z-Image-Turbo", #3
188
- "weights": "Technically_Color_Z_Image_Turbo_v1_renderartist_2000.safetensors",
189
- "trigger_word": "t3chnic4lly"
190
- },
191
- {
192
- "image": "https://huggingface.co/SkyAsl/Tattoo-artist-Z/resolve/main/images/a%20dragon%20with%20flames.png",
193
- "title": "Tattoo-artist-Z",
194
- "repo": "SkyAsl/Tattoo-artist-Z", #31
195
- "weights": "adapter_model.safetensors",
196
- "trigger_word": "a tattoo design"
197
- },
198
- {
199
- "image": "https://huggingface.co/strangerzonehf/Flux-Ultimate-LoRA-Collection/resolve/main/images/z-image_00147_.png",
200
- "title": "Turbo Ghibli",
201
- "repo": "Ttio2/Z-Image-Turbo-Ghibli-Style", #19
202
- "weights": "ghibli_zimage_finetune.safetensors",
203
- "trigger_word": "Ghibli Style"
204
- },
205
- {
206
- "image": "https://huggingface.co/tarn59/pixel_art_style_lora_z_image_turbo/resolve/main/images/ComfyUI_00273_.png",
207
- "title": "Pixel Art",
208
- "repo": "tarn59/pixel_art_style_lora_z_image_turbo", #4
209
- "weights": "pixel_art_style_z_image_turbo.safetensors",
210
- "trigger_word": "Pixel art style."
211
- },
212
- {
213
- "image": "https://huggingface.co/renderartist/Saturday-Morning-Z-Image-Turbo/resolve/main/images/Saturday_Morning_Z_15.png",
214
- "title": "Saturday Morning",
215
- "repo": "renderartist/Saturday-Morning-Z-Image-Turbo", #5
216
- "weights": "Saturday_Morning_Z_Image_Turbo_v1_renderartist_1250.safetensors",
217
- "trigger_word": "saturd4ym0rning"
218
- },
219
- {
220
- "image": "https://huggingface.co/AIImageStudio/ReversalFilmGravure_z_Image_turbo/resolve/main/images/2025-12-01_173047-z_image_z_image_turbo_bf16-435125750859057-euler_10_hires.png",
221
- "title": "ReversalFilmGravure",
222
- "repo": "AIImageStudio/ReversalFilmGravure_z_Image_turbo", #6
223
- "weights": "z_image_turbo_ReversalFilmGravure_v1.0.safetensors",
224
- "trigger_word": "Reversal Film Gravure, analog film photography"
225
- },
226
- {
227
- "image": "https://huggingface.co/renderartist/Coloring-Book-Z-Image-Turbo-LoRA/resolve/main/images/CBZ_00274_.png",
228
- "title": "Coloring Book Z",
229
- "repo": "renderartist/Coloring-Book-Z-Image-Turbo-LoRA", #7
230
- "weights": "Coloring_Book_Z_Image_Turbo_v1_renderartist_2000.safetensors",
231
- "trigger_word": "c0l0ringb00k"
232
- },
233
- {
234
- "image": "https://huggingface.co/damnthatai/1950s_American_Dream/resolve/main/images/ZImage_20251129163459_135x_00001_.jpg",
235
- "title": "1950s American Dream",
236
- "repo": "damnthatai/1950s_American_Dream", #8
237
- "weights": "5os4m3r1c4n4_z.safetensors",
238
- "trigger_word": "5os4m3r1c4n4, 1950s, painting, a painting of"
239
- },
240
- {
241
- "image": "https://huggingface.co/wcde/Z-Image-Turbo-DeJPEG-Lora/resolve/main/images/01.png",
242
- "title": "DeJPEG",
243
- "repo": "wcde/Z-Image-Turbo-DeJPEG-Lora", #9
244
- "weights": "dejpeg_v3.safetensors",
245
- "trigger_word": ""
246
- },
247
- {
248
- "image": "https://huggingface.co/suayptalha/Z-Image-Turbo-Realism-LoRA/resolve/main/images/n4aSpqa-YFXYo4dtcIg4W.png",
249
- "title": "DeJPEG",
250
- "repo": "suayptalha/Z-Image-Turbo-Realism-LoRA", #10
251
- "weights": "pytorch_lora_weights.safetensors",
252
- "trigger_word": "Realism"
253
- },
254
- {
255
- "image": "https://huggingface.co/renderartist/Classic-Painting-Z-Image-Turbo-LoRA/resolve/main/images/Classic_Painting_Z_00247_.png",
256
- "title": "Classic Painting Z",
257
- "repo": "renderartist/Classic-Painting-Z-Image-Turbo-LoRA", #11
258
- "weights": "Classic_Painting_Z_Image_Turbo_v1_renderartist_1750.safetensors",
259
- "trigger_word": "class1cpa1nt"
260
- },
261
- {
262
- "image": "https://huggingface.co/DK9/3D_MMORPG_style_z-image-turbo_lora/resolve/main/images/10_with_lora.png",
263
- "title": "3D MMORPG",
264
- "repo": "DK9/3D_MMORPG_style_z-image-turbo_lora", #12
265
- "weights": "lostark_v1.safetensors",
266
- "trigger_word": ""
267
- },
268
- {
269
- "image": "https://huggingface.co/Danrisi/Olympus_UltraReal_ZImage/resolve/main/images/Z-Image_01011_.png",
270
- "title": "Olympus UltraReal",
271
- "repo": "Danrisi/Olympus_UltraReal_ZImage", #13
272
- "weights": "Olympus.safetensors",
273
- "trigger_word": "digital photography, early 2000s compact camera aesthetic, amateur candid shot, digital photography, early 2000s compact camera aesthetic, amateur candid shot, direct flash lighting, hard flash shadow, specular highlights, overexposed highlights"
274
- },
275
- {
276
- "image": "https://huggingface.co/AiAF/D-ART_Z-Image-Turbo_LoRA/resolve/main/images/example_l3otpwzaz.png",
277
- "title": "D ART Z Image",
278
- "repo": "AiAF/D-ART_Z-Image-Turbo_LoRA", #14
279
- "weights": "D-ART_Z-Image-Turbo.safetensors",
280
- "trigger_word": "D-ART"
281
- },
282
- {
283
- "image": "https://huggingface.co/AlekseyCalvin/Marionette_Modernism_Z-image-Turbo_LoRA/resolve/main/bluebirdmandoll.webp",
284
- "title": "Marionette Modernism",
285
- "repo": "AlekseyCalvin/Marionette_Modernism_Z-image-Turbo_LoRA", #15
286
- "weights": "ZImageDadadoll_000003600.safetensors",
287
- "trigger_word": "DADADOLL style"
288
- },
289
- {
290
- "image": "https://huggingface.co/AlekseyCalvin/HistoricColor_Z-image-Turbo-LoRA/resolve/main/HSTZgen2.webp",
291
- "title": "Historic Color Z",
292
- "repo": "AlekseyCalvin/HistoricColor_Z-image-Turbo-LoRA", #16
293
- "weights": "ZImage1HST_000004000.safetensors",
294
- "trigger_word": "HST style"
295
- },
296
- {
297
- "image": "https://huggingface.co/tarn59/80s_air_brush_style_z_image_turbo/resolve/main/images/ComfyUI_00707_.png",
298
- "title": "80s Air Brush",
299
- "repo": "tarn59/80s_air_brush_style_z_image_turbo", #17
300
- "weights": "80s_air_brush_style_v2_z_image_turbo.safetensors",
301
- "trigger_word": "80s Air Brush style."
302
- },
303
- {
304
- "image": "https://huggingface.co/CedarC/Z-Image_360/resolve/main/images/1765505225357__000006750_6.jpg",
305
- "title": "360panorama",
306
- "repo": "CedarC/Z-Image_360", #18
307
- "weights": "Z-Image_360.safetensors",
308
- "trigger_word": "360panorama"
309
- },
310
- {
311
- "image": "https://huggingface.co/HAV0X1014/Z-Image-Turbo-KF-Bat-Eared-Fox-LoRA/resolve/main/images/ComfyUI_00132_.png",
312
- "title": "KF-Bat-Eared",
313
- "repo": "HAV0X1014/Z-Image-Turbo-KF-Bat-Eared-Fox-LoRA", #21
314
- "weights": "z-image-turbo-bat_eared_fox.safetensors",
315
- "trigger_word": "bat_eared_fox_kemono_friends"
316
- },
317
- {
318
- "image": "https://cdn-uploads.huggingface.co/production/uploads/653cd3049107029eb004f968/IHttgddXu6ZBMo7eyy8p6.png",
319
- "title": "80s Horror",
320
- "repo": "neph1/80s_horror_movies_lora_zit", #23
321
- "weights": "80s_horror_z_80.safetensors",
322
- "trigger_word": "80s_horror"
323
- },
324
- {
325
- "image": "https://huggingface.co/Quorlen/z_image_turbo_Sunbleached_Protograph_Style_Lora/resolve/main/images/ComfyUI_00024_.png",
326
- "title": "Sunbleached Protograph",
327
- "repo": "Quorlen/z_image_turbo_Sunbleached_Protograph_Style_Lora", #24
328
- "weights": "zimageturbo_Sunbleach_Photograph_Style_Lora_TAV2_000002750.safetensors",
329
- "trigger_word": "Act1vate!"
330
- },
331
- {
332
- "image": "https://huggingface.co/bunnycore/Z-Art-2.1/resolve/main/images/ComfyUI_00069_.png",
333
- "title": "Z-Art-2.1",
334
- "repo": "bunnycore/Z-Art-2.1", #25
335
- "weights": "Z-Image-Art2.1.safetensors",
336
- "trigger_word": "anime art"
337
- },
338
- {
339
- "image": "https://huggingface.co/cactusfriend/longfurby-z/resolve/main/images/1764658860954__000003000_1.jpg",
340
- "title": "Longfurby",
341
- "repo": "cactusfriend/longfurby-z", #27
342
- "weights": "longfurbyZ.safetensors",
343
- "trigger_word": ""
344
- },
345
- {
346
- "image": "https://huggingface.co/SkyAsl/Pixel-artist-Z/resolve/main/pixel-art-result.png",
347
- "title": "Pixel Art",
348
- "repo": "SkyAsl/Pixel-artist-Z", #29
349
- "weights": "adapter_model.safetensors",
350
- "trigger_word": "a pixel art character"
351
- },
352
- ]
353
-
354
- dtype = torch.bfloat16
355
- device = "cuda" if torch.cuda.is_available() else "cpu"
356
- base_model = "Tongyi-MAI/Z-Image-Turbo"
357
-
358
- print(f"Loading {base_model} pipeline...")
359
-
360
- # Initialize Pipeline
361
- pipe = ZImagePipeline.from_pretrained(
362
- base_model,
363
- torch_dtype=dtype,
364
- low_cpu_mem_usage=False,
365
- ).to(device)
366
-
367
- # ======== AoTI compilation + FA3 ========
368
- # As per reference for optimization
369
- try:
370
- print("Applying AoTI compilation and FA3...")
371
- pipe.transformer.layers._repeated_blocks = ["ZImageTransformerBlock"]
372
- spaces.aoti_blocks_load(pipe.transformer.layers, "zerogpu-aoti/Z-Image", variant="fa3")
373
- print("Optimization applied successfully.")
374
- except Exception as e:
375
- print(f"Optimization warning: {e}. Continuing with standard pipeline.")
376
-
377
- MAX_SEED = np.iinfo(np.int32).max
378
-
379
- class calculateDuration:
380
- def __init__(self, activity_name=""):
381
- self.activity_name = activity_name
382
-
383
- def __enter__(self):
384
- self.start_time = time.time()
385
- return self
386
 
387
- def __exit__(self, exc_type, exc_value, traceback):
388
- self.end_time = time.time()
389
- self.elapsed_time = self.end_time - self.start_time
390
- if self.activity_name:
391
- print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
392
- else:
393
- print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
394
-
395
- def update_selection(evt: gr.SelectData, width, height):
396
- selected_lora = loras[evt.index]
397
- new_placeholder = f"Type a prompt for {selected_lora['title']}"
398
- lora_repo = selected_lora["repo"]
399
- # 로컬 LoRA 처리
400
- if lora_repo == "./":
401
- updated_text = f"### Selected: Local LoRA - {selected_lora['title']} ✅"
402
- else:
403
- updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo}) ✅"
404
- if "aspect" in selected_lora:
405
- if selected_lora["aspect"] == "portrait":
406
- width = 768
407
- height = 1024
408
- elif selected_lora["aspect"] == "landscape":
409
- width = 1024
410
- height = 768
411
- else:
412
- width = 1024
413
- height = 1024
414
- return (
415
- gr.update(placeholder=new_placeholder),
416
- updated_text,
417
- evt.index,
418
- width,
419
- height,
420
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
421
 
422
- @spaces.GPU
423
- def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
424
- # Clean up previous LoRAs in both cases
425
- with calculateDuration("Unloading LoRA"):
426
- pipe.unload_lora_weights()
 
427
 
428
- # Check if a LoRA is selected
429
- if selected_index is not None and selected_index < len(loras):
430
- selected_lora = loras[selected_index]
431
- lora_path = selected_lora["repo"]
432
- trigger_word = selected_lora["trigger_word"]
433
-
434
- # Prepare Prompt with Trigger Word
435
- if(trigger_word):
436
- if "trigger_position" in selected_lora:
437
- if selected_lora["trigger_position"] == "prepend":
438
- prompt_mash = f"{trigger_word} {prompt}"
439
- else:
440
- prompt_mash = f"{prompt} {trigger_word}"
441
- else:
442
- prompt_mash = f"{trigger_word} {prompt}"
443
  else:
444
- prompt_mash = prompt
445
 
446
- # Load LoRA
447
- with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
448
- weight_name = selected_lora.get("weights", None)
449
- try:
450
- pipe.load_lora_weights(
451
- lora_path,
452
- weight_name=weight_name,
453
- adapter_name="default",
454
- low_cpu_mem_usage=True
455
- )
456
- # Set adapter scale
457
- pipe.set_adapters(["default"], adapter_weights=[lora_scale])
458
- except Exception as e:
459
- print(f"Error loading LoRA: {e}")
460
- gr.Warning("Failed to load LoRA weights. Generating with base model.")
461
  else:
462
- # Base Model Case
463
- print("No LoRA selected. Running with Base Model.")
464
- prompt_mash = prompt
465
-
466
- with calculateDuration("Randomizing seed"):
467
- if randomize_seed:
468
- seed = random.randint(0, MAX_SEED)
469
 
470
- generator = torch.Generator(device=device).manual_seed(seed)
 
 
 
 
 
 
 
 
 
 
 
 
471
 
472
- # Note: Z-Image-Turbo is strictly T2I in this reference implementation.
473
- # Img2Img via image_input is disabled/ignored for this pipeline update.
 
 
474
 
475
- with calculateDuration("Generating image"):
476
- # For Turbo models, guidance_scale is typically 0.0
477
- forced_guidance = 0.0 # Turbo mode
 
 
478
 
479
- final_image = pipe(
480
- prompt=prompt_mash,
481
- height=int(height),
482
- width=int(width),
483
- num_inference_steps=int(steps),
484
- guidance_scale=forced_guidance,
485
- generator=generator,
486
- ).images[0]
487
 
488
- yield final_image, seed, gr.update(visible=False)
489
-
490
- def get_huggingface_safetensors(link):
491
- split_link = link.split("/")
492
- if(len(split_link) == 2):
493
- model_card = ModelCard.load(link)
494
- base_model = model_card.data.get("base_model")
495
- print(base_model)
496
-
497
- # Relaxed check to allow Z-Image or Flux or others, assuming user knows what they are doing
498
- # or specifically check for Z-Image-Turbo
499
- if base_model not in ["Tongyi-MAI/Z-Image-Turbo", "black-forest-labs/FLUX.1-dev"]:
500
- # Just a warning instead of error to allow experimentation
501
- print("Warning: Base model might not match.")
502
-
503
- image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
504
- trigger_word = model_card.data.get("instance_prompt", "")
505
- image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
506
- fs = HfFileSystem()
507
- try:
508
- list_of_files = fs.ls(link, detail=False)
509
- for file in list_of_files:
510
- if(file.endswith(".safetensors")):
511
- safetensors_name = file.split("/")[-1]
512
- if (not image_url and file.lower().endswith((".jpg", ".jpeg", ".png", ".webp"))):
513
- image_elements = file.split("/")
514
- image_url = f"https://huggingface.co/{link}/resolve/main/{image_elements[-1]}"
515
- except Exception as e:
516
- print(e)
517
- gr.Warning(f"You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
518
- raise Exception(f"You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
519
- return split_link[1], link, safetensors_name, trigger_word, image_url
520
-
521
- def check_custom_model(link):
522
- if(link.startswith("https://")):
523
- if(link.startswith("https://huggingface.co") or link.startswith("https://www.huggingface.co")):
524
- link_split = link.split("huggingface.co/")
525
- return get_huggingface_safetensors(link_split[1])
526
- else:
527
- return get_huggingface_safetensors(link)
528
-
529
- def add_custom_lora(custom_lora):
530
- global loras
531
- if(custom_lora):
532
  try:
533
- title, repo, path, trigger_word, image = check_custom_model(custom_lora)
534
- print(f"Loaded custom LoRA: {repo}")
535
- card = f'''
536
- <div class="custom_lora_card">
537
- <span>Loaded custom LoRA:</span>
538
- <div class="card_internal">
539
- <img src="{image}" />
540
- <div>
541
- <h3>{title}</h3>
542
- <small>{"Using: <code><b>"+trigger_word+"</code></b> as the trigger word" if trigger_word else "No trigger word found. If there's a trigger word, include it in your prompt"}<br></small>
543
- </div>
544
- </div>
545
- </div>
546
- '''
547
- existing_item_index = next((index for (index, item) in enumerate(loras) if item['repo'] == repo), None)
548
- if(not existing_item_index):
549
- new_item = {
550
- "image": image,
551
- "title": title,
552
- "repo": repo,
553
- "weights": path,
554
- "trigger_word": trigger_word
555
- }
556
- print(new_item)
557
- existing_item_index = len(loras)
558
- loras.append(new_item)
559
 
560
- return gr.update(visible=True, value=card), gr.update(visible=True), gr.Gallery(selected_index=None), f"Custom: {path}", existing_item_index, trigger_word
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
561
  except Exception as e:
562
- gr.Warning(f"Invalid LoRA: either you entered an invalid link, or a non-supported LoRA")
563
- return gr.update(visible=True, value=f"Invalid LoRA: either you entered an invalid link, a non-supported LoRA"), gr.update(visible=False), gr.update(), "", None, ""
564
- else:
565
- return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
566
-
567
- def remove_custom_lora():
568
- return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
569
-
570
- run_lora.zerogpu = True
571
-
572
- css = '''
573
- #gen_btn{height: 100%}
574
- #gen_column{align-self: stretch}
575
- #title{text-align: center}
576
- #title h1{font-size: 3em; display:inline-flex; align-items:center}
577
- #title img{width: 100px; margin-right: 0.5em}
578
- #gallery .grid-wrap{height: 10vh}
579
- #lora_list{background: var(--block-background-fill);padding: 0 1em .3em; font-size: 90%}
580
- .card_internal{display: flex;height: 100px;margin-top: .5em}
581
- .card_internal img{margin-right: 1em}
582
- .styler{--form-gap-width: 0px !important}
583
- #progress{height:30px}
584
- #progress .generating{display:none}
585
- .progress-container {width: 100%;height: 30px;background-color: #f0f0f0;border-radius: 15px;overflow: hidden;margin-bottom: 20px}
586
- .progress-bar {height: 100%;background-color: #4f46e5;width: calc(var(--current) / var(--total) * 100%);transition: width 0.5s ease-in-out}
587
- '''
588
-
589
- with gr.Blocks(delete_cache=(60, 60)) as demo:
590
- title = gr.HTML(
591
- """<h1>Z Image Turbo LoRA DLC 🧪</h1>""",
592
- elem_id="title",
593
- )
594
- selected_index = gr.State(None)
 
 
 
 
595
  with gr.Row():
596
- with gr.Column(scale=3):
597
- prompt = gr.Textbox(label="Enter Prompt", lines=1, placeholder="✦︎ Choose the LoRA and type the prompt (LoRA = None → Base Model = Active)")
598
- with gr.Column(scale=1, elem_id="gen_column"):
599
- generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
 
 
 
600
  with gr.Row():
601
- with gr.Column():
602
- selected_info = gr.Markdown("### No LoRA Selected (Base Model)")
603
- gallery = gr.Gallery(
604
- [(item["image"], item["title"]) for item in loras],
605
- label="Z-Image LoRAs",
606
- allow_preview=False,
607
- columns=3,
608
- elem_id="gallery",
 
609
  )
610
- with gr.Group():
611
- custom_lora = gr.Textbox(label="Enter Custom LoRA", placeholder="Paste the LoRA path and press Enter (e.g., Shakker-Labs/AWPortrait-Z).")
612
- gr.Markdown("[Check the list of Z-Image LoRA's](https://huggingface.co/models?other=base_model:adapter:Tongyi-MAI/Z-Image-Turbo)", elem_id="lora_list")
613
- custom_lora_info = gr.HTML(visible=False)
614
- custom_lora_button = gr.Button("Remove custom LoRA", visible=False)
615
- with gr.Column():
616
- progress_bar = gr.Markdown(elem_id="progress",visible=False)
617
- result = gr.Image(label="Generated Image", format="png", height=630)
618
 
619
- with gr.Row():
620
- with gr.Accordion("Advanced Settings", open=False):
 
 
 
 
 
 
 
 
 
 
 
 
621
  with gr.Row():
622
- input_image = gr.Image(label="Input image (Ignored for Z-Image-Turbo)", type="filepath", visible=False)
623
- image_strength = gr.Slider(label="Denoise Strength", info="Ignored for Z-Image-Turbo", minimum=0.1, maximum=1.0, step=0.01, value=0.75, visible=False)
624
- with gr.Column():
625
- with gr.Row():
626
- cfg_scale = gr.Slider(label="CFG Scale", info="Forced to 0.0 for Turbo", minimum=0, maximum=20, step=0.5, value=0.0, interactive=False)
627
- steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=25)
628
-
629
- with gr.Row():
630
- width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1536)
631
- height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1536)
632
-
633
- with gr.Row():
634
- randomize_seed = gr.Checkbox(True, label="Randomize seed")
635
- seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
636
- lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=3, step=0.01, value=0.95)
637
-
638
- gallery.select(
639
- update_selection,
640
- inputs=[width, height],
641
- outputs=[prompt, selected_info, selected_index, width, height]
642
- )
643
- custom_lora.input(
644
- add_custom_lora,
645
- inputs=[custom_lora],
646
- outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, prompt]
647
- )
648
- custom_lora_button.click(
649
- remove_custom_lora,
650
- outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, custom_lora]
651
- )
652
- gr.on(
653
- triggers=[generate_button.click, prompt.submit],
654
- fn=run_lora,
655
- inputs=[prompt, input_image, image_strength, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale],
656
- outputs=[result, seed, progress_bar]
 
 
 
657
  )
658
 
659
- demo.queue()
660
- demo.launch(theme=orange_red_theme, css=css, mcp_server=True, ssr_mode=False, show_error=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import gc
3
+ import torch
4
+ import gradio as gr
5
+ from PIL import Image
6
+ from pathlib import Path
7
+ import shutil
8
  import json
9
+ from huggingface_hub import hf_hub_download, HfApi, create_repo, upload_file
10
+ import tempfile
11
+ import multiprocessing as mp
12
+
13
+ # Training imports
14
+ from peft import LoraConfig, get_peft_model
15
+ from tqdm.auto import tqdm
16
  import numpy as np
 
 
17
 
18
+ # Set memory allocation config BEFORE any CUDA operations
19
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True,garbage_collection_threshold:0.6'
20
+
21
+ # Global state
22
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
23
+ DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
24
+
25
+ # HF Token from environment
26
+ HF_TOKEN = os.environ.get("HF_TOKEN")
27
+
28
+ # Model repo
29
+ MODEL_REPO = "Tongyi-MAI/Z-Image-Turbo"
30
+
31
+
32
+ # ============================================
33
+ # Comic Style CSS
34
+ # ============================================
35
+ COMIC_CSS = """
36
+ @import url('https://fonts.googleapis.com/css2?family=Bangers&family=Comic+Neue:wght@400;700&display=swap');
37
+
38
+ .gradio-container {
39
+ background-color: #FEF9C3 !important;
40
+ background-image: radial-gradient(#1F2937 1px, transparent 1px) !important;
41
+ background-size: 20px 20px !important;
42
+ min-height: 100vh !important;
43
+ font-family: 'Comic Neue', cursive, sans-serif !important;
44
+ }
45
+
46
+ footer, .footer, .gradio-container footer, .built-with, [class*="footer"], .gradio-footer, a[href*="gradio.app"] {
47
+ display: none !important;
48
+ visibility: hidden !important;
49
+ height: 0 !important;
50
+ }
51
+
52
+ /* HOME Button Style */
53
+ .home-button-container {
54
+ display: flex;
55
+ justify-content: center;
56
+ align-items: center;
57
+ gap: 15px;
58
+ margin-bottom: 15px;
59
+ padding: 12px 20px;
60
+ background: linear-gradient(135deg, #10B981 0%, #059669 100%);
61
+ border: 4px solid #1F2937;
62
+ border-radius: 12px;
63
+ box-shadow: 6px 6px 0 #1F2937;
64
+ }
65
+
66
+ .home-button {
67
+ display: inline-flex;
68
+ align-items: center;
69
+ gap: 8px;
70
+ padding: 10px 25px;
71
+ background: linear-gradient(135deg, #FACC15 0%, #F59E0B 100%);
72
+ color: #1F2937;
73
+ font-family: 'Bangers', cursive;
74
+ font-size: 1.4rem;
75
+ letter-spacing: 2px;
76
+ text-decoration: none;
77
+ border: 3px solid #1F2937;
78
+ border-radius: 8px;
79
+ box-shadow: 4px 4px 0 #1F2937;
80
+ transition: all 0.2s ease;
81
+ }
82
+
83
+ .home-button:hover {
84
+ background: linear-gradient(135deg, #FDE047 0%, #FACC15 100%);
85
+ transform: translate(-2px, -2px);
86
+ box-shadow: 6px 6px 0 #1F2937;
87
+ }
88
+
89
+ .home-button:active {
90
+ transform: translate(2px, 2px);
91
+ box-shadow: 2px 2px 0 #1F2937;
92
+ }
93
+
94
+ .url-display {
95
+ font-family: 'Comic Neue', cursive;
96
+ font-size: 1.1rem;
97
+ font-weight: 700;
98
+ color: #FFF;
99
+ background: rgba(0,0,0,0.3);
100
+ padding: 8px 16px;
101
+ border-radius: 6px;
102
+ border: 2px solid rgba(255,255,255,0.3);
103
+ }
104
+
105
+ .header-container {
106
+ text-align: center;
107
+ padding: 25px 20px;
108
+ background: linear-gradient(135deg, #3B82F6 0%, #8B5CF6 100%);
109
+ border: 4px solid #1F2937;
110
+ border-radius: 12px;
111
+ margin-bottom: 20px;
112
+ box-shadow: 8px 8px 0 #1F2937;
113
+ position: relative;
114
+ }
115
+
116
+ .header-title {
117
+ font-family: 'Bangers', cursive !important;
118
+ color: #FFF !important;
119
+ font-size: 2.8rem !important;
120
+ text-shadow: 3px 3px 0 #1F2937 !important;
121
+ letter-spacing: 3px !important;
122
+ margin: 0 !important;
123
+ }
124
+
125
+ .header-subtitle {
126
+ font-family: 'Comic Neue', cursive !important;
127
+ font-size: 1.1rem !important;
128
+ color: #FEF9C3 !important;
129
+ margin-top: 8px !important;
130
+ font-weight: 700 !important;
131
+ }
132
+
133
+ .stats-badge {
134
+ display: inline-block;
135
+ background: #FACC15;
136
+ color: #1F2937;
137
+ padding: 6px 14px;
138
+ border-radius: 20px;
139
+ font-size: 0.9rem;
140
+ margin: 3px;
141
+ font-weight: 700;
142
+ border: 2px solid #1F2937;
143
+ box-shadow: 2px 2px 0 #1F2937;
144
+ }
145
+
146
+ .gr-panel, .gr-box, .gr-form, .block, .gr-group {
147
+ background: #FFF !important;
148
+ border: 3px solid #1F2937 !important;
149
+ border-radius: 8px !important;
150
+ box-shadow: 5px 5px 0 #1F2937 !important;
151
+ }
152
+
153
+ .gr-button-primary, button.primary, .gr-button.primary {
154
+ background: linear-gradient(135deg, #EF4444 0%, #F97316 100%) !important;
155
+ border: 3px solid #1F2937 !important;
156
+ border-radius: 8px !important;
157
+ color: #FFF !important;
158
+ font-family: 'Bangers', cursive !important;
159
+ font-size: 1.3rem !important;
160
+ letter-spacing: 2px !important;
161
+ padding: 12px 24px !important;
162
+ box-shadow: 4px 4px 0 #1F2937 !important;
163
+ text-shadow: 1px 1px 0 #1F2937 !important;
164
+ transition: all 0.2s ease !important;
165
+ }
166
+
167
+ .gr-button-primary:hover, button.primary:hover {
168
+ background: linear-gradient(135deg, #DC2626 0%, #EA580C 100%) !important;
169
+ transform: translate(-2px, -2px) !important;
170
+ box-shadow: 6px 6px 0 #1F2937 !important;
171
+ }
172
+
173
+ .gr-button-primary:active, button.primary:active {
174
+ transform: translate(2px, 2px) !important;
175
+ box-shadow: 2px 2px 0 #1F2937 !important;
176
+ }
177
+
178
+ textarea, input[type="text"], input[type="number"] {
179
+ background: #FFF !important;
180
+ border: 3px solid #1F2937 !important;
181
+ border-radius: 8px !important;
182
+ color: #1F2937 !important;
183
+ font-family: 'Comic Neue', cursive !important;
184
+ font-weight: 700 !important;
185
+ }
186
+
187
+ textarea:focus, input[type="text"]:focus {
188
+ border-color: #3B82F6 !important;
189
+ box-shadow: 3px 3px 0 #3B82F6 !important;
190
+ }
191
+
192
+ .info-box {
193
+ background: linear-gradient(135deg, #FACC15 0%, #FDE047 100%) !important;
194
+ border: 3px solid #1F2937 !important;
195
+ border-radius: 8px !important;
196
+ padding: 12px 15px !important;
197
+ margin: 10px 0 !important;
198
+ box-shadow: 4px 4px 0 #1F2937 !important;
199
+ font-family: 'Comic Neue', cursive !important;
200
+ font-weight: 700 !important;
201
+ color: #1F2937 !important;
202
+ }
203
+
204
+ .result-box textarea {
205
+ background: #1F2937 !important;
206
+ color: #10B981 !important;
207
+ font-family: 'Courier New', monospace !important;
208
+ border: 3px solid #10B981 !important;
209
+ border-radius: 8px !important;
210
+ box-shadow: 4px 4px 0 #10B981 !important;
211
+ }
212
+
213
+ label, .gr-input-label, .gr-block-label {
214
+ color: #1F2937 !important;
215
+ font-family: 'Comic Neue', cursive !important;
216
+ font-weight: 700 !important;
217
+ }
218
+
219
+ .gr-accordion {
220
+ background: #E0F2FE !important;
221
+ border: 3px solid #1F2937 !important;
222
+ border-radius: 8px !important;
223
+ box-shadow: 4px 4px 0 #1F2937 !important;
224
+ }
225
+
226
+ .tab-nav button {
227
+ font-family: 'Comic Neue', cursive !important;
228
+ font-weight: 700 !important;
229
+ border: 2px solid #1F2937 !important;
230
+ margin: 2px !important;
231
+ }
232
+
233
+ .tab-nav button.selected {
234
+ background: #3B82F6 !important;
235
+ color: #FFF !important;
236
+ box-shadow: 3px 3px 0 #1F2937 !important;
237
+ }
238
+
239
+ .footer-comic {
240
+ text-align: center;
241
+ padding: 20px;
242
+ background: linear-gradient(135deg, #3B82F6 0%, #8B5CF6 100%);
243
+ border: 4px solid #1F2937;
244
+ border-radius: 12px;
245
+ margin-top: 20px;
246
+ box-shadow: 6px 6px 0 #1F2937;
247
+ }
248
+
249
+ .footer-comic p {
250
+ font-family: 'Comic Neue', cursive !important;
251
+ color: #FFF !important;
252
+ margin: 5px 0 !important;
253
+ font-weight: 700 !important;
254
+ }
255
+
256
+ ::-webkit-scrollbar {
257
+ width: 12px;
258
+ height: 12px;
259
+ }
260
+
261
+ ::-webkit-scrollbar-track {
262
+ background: #FEF9C3;
263
+ border: 2px solid #1F2937;
264
+ }
265
+
266
+ ::-webkit-scrollbar-thumb {
267
+ background: #3B82F6;
268
+ border: 2px solid #1F2937;
269
+ border-radius: 6px;
270
+ }
271
+
272
+ ::-webkit-scrollbar-thumb:hover {
273
+ background: #EF4444;
274
+ }
275
+
276
+ ::selection {
277
+ background: #FACC15;
278
+ color: #1F2937;
279
+ }
280
+
281
+ /* Slider Styling */
282
+ input[type="range"] {
283
+ accent-color: #3B82F6;
284
+ }
285
+
286
+ .gr-slider input[type="range"]::-webkit-slider-thumb {
287
+ background: #EF4444 !important;
288
+ border: 2px solid #1F2937 !important;
289
+ }
290
+
291
+ /* Image/Gallery Container */
292
+ .gr-image, .gr-gallery {
293
+ border: 3px solid #1F2937 !important;
294
+ border-radius: 8px !important;
295
+ box-shadow: 4px 4px 0 #1F2937 !important;
296
+ }
297
+
298
+ /* Quality Badge */
299
+ .quality-badge {
300
+ display: inline-block;
301
+ background: linear-gradient(135deg, #10B981 0%, #059669 100%);
302
+ color: white;
303
+ padding: 4px 12px;
304
+ border-radius: 15px;
305
+ font-size: 0.8rem;
306
+ font-weight: bold;
307
+ border: 2px solid #1F2937;
308
+ margin-left: 8px;
309
+ }
310
+
311
+ #col-container {
312
+ max-width: 1200px;
313
+ margin: 0 auto;
314
+ }
315
+
316
+ /* Hide Hugging Face elements */
317
+ .huggingface-space-link,
318
+ a[href*="huggingface.co/spaces"],
319
+ button[class*="share"],
320
+ .share-button,
321
+ [class*="hf-logo"],
322
+ .gr-share-btn,
323
+ #hf-logo,
324
+ .hf-icon,
325
+ svg[class*="hf"],
326
+ div[class*="huggingface"],
327
+ a[class*="huggingface"],
328
+ .svelte-1rjryqp,
329
+ header a[href*="huggingface"],
330
+ .space-header,
331
+ div.absolute.right-0 a[href*="huggingface"],
332
+ .gr-group > a[href*="huggingface"],
333
+ a[target="_blank"][href*="huggingface.co"] {
334
+ display: none !important;
335
+ visibility: hidden !important;
336
+ opacity: 0 !important;
337
+ pointer-events: none !important;
338
+ width: 0 !important;
339
+ height: 0 !important;
340
+ overflow: hidden !important;
341
+ }
342
+
343
+ /* Training specific styles */
344
+ .training-section {
345
+ background: linear-gradient(135deg, #E0F2FE 0%, #DBEAFE 100%) !important;
346
+ border: 3px solid #1F2937 !important;
347
+ border-radius: 12px !important;
348
+ padding: 15px !important;
349
+ margin: 10px 0 !important;
350
+ box-shadow: 4px 4px 0 #1F2937 !important;
351
+ }
352
+
353
+ .tips-box {
354
+ background: linear-gradient(135deg, #D1FAE5 0%, #A7F3D0 100%) !important;
355
+ border: 3px solid #1F2937 !important;
356
+ border-radius: 8px !important;
357
+ padding: 12px 15px !important;
358
+ margin: 10px 0 !important;
359
+ box-shadow: 4px 4px 0 #1F2937 !important;
360
+ font-family: 'Comic Neue', cursive !important;
361
+ font-weight: 700 !important;
362
+ color: #1F2937 !important;
363
+ }
364
+ """
365
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366
 
367
+ def aggressive_cleanup():
368
+ """Aggressively clean up GPU memory"""
369
+ gc.collect()
370
+ gc.collect()
371
+ if torch.cuda.is_available():
372
+ torch.cuda.empty_cache()
373
+ torch.cuda.synchronize()
374
+ torch.cuda.reset_peak_memory_stats()
375
+ torch.cuda.reset_accumulated_memory_stats()
376
 
377
+
378
+ def get_gpu_memory_info():
379
+ """Get current GPU memory status"""
380
+ if torch.cuda.is_available():
381
+ allocated = torch.cuda.memory_allocated(0) / 1e9
382
+ reserved = torch.cuda.memory_reserved(0) / 1e9
383
+ total = torch.cuda.get_device_properties(0).total_memory / 1e9
384
+ free = total - allocated
385
+ return f"GPU: {allocated:.1f}GB allocated, {reserved:.1f}GB reserved, {free:.1f}GB free of {total:.1f}GB"
386
+ return "No GPU"
387
+
388
+
389
+ def check_gpu():
390
+ """Check GPU availability and memory"""
391
+ if torch.cuda.is_available():
392
+ gpu_name = torch.cuda.get_device_name(0)
393
+ gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
394
+ return f"✅ GPU: {gpu_name} ({gpu_mem:.1f}GB total)"
395
+ return "❌ No GPU detected"
396
+
397
+
398
+ def check_hf_token():
399
+ """Check if HF_TOKEN is configured"""
400
+ if HF_TOKEN:
401
+ try:
402
+ api = HfApi(token=HF_TOKEN)
403
+ user_info = api.whoami()
404
+ return f"✅ Logged in as: {user_info['name']}"
405
+ except Exception as e:
406
+ return f"⚠️ Token invalid: {str(e)}"
407
+ return "❌ HF_TOKEN not set"
408
+
409
+
410
+ def get_hf_username():
411
+ if HF_TOKEN:
412
+ try:
413
+ api = HfApi(token=HF_TOKEN)
414
+ return api.whoami()['name']
415
+ except:
416
+ return None
417
+ return None
418
+
419
+
420
+ # ============================================
421
+ # SUBPROCESS-BASED CAPTIONING
422
+ # ============================================
423
+
424
+ def _caption_worker(image_paths_queue, results_queue, trigger_word, is_person):
425
+ """Worker process for Florence-2 captioning - completely isolated GPU context"""
426
+ import torch
427
+ from PIL import Image
428
+
429
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
430
+ device = "cuda" if torch.cuda.is_available() else "cpu"
431
+
432
+ try:
433
+ from transformers import AutoProcessor, AutoModelForCausalLM
434
+
435
+ print("[Caption Worker] Loading Florence-2-large...")
436
+ processor = AutoProcessor.from_pretrained(
437
+ "microsoft/Florence-2-large",
438
+ trust_remote_code=True
439
+ )
440
+ # FIX: Add attn_implementation="eager" to avoid _supports_sdpa error
441
+ model = AutoModelForCausalLM.from_pretrained(
442
+ "microsoft/Florence-2-large",
443
+ torch_dtype=torch.float16,
444
+ trust_remote_code=True,
445
+ attn_implementation="eager" # Disable SDPA to avoid attribute error
446
+ ).to(device)
447
+ model.eval()
448
+ print("[Caption Worker] Florence-2 loaded!")
449
 
450
+ image_paths = image_paths_queue.get()
451
+ captions = []
452
+
453
+ for idx, img_path in enumerate(image_paths):
454
+ try:
455
+ img = Image.open(img_path).convert("RGB")
456
+ task = "<DETAILED_CAPTION>"
457
+ inputs = processor(text=task, images=img, return_tensors="pt")
458
+ inputs = {k: v.to(device) for k, v in inputs.items()}
459
+
460
+ with torch.no_grad():
461
+ generated_ids = model.generate(
462
+ **inputs, max_new_tokens=256, num_beams=3, do_sample=False
463
+ )
464
+
465
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
466
+ parsed = processor.post_process_generation(
467
+ generated_text, task=task, image_size=(img.width, img.height)
468
+ )
469
+
470
+ raw_caption = parsed.get(task, "")
471
+
472
+ if raw_caption:
473
+ caption = raw_caption.strip()
474
+ if is_person:
475
+ replacements = [
476
+ "a young woman", "a woman", "the woman", "a young man", "a man", "the man",
477
+ "a person", "the person", "a young person", "an individual",
478
+ "a girl", "the girl", "a boy", "the boy",
479
+ "a lady", "the lady", "a gentleman", "the gentleman",
480
+ "someone", "a figure", "the figure"
481
+ ]
482
+ caption_lower = caption.lower()
483
+ replaced = False
484
+ for ref in replacements:
485
+ if ref in caption_lower:
486
+ import re
487
+ pattern = re.compile(re.escape(ref), re.IGNORECASE)
488
+ caption = pattern.sub(trigger_word, caption, count=1)
489
+ replaced = True
490
+ break
491
+ if not replaced:
492
+ caption = f"{trigger_word}, {caption}"
493
+ else:
494
+ caption = f"{trigger_word}, {caption}"
495
+ else:
496
+ caption = trigger_word
497
+
498
+ captions.append(caption)
499
+ print(f"[Caption Worker] [{idx+1}/{len(image_paths)}] {caption[:80]}...")
500
+
501
+ del inputs, generated_ids, img
502
+ torch.cuda.empty_cache()
503
+
504
+ except Exception as e:
505
+ print(f"[Caption Worker] Error on image {idx}: {e}")
506
+ captions.append(trigger_word)
507
+
508
+ results_queue.put(captions)
509
+ del model, processor
510
+ torch.cuda.empty_cache()
511
+
512
+ except Exception as e:
513
+ print(f"[Caption Worker] Fatal error: {e}")
514
+ import traceback
515
+ traceback.print_exc()
516
+ # Return trigger word as fallback - need to get image_paths first
517
+ try:
518
+ image_paths = image_paths_queue.get(timeout=1)
519
+ results_queue.put([trigger_word] * len(image_paths))
520
+ except:
521
+ results_queue.put([trigger_word] * 20)
522
+
523
+
524
+ def run_captioning_subprocess(image_paths, trigger_word, is_person=True):
525
+ """Run captioning in subprocess for complete GPU memory isolation"""
526
+ print(f"[Main] Starting captioning subprocess...")
527
+ print(f"[Main] Before: {get_gpu_memory_info()}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
528
 
529
+ ctx = mp.get_context('spawn')
530
+ image_paths_queue = ctx.Queue()
531
+ results_queue = ctx.Queue()
532
+
533
+ worker = ctx.Process(
534
+ target=_caption_worker,
535
+ args=(image_paths_queue, results_queue, trigger_word, is_person)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
536
  )
537
+ worker.start()
538
+ image_paths_queue.put(image_paths)
539
+
540
+ try:
541
+ captions = results_queue.get(timeout=900) # 15 min timeout
542
+ except Exception as e:
543
+ print(f"[Main] Captioning error: {e}")
544
+ captions = [trigger_word] * len(image_paths)
545
+
546
+ worker.join(timeout=30)
547
+ if worker.is_alive():
548
+ worker.terminate()
549
+ worker.join()
550
+
551
+ print(f"[Main] After captioning: {get_gpu_memory_info()}")
552
+ return captions
553
+
554
 
555
+ def prepare_dataset(images, trigger_word, output_dir, use_auto_caption=True, is_person=True):
556
+ """Prepare dataset with subprocess captioning"""
557
+ dataset_dir = Path(output_dir) / "dataset"
558
+ dataset_dir.mkdir(parents=True, exist_ok=True)
559
+
560
+ image_paths = []
561
 
562
+ for i, img in enumerate(images):
563
+ if img is None:
564
+ continue
565
+ if isinstance(img, tuple):
566
+ img = img[0]
567
+ if isinstance(img, str):
568
+ img_pil = Image.open(img)
569
+ elif isinstance(img, np.ndarray):
570
+ img_pil = Image.fromarray(img)
571
+ elif hasattr(img, 'mode'):
572
+ img_pil = img
 
 
 
 
573
  else:
574
+ continue
575
 
576
+ if img_pil.mode != "RGB":
577
+ img_pil = img_pil.convert("RGB")
578
+
579
+ img_path = dataset_dir / f"image_{i:04d}.jpg"
580
+ img_pil.save(img_path, quality=95)
581
+ image_paths.append(str(img_path))
582
+
583
+ if use_auto_caption:
584
+ captions = run_captioning_subprocess(image_paths, trigger_word, is_person)
 
 
 
 
 
 
585
  else:
586
+ captions = [trigger_word] * len(image_paths)
 
 
 
 
 
 
587
 
588
+ for img_path, caption in zip(image_paths, captions):
589
+ caption_path = Path(img_path).with_suffix('.txt')
590
+ caption_path.write_text(caption)
591
+
592
+ return image_paths, captions, str(dataset_dir)
593
+
594
+
595
+ def compute_flow_matching_loss(model_output, target, timesteps):
596
+ """Compute Rectified Flow loss"""
597
+ loss = torch.nn.functional.mse_loss(model_output, target, reduction="none")
598
+ loss = loss.mean(dim=list(range(1, len(loss.shape))))
599
+ return loss.mean()
600
+
601
 
602
+ def upload_to_hub(lora_path, repo_name, trigger_word, training_info, progress_callback=None):
603
+ """Upload to HF Hub"""
604
+ if not HF_TOKEN:
605
+ return False, "HF_TOKEN not configured"
606
 
607
+ try:
608
+ api = HfApi(token=HF_TOKEN)
609
+ username = get_hf_username()
610
+ if not username:
611
+ return False, "Could not get username"
612
 
613
+ repo_id = f"{username}/{repo_name}"
 
 
 
 
 
 
 
614
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
615
  try:
616
+ create_repo(repo_id=repo_id, token=HF_TOKEN, private=True, repo_type="model", exist_ok=True)
617
+ except:
618
+ pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
619
 
620
+ api.upload_file(
621
+ path_or_fileobj=lora_path,
622
+ path_in_repo=f"{repo_name}.safetensors",
623
+ repo_id=repo_id,
624
+ token=HF_TOKEN
625
+ )
626
+
627
+ readme = f"""---
628
+ license: apache-2.0
629
+ base_model: Tongyi-MAI/Z-Image-Turbo
630
+ tags: [lora, z-image, text-to-image, diffusers]
631
+ ---
632
+ # {repo_name}
633
+ Trigger: `{trigger_word}`
634
+ {training_info}
635
+ """
636
+ api.upload_file(path_or_fileobj=readme.encode(), path_in_repo="README.md", repo_id=repo_id, token=HF_TOKEN)
637
+
638
+ return True, f"https://huggingface.co/{repo_id}"
639
+ except Exception as e:
640
+ return False, str(e)
641
+
642
+
643
+ def train_lora(
644
+ images, trigger_word, output_name, num_steps, learning_rate, lora_rank,
645
+ resolution, batch_size, upload_to_hub_flag, hub_repo_name,
646
+ use_auto_caption, is_person_training, progress=gr.Progress()
647
+ ):
648
+ """Train LoRA with CORRECT ZImageTransformer2DModel forward signature"""
649
+
650
+ if not torch.cuda.is_available():
651
+ return None, "❌ No GPU available"
652
+
653
+ if not images or len(images) < 3:
654
+ return None, "❌ Please upload at least 3 images"
655
+
656
+ if not trigger_word:
657
+ return None, "❌ Please specify a trigger word"
658
+
659
+ if not output_name:
660
+ output_name = "z_image_lora"
661
+ output_name = output_name.replace(" ", "_").lower()
662
+
663
+ if upload_to_hub_flag and not HF_TOKEN:
664
+ return None, "❌ HF_TOKEN not configured"
665
+
666
+ progress(0, desc="Initializing...")
667
+ print(f"[Train] Start: {get_gpu_memory_info()}")
668
+ aggressive_cleanup()
669
+
670
+ with tempfile.TemporaryDirectory() as tmpdir:
671
+ try:
672
+ # ============================================
673
+ # PHASE 1: Captioning (Subprocess)
674
+ # ============================================
675
+ progress(0.02, desc="Running Florence-2 captioning (subprocess)...")
676
+
677
+ image_paths, captions, dataset_dir = prepare_dataset(
678
+ images, trigger_word, tmpdir, use_auto_caption, is_person_training
679
+ )
680
+
681
+ if len(image_paths) < 3:
682
+ return None, "❌ Not enough valid images"
683
+
684
+ progress(0.12, desc=f"Captioning done: {len(image_paths)} images")
685
+ aggressive_cleanup()
686
+ print(f"[Train] After captioning cleanup: {get_gpu_memory_info()}")
687
+
688
+ # ============================================
689
+ # PHASE 2: Load Pipeline for Text Encoding
690
+ # ============================================
691
+ progress(0.15, desc="Loading pipeline for encoding...")
692
+ print(f"[Train] Before pipeline: {get_gpu_memory_info()}")
693
+
694
+ from diffusers import ZImagePipeline
695
+
696
+ # Load pipeline to CPU first
697
+ pipe = ZImagePipeline.from_pretrained(
698
+ MODEL_REPO,
699
+ torch_dtype=DTYPE,
700
+ )
701
+
702
+ # Get VAE scaling factor
703
+ vae_scaling_factor = pipe.vae.config.scaling_factor
704
+
705
+ # ============================================
706
+ # PHASE 3: Encode Captions with Text Encoder
707
+ # ============================================
708
+ progress(0.20, desc="Encoding captions...")
709
+
710
+ # Move text encoder to GPU
711
+ pipe.text_encoder.to(DEVICE)
712
+
713
+ cached_text_embeddings = []
714
+
715
+ with torch.no_grad():
716
+ for idx, caption in enumerate(captions):
717
+ text_inputs = pipe.tokenizer(
718
+ caption,
719
+ padding="max_length",
720
+ max_length=256,
721
+ truncation=True,
722
+ return_tensors="pt"
723
+ ).to(DEVICE)
724
+
725
+ text_emb = pipe.text_encoder(**text_inputs)[0]
726
+ cached_text_embeddings.append(text_emb.cpu())
727
+
728
+ del text_inputs, text_emb
729
+
730
+ if idx % 2 == 0:
731
+ torch.cuda.empty_cache()
732
+ progress(0.20 + 0.10 * (idx / len(captions)),
733
+ desc=f"Encoding caption {idx+1}/{len(captions)}")
734
+
735
+ # Free text encoder
736
+ pipe.text_encoder.to("cpu")
737
+ del pipe.text_encoder
738
+ aggressive_cleanup()
739
+ print(f"[Train] After text encoding: {get_gpu_memory_info()}")
740
+
741
+ # ============================================
742
+ # PHASE 4: Encode Images with VAE
743
+ # ============================================
744
+ progress(0.32, desc="Encoding images with VAE...")
745
+
746
+ pipe.vae.to(DEVICE)
747
+
748
+ cached_latents = []
749
+
750
+ with torch.no_grad():
751
+ for idx, img_path in enumerate(image_paths):
752
+ img = Image.open(img_path).convert("RGB")
753
+ img = img.resize((int(resolution), int(resolution)), Image.LANCZOS)
754
+ img_tensor = torch.from_numpy(np.array(img)).permute(2, 0, 1).float() / 255.0
755
+ img_tensor = img_tensor.unsqueeze(0).to(DEVICE, dtype=DTYPE)
756
+ img_tensor = 2.0 * img_tensor - 1.0
757
+
758
+ latent = pipe.vae.encode(img_tensor).latent_dist.sample()
759
+ latent = latent * vae_scaling_factor
760
+ cached_latents.append(latent.cpu())
761
+
762
+ del img_tensor, latent, img
763
+
764
+ if idx % 2 == 0:
765
+ torch.cuda.empty_cache()
766
+ progress(0.32 + 0.08 * (idx / len(image_paths)),
767
+ desc=f"Encoding image {idx+1}/{len(image_paths)}")
768
+
769
+ # Free VAE
770
+ pipe.vae.to("cpu")
771
+ del pipe.vae
772
+ aggressive_cleanup()
773
+ print(f"[Train] After VAE encoding: {get_gpu_memory_info()}")
774
+
775
+ # ============================================
776
+ # PHASE 5: Setup Transformer with Training Adapter
777
+ # ============================================
778
+ progress(0.42, desc="Setting up transformer with training adapter...")
779
+
780
+ # Download training adapter
781
+ try:
782
+ adapter_path = hf_hub_download(
783
+ repo_id="ostris/zimage_turbo_training_adapter",
784
+ filename="zimage_turbo_training_adapter_v1.safetensors",
785
+ local_dir=tmpdir
786
+ )
787
+ print(f"[Train] Training adapter downloaded: {adapter_path}")
788
+ except Exception as e:
789
+ return None, f"❌ Could not download training adapter: {e}"
790
+
791
+ # Get transformer (still on CPU from pipeline)
792
+ transformer = pipe.transformer
793
+
794
+ # Load adapter via pipe's load_lora_weights
795
+ from safetensors.torch import load_file, save_file
796
+
797
+ try:
798
+ pipe.load_lora_weights(adapter_path, adapter_name="training_adapter")
799
+ print("[Train] Training adapter loaded via load_lora_weights")
800
+ except Exception as e:
801
+ print(f"[Train] Warning: Could not load adapter via load_lora_weights: {e}")
802
+
803
+ # Configure our training LoRA
804
+ progress(0.45, desc="Configuring LoRA...")
805
+
806
+ lora_config = LoraConfig(
807
+ r=int(lora_rank),
808
+ lora_alpha=int(lora_rank),
809
+ init_lora_weights="gaussian",
810
+ target_modules=[
811
+ "to_q", "to_k", "to_v", "to_out.0",
812
+ "attn.to_q", "attn.to_k", "attn.to_v", "attn.to_out.0",
813
+ "ff.net.0.proj", "ff.net.2",
814
+ "proj_in", "proj_out",
815
+ ],
816
+ lora_dropout=0.0,
817
+ )
818
+
819
+ transformer = get_peft_model(transformer, lora_config)
820
+ trainable = sum(p.numel() for p in transformer.parameters() if p.requires_grad)
821
+ total = sum(p.numel() for p in transformer.parameters())
822
+ print(f"[Train] Trainable: {trainable:,} / {total:,} ({100*trainable/total:.2f}%)")
823
+
824
+ # Move to GPU
825
+ transformer.to(DEVICE)
826
+ print(f"[Train] Transformer on GPU: {get_gpu_memory_info()}")
827
+
828
+ if hasattr(transformer, 'enable_gradient_checkpointing'):
829
+ transformer.enable_gradient_checkpointing()
830
+
831
+ # Free the rest of pipeline
832
+ del pipe
833
+ aggressive_cleanup()
834
+
835
+ # Optimizer & Scheduler
836
+ optimizer = torch.optim.AdamW(
837
+ [p for p in transformer.parameters() if p.requires_grad],
838
+ lr=learning_rate, weight_decay=0.01, betas=(0.9, 0.999), eps=1e-8
839
+ )
840
+
841
+ from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, LinearLR, SequentialLR
842
+
843
+ warmup_steps = min(100, int(num_steps * 0.1))
844
+ warmup_scheduler = LinearLR(optimizer, start_factor=0.1, total_iters=warmup_steps)
845
+ cosine_scheduler = CosineAnnealingWarmRestarts(
846
+ optimizer, T_0=max(1, int(num_steps - warmup_steps)), eta_min=learning_rate * 0.01
847
+ )
848
+ lr_scheduler = SequentialLR(
849
+ optimizer, [warmup_scheduler, cosine_scheduler], milestones=[warmup_steps]
850
+ )
851
+
852
+ progress(0.50, desc=f"Training with {len(cached_latents)} samples...")
853
+
854
+ # ============================================
855
+ # PHASE 6: Training Loop with CORRECT forward signature
856
+ # ZImageTransformer2DModel.forward(x, t, cap_feats, ...)
857
+ # x: List[Tensor] where each tensor is [C, F, H, W]
858
+ # cap_feats: List[Tensor] where each tensor is [seq_len, dim]
859
+ # ============================================
860
+ transformer.train()
861
+ losses = []
862
+ successful_steps = 0
863
+
864
+ for step in range(int(num_steps)):
865
+ try:
866
+ idx = np.random.randint(0, len(cached_latents))
867
+
868
+ # latents: [B, C, H, W] -> need [C, F, H, W] where F=1
869
+ latents = cached_latents[idx].to(DEVICE, dtype=DTYPE)
870
+ # Remove batch dim, add frame dim: [1, C, H, W] -> [C, 1, H, W]
871
+ latents = latents.squeeze(0).unsqueeze(1) # [C, 1, H, W]
872
+
873
+ # text_embeddings: [B, seq_len, dim] -> [seq_len, dim]
874
+ text_embeddings = cached_text_embeddings[idx].to(DEVICE, dtype=DTYPE)
875
+ text_embeddings = text_embeddings.squeeze(0) # [seq_len, dim]
876
+
877
+ # Timestep for flow matching (0 to 1)
878
+ timesteps = torch.rand(1, device=DEVICE, dtype=DTYPE)
879
+
880
+ # Create noisy latents using flow matching interpolation
881
+ noise = torch.randn_like(latents)
882
+ t = timesteps.view(-1, 1, 1, 1)
883
+ noisy_latents = (1 - t) * latents + t * noise
884
+
885
+ # Target is the velocity: noise - clean
886
+ target = noise - latents
887
+
888
+ # Scale timestep for model
889
+ t_input = timesteps * 1000
890
+
891
+ # CORRECT FORWARD CALL:
892
+ # x and cap_feats must be Lists!
893
+ with torch.amp.autocast('cuda', dtype=DTYPE):
894
+ output = transformer(
895
+ x=[noisy_latents], # List of [C, F, H, W]
896
+ t=t_input, # timestep
897
+ cap_feats=[text_embeddings], # List of [seq_len, dim]
898
+ return_dict=True
899
+ )
900
+
901
+ # Get model output - it will also be a list
902
+ if hasattr(output, 'sample'):
903
+ model_output = output.sample
904
+ if isinstance(model_output, list):
905
+ model_output = model_output[0]
906
+ elif isinstance(output, tuple):
907
+ model_output = output[0]
908
+ if isinstance(model_output, list):
909
+ model_output = model_output[0]
910
+ else:
911
+ model_output = output
912
+ if isinstance(model_output, list):
913
+ model_output = model_output[0]
914
+
915
+ loss = compute_flow_matching_loss(model_output, target, timesteps)
916
+
917
+ optimizer.zero_grad()
918
+ loss.backward()
919
+ torch.nn.utils.clip_grad_norm_(transformer.parameters(), 1.0)
920
+ optimizer.step()
921
+ lr_scheduler.step()
922
+
923
+ losses.append(loss.item())
924
+ successful_steps += 1
925
+
926
+ del latents, text_embeddings, noise, noisy_latents, target, model_output, loss, output
927
+
928
+ if step % 25 == 0:
929
+ avg_loss = np.mean(losses[-50:]) if len(losses) >= 50 else np.mean(losses) if losses else float('nan')
930
+ progress(
931
+ 0.50 + 0.40 * (step / int(num_steps)),
932
+ desc=f"Step {step}/{int(num_steps)} | Loss: {avg_loss:.4f}"
933
+ )
934
+ print(f"[Train] Step {step}: Loss={avg_loss:.4f}")
935
+
936
+ if step % 100 == 0:
937
+ gc.collect()
938
+ torch.cuda.empty_cache()
939
+
940
+ except Exception as e:
941
+ if step < 5:
942
+ print(f"[Train] Error at step {step}: {e}")
943
+ import traceback
944
+ traceback.print_exc()
945
+ continue
946
+
947
+ if successful_steps == 0:
948
+ return None, "❌ Training failed - no successful steps. Check model forward signature."
949
+
950
+ # ============================================
951
+ # PHASE 7: Save LoRA
952
+ # ============================================
953
+ progress(0.92, desc="Saving LoRA...")
954
+
955
+ del cached_latents, cached_text_embeddings
956
+ aggressive_cleanup()
957
+
958
+ lora_state_dict = {}
959
+ for name, param in transformer.named_parameters():
960
+ if "lora" in name.lower() and param.requires_grad:
961
+ clean_name = name.replace("base_model.model.", "")
962
+ lora_state_dict[clean_name] = param.detach().cpu()
963
+
964
+ if not lora_state_dict:
965
+ return None, "❌ No LoRA weights found"
966
+
967
+ final_output = f"/tmp/{output_name}.safetensors"
968
+ save_file(lora_state_dict, final_output)
969
+
970
+ file_size = os.path.getsize(final_output) / (1024 * 1024)
971
+ avg_final_loss = np.mean(losses[-100:]) if len(losses) >= 100 else np.mean(losses) if losses else float('nan')
972
+
973
+ training_info = f"""
974
+ - Images: {len(image_paths)}
975
+ - Steps: {successful_steps}
976
+ - Final Loss: {avg_final_loss:.4f}
977
+ - LR: {learning_rate}, Rank: {int(lora_rank)}, Resolution: {int(resolution)}
978
+ """
979
+
980
+ hub_result = ""
981
+ if upload_to_hub_flag:
982
+ progress(0.94, desc="Uploading to Hub...")
983
+ success, result = upload_to_hub(
984
+ final_output, hub_repo_name or output_name, trigger_word, training_info
985
+ )
986
+ hub_result = f"\n\n🚀 Uploaded: {result}" if success else f"\n\n⚠️ Upload failed: {result}"
987
+
988
+ del transformer
989
+ aggressive_cleanup()
990
+ progress(1.0, desc="Complete!")
991
+
992
+ sample_captions = "\n".join([f" - {c[:80]}..." for c in captions[:3]])
993
+
994
+ return final_output, f"""✅ Training complete!
995
+
996
+ 📁 LoRA: {output_name}.safetensors ({file_size:.1f} MB)
997
+ 🏷️ Trigger: {trigger_word}
998
+ 📊 Loss: {avg_final_loss:.4f}
999
+ 🖼️ Images: {len(image_paths)}
1000
+ ⚙️ Steps: {successful_steps}
1001
+
1002
+ **Sample captions:**
1003
+ {sample_captions}{hub_result}
1004
+
1005
+ **Usage:**
1006
+ ```python
1007
+ from diffusers import ZImagePipeline
1008
+ import torch
1009
+
1010
+ pipe = ZImagePipeline.from_pretrained("Tongyi-MAI/Z-Image-Turbo", torch_dtype=torch.bfloat16)
1011
+ pipe.load_lora_weights("{output_name}.safetensors")
1012
+ pipe.to("cuda")
1013
+
1014
+ image = pipe("{trigger_word}, your prompt here", num_inference_steps=8, guidance_scale=0.0).images[0]
1015
+ ```"""
1016
+
1017
  except Exception as e:
1018
+ aggressive_cleanup()
1019
+ import traceback
1020
+ return None, f"❌ Error: {str(e)}\n\n{traceback.format_exc()}"
1021
+
1022
+
1023
+ # ============================================
1024
+ # Gradio UI with Comic Style
1025
+ # ============================================
1026
+ with gr.Blocks(css=COMIC_CSS, theme=gr.themes.Soft(), title="Z-IMAGE GEN/LORA") as demo:
1027
+
1028
+ # HOME Button
1029
+ gr.HTML("""
1030
+ <div class="home-button-container">
1031
+ <a href="https://www.ginigen.com" target="_blank" class="home-button">
1032
+ 🏠 HOME
1033
+ </a>
1034
+ <span class="url-display">🌐 www.ginigen.com</span>
1035
+ </div>
1036
+ """)
1037
+
1038
+ # Header
1039
+ gr.HTML("""
1040
+ <div class="header-container">
1041
+ <div class="header-title">🎨 Z-IMAGE GEN/LORA 🎨</div>
1042
+ <div class="header-subtitle">Train custom LoRA for Z-Image Turbo with Florence-2 auto-captioning</div>
1043
+ <div style="margin-top:12px">
1044
+ <span class="stats-badge">🧠 Florence-2 Caption</span>
1045
+ <span class="stats-badge">⚡ Memory Optimized</span>
1046
+ <span class="stats-badge">🚀 Hub Upload</span>
1047
+ <span class="stats-badge">🎯 Person/Style/Object</span>
1048
+ </div>
1049
+ </div>
1050
+ """)
1051
+
1052
+ # Status Row
1053
+ gr.HTML('<div class="info-box">📊 <b>System Status</b> - Check GPU and HuggingFace connection</div>')
1054
+
1055
  with gr.Row():
1056
+ with gr.Column(scale=1):
1057
+ gpu_status = gr.Textbox(label="🖥️ GPU Status", value=check_gpu(), interactive=False)
1058
+ with gr.Column(scale=1):
1059
+ hf_status = gr.Textbox(label="🔑 HF Token", value=check_hf_token(), interactive=False)
1060
+ refresh_btn = gr.Button("🔄 Refresh", size="sm")
1061
+ refresh_btn.click(fn=lambda: (check_gpu(), check_hf_token()), outputs=[gpu_status, hf_status])
1062
+
1063
  with gr.Row():
1064
+ # Left Column - Images
1065
+ with gr.Column(scale=1):
1066
+ gr.HTML('<div class="info-box">📸 <b>Training Images</b> - Upload 6-20 high-quality images</div>')
1067
+ images = gr.Gallery(
1068
+ label="Upload Images",
1069
+ columns=4,
1070
+ height=300,
1071
+ type="filepath",
1072
+ interactive=True
1073
  )
1074
+ gr.HTML('<div class="tips-box">💡 <b>Tips:</b> 6-20 images, varied poses/angles, consistent subject, good lighting</div>')
 
 
 
 
 
 
 
1075
 
1076
+ # Right Column - Settings
1077
+ with gr.Column(scale=1):
1078
+ gr.HTML('<div class="info-box">⚙️ <b>Training Settings</b> - Configure your LoRA training</div>')
1079
+
1080
+ trigger_word = gr.Textbox(
1081
+ label="🏷️ Trigger Word",
1082
+ placeholder="ohwx_person",
1083
+ info="Use unique token like 'ohwx' to avoid conflicts"
1084
+ )
1085
+ output_name = gr.Textbox(
1086
+ label="📁 Output Name",
1087
+ placeholder="my_lora"
1088
+ )
1089
+
1090
  with gr.Row():
1091
+ use_auto_caption = gr.Checkbox(label="🔤 Auto-Caption (Florence-2)", value=True)
1092
+ is_person_training = gr.Checkbox(label="👤 Person/Face Training", value=True)
1093
+
1094
+ with gr.Row():
1095
+ num_steps = gr.Slider(500, 5000, 1500, step=100, label="🔢 Steps")
1096
+ learning_rate = gr.Slider(1e-5, 5e-4, 5e-5, step=1e-5, label="📈 Learning Rate")
1097
+
1098
+ with gr.Row():
1099
+ lora_rank = gr.Slider(4, 64, 32, step=4, label="🎚️ LoRA Rank")
1100
+ resolution = gr.Slider(512, 1024, 1024, step=128, label="📐 Resolution")
1101
+
1102
+ batch_size = gr.Slider(1, 4, 1, step=1, visible=False)
1103
+
1104
+ gr.HTML('<div class="info-box">🚀 <b>Hub Upload</b> - Upload trained LoRA to HuggingFace</div>')
1105
+
1106
+ with gr.Row():
1107
+ upload_to_hub_flag = gr.Checkbox(label="📤 Upload to HF Hub (Private)", value=False)
1108
+ hub_repo_name = gr.Textbox(label="📦 Repo Name", placeholder="my-zimage-lora")
1109
+
1110
+ # Train Button
1111
+ with gr.Row():
1112
+ train_btn = gr.Button("🚀 START TRAINING!", variant="primary", size="lg")
1113
+
1114
+ # Output
1115
+ with gr.Row():
1116
+ with gr.Column(scale=1):
1117
+ output_file = gr.File(label="📥 Download LoRA")
1118
+ with gr.Column(scale=1):
1119
+ output_log = gr.Textbox(label="📋 Training Log", lines=15)
1120
+
1121
+ train_btn.click(
1122
+ fn=train_lora,
1123
+ inputs=[
1124
+ images, trigger_word, output_name, num_steps, learning_rate, lora_rank,
1125
+ resolution, batch_size, upload_to_hub_flag, hub_repo_name,
1126
+ use_auto_caption, is_person_training
1127
+ ],
1128
+ outputs=[output_file, output_log]
1129
  )
1130
 
1131
+ # Recommended Settings Table
1132
+ gr.HTML("""
1133
+ <div class="info-box">
1134
+ 📋 <b>Recommended Settings by Use Case</b>
1135
+ <table style="width:100%; margin-top:10px; border-collapse: collapse;">
1136
+ <tr style="background:#3B82F6; color:white;">
1137
+ <th style="padding:8px; border:2px solid #1F2937;">Use Case</th>
1138
+ <th style="padding:8px; border:2px solid #1F2937;">Steps</th>
1139
+ <th style="padding:8px; border:2px solid #1F2937;">LR</th>
1140
+ <th style="padding:8px; border:2px solid #1F2937;">Rank</th>
1141
+ </tr>
1142
+ <tr style="background:#FEF9C3;">
1143
+ <td style="padding:8px; border:2px solid #1F2937;">👤 Person</td>
1144
+ <td style="padding:8px; border:2px solid #1F2937;">1500</td>
1145
+ <td style="padding:8px; border:2px solid #1F2937;">5e-5</td>
1146
+ <td style="padding:8px; border:2px solid #1F2937;">32</td>
1147
+ </tr>
1148
+ <tr style="background:#FFF;">
1149
+ <td style="padding:8px; border:2px solid #1F2937;">🎨 Style</td>
1150
+ <td style="padding:8px; border:2px solid #1F2937;">2000</td>
1151
+ <td style="padding:8px; border:2px solid #1F2937;">1e-4</td>
1152
+ <td style="padding:8px; border:2px solid #1F2937;">16</td>
1153
+ </tr>
1154
+ <tr style="background:#FEF9C3;">
1155
+ <td style="padding:8px; border:2px solid #1F2937;">📦 Object</td>
1156
+ <td style="padding:8px; border:2px solid #1F2937;">1200</td>
1157
+ <td style="padding:8px; border:2px solid #1F2937;">8e-5</td>
1158
+ <td style="padding:8px; border:2px solid #1F2937;">24</td>
1159
+ </tr>
1160
+ </table>
1161
+ </div>
1162
+ """)
1163
+
1164
+ # Footer
1165
+ gr.HTML("""
1166
+ <div class="footer-comic">
1167
+ <p style="font-family:'Bangers',cursive;font-size:1.5rem;letter-spacing:2px">🎨 Z-IMAGE GEN/LORA 🎨</p>
1168
+ <p>Powered by Z-Image Turbo + Florence-2 + PEFT</p>
1169
+ <p>🧠 Auto Caption • ⚡ Memory Optimized • 🚀 Fast Training • 📤 Hub Upload</p>
1170
+ <p style="margin-top:10px"><a href="https://www.ginigen.com" target="_blank" style="color:#FACC15;text-decoration:none;font-weight:bold;">🏠 www.ginigen.com</a></p>
1171
+ </div>
1172
+ """)
1173
+
1174
+ if __name__ == "__main__":
1175
+ mp.set_start_method('spawn', force=True)
1176
+ demo.launch()