Alexander Bagus commited on
Commit
b30e434
·
1 Parent(s): 8009598

Initial commit

Browse files
Files changed (6) hide show
  1. README.md +2 -3
  2. app.py +421 -0
  3. requirements.txt +8 -0
  4. utils/image_utils.py +76 -0
  5. utils/prompt_utils.py +71 -0
  6. utils/repo_utils.py +33 -0
README.md CHANGED
@@ -1,11 +1,10 @@
1
  ---
2
  title: Z Image To LoRA
3
- emoji: 🌖
4
  colorFrom: gray
5
  colorTo: green
6
  sdk: gradio
7
- sdk_version: 6.5.1
8
- python_version: '3.12'
9
  app_file: app.py
10
  pinned: false
11
  license: apache-2.0
 
1
  ---
2
  title: Z Image To LoRA
3
+ emoji: 🥹
4
  colorFrom: gray
5
  colorTo: green
6
  sdk: gradio
7
+ sdk_version: 6.1.0
 
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
app.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import torch, random, json, spaces, time
4
+ from ulid import ULID
5
+ from diffsynth.pipelines.z_image import (
6
+ ModelConfig, ZImageUnit_Image2LoRAEncode, ZImageUnit_Image2LoRADecode
7
+ )
8
+ from diffsynth.pipelines.z_image import ZImagePipeline as ZImagePipelineDs
9
+ from diffusers import ZImagePipeline
10
+ from safetensors.torch import save_file, load_file
11
+ import torch
12
+ from PIL import Image
13
+ from pathlib import Path
14
+ from utils import repo_utils, image_utils, prompt_utils
15
+ from huggingface_hub import snapshot_download
16
+ import glob
17
+
18
+ # repo_utils.clone_repo_if_not_exists("git clone https://huggingface.co/DiffSynth-Studio/General-Image-Encoders", "app/repos")
19
+ # repo_utils.clone_repo_if_not_exists("https://huggingface.co/apple/starflow", "app/models")
20
+
21
+ URL_PUBLIC = "https://huggingface.co/spaces/AiSudo/Qwen-Image-to-LoRA/blob/main"
22
+ DTYPE = torch.bfloat16
23
+ MAX_SEED = np.iinfo(np.int32).max
24
+
25
+ MODELS_DIR = Path("./models")
26
+ def download_hf_models(output_dir: Path) -> dict:
27
+ """
28
+ Download required models from Hugging Face using huggingface_hub.
29
+
30
+ Downloads:
31
+ - DiffSynth-Studio/Z-Image-i2L
32
+ - Tongyi-MAI/Z-Image
33
+ - DiffSynth-Studio/General-Image-Encoders
34
+ - Tongyi-MAI/Z-Image-Turbo
35
+
36
+ Returns dict with paths to downloaded models.
37
+ """
38
+
39
+ output_dir.mkdir(parents=True, exist_ok=True)
40
+
41
+ models = [
42
+ {
43
+ "repo_id": "DiffSynth-Studio/General-Image-Encoders",
44
+ "description": "General Image Encoders (SigLIP2-G384, DINOv3-7B)",
45
+ "allow_patterns": None,
46
+ },
47
+ {
48
+ "repo_id": "Tongyi-MAI/Z-Image-Turbo",
49
+ "description": "Z-Image Turbo",
50
+ },
51
+ {
52
+ "repo_id": "Tongyi-MAI/Z-Image",
53
+ "description": "Z-Image base model (transformer)",
54
+ "allow_patterns": ["transformer/*.safetensors"],
55
+ },
56
+ {
57
+ "repo_id": "DiffSynth-Studio/Z-Image-i2L",
58
+ "description": "Z-Image-i2L (Image to LoRA model)",
59
+ "allow_patterns": ["*.safetensors"],
60
+ },
61
+ ]
62
+
63
+ downloaded_paths = {}
64
+
65
+ for model in models:
66
+ repo_id = model["repo_id"]
67
+ local_dir = output_dir / repo_id
68
+
69
+ # Check if already downloaded
70
+ if local_dir.exists() and any(local_dir.rglob("*.safetensors")):
71
+ print(f" ✓ {repo_id} (already downloaded)")
72
+ downloaded_paths[repo_id] = local_dir
73
+ continue
74
+
75
+ print(f" 📥 Downloading {repo_id}...")
76
+ print(f" {model['description']}")
77
+
78
+ try:
79
+ result_path = snapshot_download(
80
+ repo_id=repo_id,
81
+ local_dir=str(local_dir),
82
+ allow_patterns=model["allow_patterns"],
83
+ local_dir_use_symlinks=False,
84
+ resume_download=True,
85
+ )
86
+ downloaded_paths[repo_id] = Path(result_path)
87
+ print(f" ✓ {repo_id}")
88
+ except Exception as e:
89
+ print(f" ❌ Error downloading {repo_id}: {e}")
90
+ raise
91
+
92
+ return downloaded_paths
93
+
94
+ def get_model_files(base_path: Path, pattern: str) -> list:
95
+ """Get list of files matching a glob pattern."""
96
+ full_pattern = str(base_path / pattern)
97
+ files = sorted(glob.glob(full_pattern))
98
+ return files
99
+
100
+ downloaded_paths = download_hf_models(MODELS_DIR)
101
+
102
+ zimage_path = MODELS_DIR / "Tongyi-MAI" / "Z-Image"
103
+ zimage_transformer_files = get_model_files(zimage_path, "transformer/*.safetensors")
104
+
105
+ # Z-Image-Turbo
106
+ zimage_turbo_path = MODELS_DIR / "Tongyi-MAI" / "Z-Image-Turbo"
107
+ text_encoder_files = get_model_files(zimage_turbo_path, "text_encoder/*.safetensors")
108
+ vae_file = get_model_files(zimage_turbo_path, "vae/diffusion_pytorch_model.safetensors")
109
+ tokenizer_path = zimage_turbo_path / "tokenizer"
110
+
111
+ # General Image Encoders
112
+ encoders_path = MODELS_DIR / "DiffSynth-Studio" / "General-Image-Encoders"
113
+ siglip_file = get_model_files(encoders_path, "SigLIP2-G384/model.safetensors")
114
+ dino_file = get_model_files(encoders_path, "DINOv3-7B/model.safetensors")
115
+
116
+ # Z-Image-i2L from HuggingFace
117
+ zimage_i2l_path = MODELS_DIR / "DiffSynth-Studio" / "Z-Image-i2L"
118
+ zimage_i2l_file = get_model_files(zimage_i2l_path, "model.safetensors")
119
+
120
+ print(f" Z-Image transformer: {len(zimage_transformer_files)} file(s)")
121
+ print(f" Text encoder: {len(text_encoder_files)} file(s)")
122
+ print(f" VAE: {len(vae_file)} file(s)")
123
+ print(f" Tokenizer: {tokenizer_path}")
124
+ print(f" SigLIP2: {len(siglip_file)} file(s)")
125
+ print(f" DINOv3: {len(dino_file)} file(s)")
126
+ print(f" Z-Image-i2L: {len(zimage_i2l_file)} file(s)")
127
+
128
+ ################
129
+ vram_config = {
130
+ "offload_dtype": torch.bfloat16,
131
+ "offload_device": "cuda",
132
+ "onload_dtype": torch.bfloat16,
133
+ "onload_device": "cuda",
134
+ "preparing_dtype": torch.bfloat16,
135
+ "preparing_device": "cuda",
136
+ "computation_dtype": torch.bfloat16,
137
+ "computation_device": "cuda",
138
+ }
139
+
140
+ model_configs = [
141
+ # All models from HuggingFace - use path= for local files
142
+ ModelConfig(path=zimage_transformer_files, **vram_config),
143
+ ModelConfig(path=text_encoder_files),
144
+ ModelConfig(path=vae_file),
145
+ ModelConfig(path=siglip_file),
146
+ ModelConfig(path=dino_file),
147
+ ModelConfig(path=zimage_i2l_file),
148
+ ]
149
+
150
+ pipe_lora = ZImagePipelineDs.from_pretrained(
151
+ torch_dtype=torch.bfloat16,
152
+ device="cuda",
153
+ model_configs=model_configs,
154
+ tokenizer_config=ModelConfig(path=str(tokenizer_path)),
155
+ )
156
+
157
+ pipe_imagen = ZImagePipeline.from_pretrained(
158
+ "./models/Tongyi-MAI/Z-Image-Turbo",
159
+ torch_dtype=torch.bfloat16,
160
+ low_cpu_mem_usage=False,
161
+ )
162
+ pipe_imagen.to("cuda")
163
+
164
+ # # Load models
165
+ # pipe_lora = QwenImagePipeline.from_pretrained(
166
+ # torch_dtype=torch.bfloat16,
167
+ # device="cuda",
168
+ # model_configs=[
169
+ # ModelConfig(
170
+ # download_source="huggingface",
171
+ # model_id="DiffSynth-Studio/General-Image-Encoders",
172
+ # origin_file_pattern="SigLIP2-G384/model.safetensors",
173
+ # **vram_config
174
+ # ),
175
+ # ModelConfig(
176
+ # download_source="huggingface",
177
+ # model_id="DiffSynth-Studio/General-Image-Encoders",
178
+ # origin_file_pattern="DINOv3-7B/model.safetensors",
179
+ # **vram_config
180
+ # ),
181
+ # ModelConfig(
182
+ # download_source="huggingface",
183
+ # model_id="DiffSynth-Studio/Z-Image-i2L",
184
+ # origin_file_pattern="*.safetensors",
185
+ # **vram_config
186
+ # ),
187
+ # ],
188
+ # processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
189
+ # vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
190
+ # )
191
+
192
+ # vram_config = {
193
+ # "offload_dtype": "disk",
194
+ # "offload_device": "disk",
195
+ # "onload_dtype": torch.bfloat16,
196
+ # "onload_device": "cuda",
197
+ # "preparing_dtype": torch.bfloat16,
198
+ # "preparing_device": "cuda",
199
+ # "computation_dtype": torch.bfloat16,
200
+ # "computation_device": "cuda",
201
+ # }
202
+
203
+ # pipe_imagen = QwenImagePipeline.from_pretrained(
204
+ # torch_dtype=torch.bfloat16,
205
+ # device="cuda",
206
+ # model_configs=[
207
+ # ModelConfig(download_source="huggingface", model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
208
+ # ModelConfig(download_source="huggingface", model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config),
209
+ # ModelConfig(download_source="huggingface", model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
210
+ # ],
211
+ # tokenizer_config=ModelConfig(download_source="huggingface", model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
212
+ # vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
213
+ # )
214
+
215
+
216
+ @spaces.GPU
217
+ def generate_lora(
218
+ input_images,
219
+ progress=gr.Progress(track_tqdm=True),
220
+ ):
221
+
222
+ ulid = str(ULID()).lower()[:12]
223
+ print(f"ulid: {ulid}")
224
+
225
+ if not input_images:
226
+ print("images are empty.")
227
+ return False
228
+
229
+ progress(0.1, desc="Processing images...")
230
+ pil_images = [Image.open(filepath).convert("RGB") for filepath, _ in input_images]
231
+
232
+ progress(0.3, desc="Encoding images to LoRA...")
233
+ # Model inference
234
+ with torch.no_grad():
235
+ embs = ZImageUnit_Image2LoRAEncode().process(pipe_lora, image2lora_images=pil_images)
236
+ progress(0.7, desc="Decoding LoRA weights...")
237
+ lora = ZImageUnit_Image2LoRADecode().process(pipe_lora, **embs)["lora"]
238
+
239
+ progress(0.9, desc="Saving LoRA file...")
240
+ lora_name = f"{ulid}.safetensors"
241
+ lora_path = f"loras/{lora_name}"
242
+
243
+ progress(1.0, desc="Done!")
244
+
245
+ save_file(lora, lora_path)
246
+
247
+ return lora_name, gr.update(interactive=True, value=lora_path), gr.update(interactive=True)
248
+
249
+ @spaces.GPU
250
+ def generate_image(
251
+ lora_name,
252
+ prompt,
253
+ negative_prompt="blurry ugly bad",
254
+ width=1024,
255
+ height=1024,
256
+ seed=42,
257
+ randomize_seed=True,
258
+ guidance_scale=3.5,
259
+ num_inference_steps=8,
260
+ progress=gr.Progress(track_tqdm=True),
261
+ ):
262
+ lora_path = f"loras/{lora_name}"
263
+ pipe_imagen.clear_lora()
264
+ pipe_imagen.load_lora(pipe_imagen.dit, lora_path)
265
+
266
+ if randomize_seed:
267
+ seed = random.randint(0, MAX_SEED)
268
+
269
+ generator = torch.Generator().manual_seed(seed)
270
+
271
+ output_image = pipe_imagen(
272
+ prompt=prompt,
273
+ negative_prompt=negative_prompt,
274
+ num_inference_steps=num_inference_steps,
275
+ width=width,
276
+ height=height,
277
+ # generator=generator,
278
+ # true_cfg_scale=guidance_scale,
279
+ # guidance_scale=1.0 # Use a fixed default for distilled guidance
280
+ )
281
+
282
+ return output_image, seed
283
+
284
+ return True
285
+
286
+
287
+ def read_file(path: str) -> str:
288
+ with open(path, 'r', encoding='utf-8') as f:
289
+ content = f.read()
290
+ return content
291
+
292
+ css = """
293
+ #col-container {
294
+ margin: 0 auto;
295
+ max-width: 960px;
296
+ }
297
+ h3{
298
+ text-align: center;
299
+ display:block;
300
+ }
301
+
302
+ """
303
+
304
+
305
+
306
+ with open('examples/0_examples.json', 'r') as file: examples = json.load(file)
307
+ print(examples)
308
+ with gr.Blocks() as demo:
309
+ with gr.Column(elem_id="col-container"):
310
+ with gr.Column():
311
+ gr.HTML(read_file("static/header.html"))
312
+ with gr.Row():
313
+ with gr.Column():
314
+ input_images = gr.Gallery(
315
+ label="Input images",
316
+ file_types=["image"],
317
+ show_label=False,
318
+ elem_id="gallery",
319
+ columns=2,
320
+ object_fit="cover",
321
+ height=300)
322
+
323
+ lora_button = gr.Button("Generate LoRA", variant="primary")
324
+
325
+ with gr.Column():
326
+ lora_name = gr.Textbox(label="Generated LoRA path",lines=2, interactive=False)
327
+ lora_download = gr.DownloadButton(label=f"Download LoRA", interactive=False)
328
+ with gr.Column(elem_id='imagen-container') as imagen_container:
329
+ gr.Markdown("### After your LoRA is ready, you can try generate image here.")
330
+ with gr.Row():
331
+ with gr.Column():
332
+ prompt = gr.Textbox(
333
+ label="Prompt",
334
+ show_label=False,
335
+ lines=2,
336
+ placeholder="Enter your prompt",
337
+ value="a man in a fishing boat.",
338
+ container=False,
339
+ )
340
+
341
+ imagen_button = gr.Button("Generate Image", variant="primary", interactive=False)
342
+ with gr.Accordion("Advanced Settings", open=False):
343
+ negative_prompt = gr.Textbox(
344
+ label="Negative prompt",
345
+ lines=2,
346
+ container=False,
347
+ placeholder="Enter your negative prompt",
348
+ value="blurry ugly bad"
349
+ )
350
+ num_inference_steps = gr.Slider(
351
+ label="Steps",
352
+ minimum=1,
353
+ maximum=50,
354
+ step=1,
355
+ value=25,
356
+ )
357
+ with gr.Row():
358
+ width = gr.Slider(
359
+ label="Width",
360
+ minimum=512,
361
+ maximum=1280,
362
+ step=32,
363
+ value=768,
364
+ )
365
+
366
+ height = gr.Slider(
367
+ label="Height",
368
+ minimum=512,
369
+ maximum=1280,
370
+ step=32,
371
+ value=1024,
372
+ )
373
+ with gr.Row():
374
+ seed = gr.Slider(
375
+ label="Seed",
376
+ minimum=0,
377
+ maximum=MAX_SEED,
378
+ step=1,
379
+ value=42,
380
+ )
381
+ guidance_scale = gr.Slider(
382
+ label="Guidance scale",
383
+ minimum=0.0,
384
+ maximum=10.0,
385
+ step=0.1,
386
+ value=3.5,
387
+ )
388
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
389
+
390
+ with gr.Column():
391
+ output_image = gr.Image(label="Generated image", show_label=False)
392
+
393
+ gr.Examples(examples=examples, inputs=[input_images])
394
+ gr.Markdown(read_file("static/footer.md"))
395
+
396
+ lora_button.click(
397
+ fn=generate_lora,
398
+ inputs=[
399
+ input_images
400
+ ],
401
+ outputs=[lora_name, lora_download, imagen_button],
402
+ )
403
+ imagen_button.click(
404
+ fn=generate_image,
405
+ inputs=[
406
+ lora_name,
407
+ prompt,
408
+ negative_prompt,
409
+ width,
410
+ height,
411
+ seed,
412
+ randomize_seed,
413
+ guidance_scale,
414
+ num_inference_steps,
415
+ ],
416
+ outputs=[output_image, seed],
417
+ )
418
+
419
+
420
+ if __name__ == "__main__":
421
+ demo.launch(mcp_server=True, css=css)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ transformers
4
+ accelerate
5
+ spaces
6
+ git+https://github.com/huggingface/diffusers.git
7
+ git+https://github.com/modelscope/DiffSynth-Studio.git
8
+ python-ulid
utils/image_utils.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ import numpy as np
4
+
5
+ def rescale_image(img, scale, nearest=32, max_size=1280):
6
+ w, h = img.size
7
+ new_w = int(w * scale)
8
+ new_h = int(h * scale)
9
+
10
+ if new_w > max_size or new_h > max_size:
11
+ # Calculate new size keeping aspect ratio
12
+ scale = min(max_size / new_w, max_size / new_h)
13
+ new_w = int(new_w * scale)
14
+ new_h = int(new_h * scale)
15
+
16
+ # Adjust to nearest multiple
17
+ new_w = (new_w // nearest) * nearest
18
+ new_h = (new_h // nearest) * nearest
19
+
20
+ return img.resize((new_w, new_h), Image.LANCZOS), new_w, new_h
21
+
22
+ def padding_image(images, new_width, new_height):
23
+ new_image = Image.new('RGB', (new_width, new_height), (255, 255, 255))
24
+
25
+ aspect_ratio = images.width / images.height
26
+ if new_width / new_height > 1:
27
+ if aspect_ratio > new_width / new_height:
28
+ new_img_width = new_width
29
+ new_img_height = int(new_img_width / aspect_ratio)
30
+ else:
31
+ new_img_height = new_height
32
+ new_img_width = int(new_img_height * aspect_ratio)
33
+ else:
34
+ if aspect_ratio > new_width / new_height:
35
+ new_img_width = new_width
36
+ new_img_height = int(new_img_width / aspect_ratio)
37
+ else:
38
+ new_img_height = new_height
39
+ new_img_width = int(new_img_height * aspect_ratio)
40
+
41
+ resized_img = images.resize((new_img_width, new_img_height))
42
+
43
+ paste_x = (new_width - new_img_width) // 2
44
+ paste_y = (new_height - new_img_height) // 2
45
+
46
+ new_image.paste(resized_img, (paste_x, paste_y))
47
+
48
+ return new_image
49
+
50
+
51
+ def get_image_latent(ref_image=None, sample_size=None, padding=False):
52
+ if ref_image is not None:
53
+ if isinstance(ref_image, str):
54
+ ref_image = Image.open(ref_image).convert("RGB")
55
+ if padding:
56
+ ref_image = padding_image(
57
+ ref_image, sample_size[1], sample_size[0])
58
+ ref_image = ref_image.resize((sample_size[1], sample_size[0]))
59
+ ref_image = torch.from_numpy(np.array(ref_image))
60
+ ref_image = ref_image.unsqueeze(0).permute(
61
+ [3, 0, 1, 2]).unsqueeze(0) / 255
62
+ elif isinstance(ref_image, Image.Image):
63
+ ref_image = ref_image.convert("RGB")
64
+ if padding:
65
+ ref_image = padding_image(
66
+ ref_image, sample_size[1], sample_size[0])
67
+ ref_image = ref_image.resize((sample_size[1], sample_size[0]))
68
+ ref_image = torch.from_numpy(np.array(ref_image))
69
+ ref_image = ref_image.unsqueeze(0).permute(
70
+ [3, 0, 1, 2]).unsqueeze(0) / 255
71
+ else:
72
+ ref_image = torch.from_numpy(np.array(ref_image))
73
+ ref_image = ref_image.unsqueeze(0).permute(
74
+ [3, 0, 1, 2]).unsqueeze(0) / 255
75
+
76
+ return ref_image
utils/prompt_utils.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from huggingface_hub import InferenceClient
3
+
4
+ # source: https://huggingface.co/spaces/InstantX/Qwen-Image-ControlNet/blob/main/app.py
5
+ def polish_prompt(original_prompt):
6
+ # """Rewrites the prompt using a Hugging Face InferenceClient."""
7
+ magic_prompt = "Ultra HD, 4K, cinematic composition"
8
+ system_prompt = """
9
+ You are a prompt engineering expert for text-to-image models. Since text-to-image models have limited capabilities in understanding user prompts, you need to identify the core theme and intent of the user's input and improve the model's understanding accuracy and generation quality through optimization and rewriting. The rewrite must strictly retain all information from the user's original prompt without deleting or distorting any details.
10
+ Specific requirements are as follows:
11
+ 1. The rewrite must not affect any information expressed in the user's original prompt; the rewritten prompt should use coherent natural language, avoid low-information redundant descriptions, and keep the rewritten prompt length as concise as possible.
12
+ 2. Ensure consistency between input and output languages: Chinese input yields Chinese output, and English input yields English output. The rewritten token count should not exceed 512.
13
+ 3. The rewritten description should further refine subject characteristics and aesthetic techniques appearing in the original prompt, such as lighting and textures.
14
+ 4. If the original prompt does not specify an image style, ensure the rewritten prompt uses a **realistic photography style**. If the user specifies a style, retain the user's style.
15
+ 5. When the original prompt requires reasoning to clarify user intent, use logical reasoning based on world knowledge to convert vague abstract descriptions into specific tangible objects (e.g., convert "the tallest animal" to "a giraffe").
16
+ 6. When the original prompt requires text generation, please use double quotes to enclose the text part (e.g., `"50% OFF"`).
17
+ 7. When the original prompt requires generating text-heavy scenes like webpages, logos, UIs, or posters, and no specific text content is specified, you need to infer appropriate text content and enclose it in double quotes. For example, if the user inputs: "A tourism flyer with a grassland theme," it should be rewritten as: "A tourism flyer with the image title 'Grassland'."
18
+ 8. When negative words exist in the original prompt, ensure the rewritten prompt does not contain negative words. For example, "a lakeside without boats" should be rewritten such that the word "boat" does not appear at all.
19
+ 9. Except for text content explicitly requested by the user, **adding any extra text content is prohibited**.
20
+ Here are examples of rewrites for different types of prompts:
21
+ # Examples (Few-Shot Learning)
22
+ 1. User Input: An animal with nine lives.
23
+ Rewrite Output: A cat bathed in soft sunlight, its fur soft and glossy. The background is a comfortable home environment with light from the window filtering through curtains, creating a warm light and shadow effect. The shot uses a medium distance perspective to highlight the cat's leisurely and stretched posture. Light cleverly hits the cat's face, emphasizing its spirited eyes and delicate whiskers, adding depth and affinity to the image.
24
+ 2. User Input: Create an anime-style tourism flyer with a grassland theme.
25
+ Rewrite Output: In the lower right of the center, a short-haired girl sits sideways on a gray, irregularly shaped rock. She wears a white short-sleeved dress and brown flat shoes, holding a bunch of small white flowers in her left hand, smiling with her legs hanging naturally. The girl has dark brown shoulder-length hair with bangs covering her forehead, brown eyes, and a slightly open mouth. The rock surface has textures of varying depths. To the girl's left and front is lush grass, with long, yellow-green blades, some glowing golden in the sunlight. The grass extends into the distance, forming rolling green hills that fade in color as they recede. The sky occupies the upper half of the picture, pale blue dotted with a few fluffy white clouds. In the upper left corner, there is a line of text in italic, dark green font reading "Explore Nature's Peace". Colors are dominated by green, blue, and yellow, fluid lines, and distinct light and shadow contrast, creating a quiet and comfortable atmosphere.
26
+ 3. User Input: A Christmas sale poster with a red background, promoting a Buy 1 Get 1 Free milk tea offer.
27
+ Rewrite Output: The poster features an overall red tone, embellished with white snowflake patterns on the top and left side. The upper right features a bunch of holly leaves with red berries and a pine cone. In the upper center, golden 3D text reads "Christmas Heartwarming Feedback" centered, along with red bold text "Buy 1 Get 1". Below, two transparent cups filled with bubble tea are placed side by side; the tea is light brown with dark brown pearls scattered at the bottom and middle. Below the cups, white snow piles up, decorated with pine branches, red berries, and pine cones. A blurry Christmas tree is faintly visible in the lower right corner. The image has high clarity, accurate text content, a unified design style, a prominent Christmas theme, and a reasonable layout, providing strong visual appeal.
28
+ 4. User Input: A woman indoors shot in natural light, smiling with arms crossed, showing a relaxed and confident posture.
29
+ Rewrite Output: The image features a young Asian woman with long dark brown hair naturally falling over her shoulders, with some strands illuminated by light, showing a soft sheen. Her features are delicate, with long eyebrows, bright and spirited dark brown eyes looking directly at the camera, revealing peace and confidence. She has a high nose bridge, full lips with nude lipstick, and corners of the mouth slightly raised in a faint smile. Her skin is fair, with cheeks and collarbones illuminated by warm light, showing a healthy ruddiness. She wears a black spaghetti strap tank top revealing graceful collarbone lines, and a thin gold necklace with small beads and metal bars glinting in the light. Her outer layer is a beige knitted cardigan, soft in texture with visible knitting patterns on the sleeves. Her arms are crossed over her chest, hands covered by the cardigan sleeves, in a relaxed posture. The background is a pure dark brown without extra decoration, making the figure the absolute focus. The figure is located in the center of the frame. Light enters from the upper right, creating bright spots on her left cheek, neck, and collarbone, while the right side is slightly shadowed, creating a three-dimensional and soft tone. Image details are clear, showcasing skin texture, hair, and clothing materials well. Colors are dominated by warm tones, with the combination of beige and dark brown creating a warm and comfortable atmosphere. The overall style is natural, elegant, and artistic.
30
+ 5. User Input: Create a series of images showing the growth process of an apple from seed to fruit. The series should include four stages: 1. Sowing, 2. Seedling growth, 3. Plant maturity, 4. Fruit harvesting.
31
+ Rewrite Output: A 4-panel exquisite illustration depicting the growth process of an apple, capturing each stage precisely and clearly. 1. "Sowing": A close-up shot of a hand gently placing a small apple seed into fertile dark soil, with visible soil texture and the seed's smooth surface. The background is a soft-focus garden dotted with green leaves and sunlight filtering through. 2. "Seedling Growth": A young apple sapling breaks through the soil, stretching tender green leaves toward the sky. The scene is set in a vibrant garden illuminated by warm golden light, highlighting the seedling's delicate structure. 3. "Plant Maturity": A mature apple tree, lush with branches and leaves, covered in tender green foliage and developing small apples. The background is a vibrant orchard under a clear blue sky, with dappled sunlight creating a peaceful atmosphere. 4. "Fruit Harvesting": A hand reaches into the tree to pick a ripe red apple, its smooth skin glistening in the sun. The scene shows the abundance of the orchard, with baskets of apples in the background, giving a sense of fulfillment. Each illustration uses a realistic style, focusing on details and harmonious colors to showcase the natural beauty and development of the apple's life cycle.
32
+ 6. User Input: If 1 represents red, 2 represents green, 3 represents purple, and 4 represents yellow, please generate a four-color rainbow based on this rule. The color order from top to bottom is 3142.
33
+ Rewrite Output: The image consists of four horizontally arranged colored stripes, ordered from top to bottom as purple, red, yellow, and green. A white number is centered on each stripe. The top purple stripe features the number "3", the red stripe below it has the number "1", the yellow stripe further down has the number "4", and the bottom green stripe has the number "2". All numbers use a sans-serif font in pure white, forming a sharp contrast with the background colors to ensure good readability. The stripes have high color saturation and a slight texture. The overall layout is simple and clear, with distinct visual effects and no extra decorative elements, emphasizing the numerical information. The image is high definition, with accurate colors and a consistent style, offering strong visual appeal.
34
+ 7. User Input: A stone tablet carved with "Guan Guan Ju Jiu, On the River Isle", natural light, background is a Chinese garden.
35
+ Rewrite Output: An ancient stone tablet carved with "Guan Guan Ju Jiu, On the River Isle", the surface covered with traces of time, the writing clear and deep. Natural light falls from above, softly illuminating every detail of the stone tablet and enhancing its sense of history. The background is an elegant Chinese garden featuring lush bamboo forests, winding paths, and quiet pools, creating a serene and distant atmosphere. The overall picture uses a realistic style with rich details and natural light and shadow effects, highlighting the cultural heritage of the stone tablet and the classical beauty of the garden.
36
+ # Output Format
37
+ Please directly output the rewritten and optimized Prompt content. Do not include any explanatory language or JSON formatting, and do not add opening or closing quotes yourself.
38
+ """
39
+
40
+
41
+
42
+ api_key = os.environ.get("HF_TOKEN")
43
+ if not api_key:
44
+ print("Warning: HF_TOKEN is not set. Prompt enhancement is disabled.")
45
+ return original_prompt
46
+
47
+ if not original_prompt:
48
+ return magic_prompt
49
+ # client = InferenceClient(provider="cerebras", api_key=api_key)
50
+ # messages = []
51
+ client = InferenceClient(provider="nebius", api_key=api_key)
52
+
53
+ try:
54
+ completion = client.chat.completions.create(
55
+ model="Qwen/Qwen3-Coder-30B-A3B-Instruct",
56
+ max_tokens=256,
57
+ messages=[
58
+ {"role": "system", "content": system_prompt},
59
+ {"role": "user", "content": original_prompt}
60
+ ],
61
+ )
62
+ # completion = client.chat.completions.create(
63
+ # model="Qwen/Qwen3-235B-A22B-Instruct-2507", messages=messages
64
+ # )
65
+ polished_prompt = completion.choices[0].message.content
66
+ # polished_prompt += f" {magic_prompt}"
67
+ return polished_prompt.strip().replace("\n", " ")
68
+ except Exception as e:
69
+ print(f"Error during prompt enhancement: {e}")
70
+ return original_prompt
71
+
utils/repo_utils.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, subprocess
2
+
3
+ def clone_repo_if_not_exists(git_url, git_dir):
4
+ """
5
+ Clones a Git repository if it doesn't already exist in the git_dir folder.
6
+ """
7
+ home_dir = os.path.expanduser("~")
8
+ models_dir = os.path.join(home_dir, git_dir)
9
+
10
+ # Extract repository name from the Git URL
11
+ if git_url.endswith(".git"):
12
+ git_name = git_url.split('/')[-1][:-4]
13
+ else:
14
+ git_name = git_url.split('/')[-1]
15
+
16
+ repo_path = os.path.join(models_dir, git_name)
17
+
18
+ if not os.path.exists(models_dir):
19
+ os.makedirs(models_dir)
20
+ print(f"Created directory: {models_dir}")
21
+
22
+ if not os.path.exists(repo_path):
23
+ print(f"Repository '{git_name}' not found in '{models_dir}'. Cloning from {git_url}...")
24
+ try:
25
+ subprocess.run(["git", "clone", git_url, repo_path], check=True)
26
+ print(f"Successfully cloned '{git_name}' to '{repo_path}'.")
27
+ except subprocess.CalledProcessError as e:
28
+ print(f"Error cloning repository: {e}")
29
+ except FileNotFoundError:
30
+ print("Error: 'git' command not found. Please ensure Git is installed and in your PATH.")
31
+ else:
32
+ print(f"Repository '{git_name}' already exists at '{repo_path}'. Skipping clone.")
33
+