kelseye commited on
Commit
4a60f19
·
verified ·
1 Parent(s): 331282e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +441 -158
app.py CHANGED
@@ -1,20 +1,193 @@
1
  import spaces
2
- import os
3
- os.system("pip install diffsynth==2.0.3")
4
- from modelscope.hub.api import HubApi
5
- api = HubApi()
6
- api.login(os.environ["MODELSCOPE_TOKEN"])
7
- os.environ["DIFFSYNTH_MODEL_BASE_PATH"] = "/mnt/workspace/models"
8
- os.environ["DIFFSYNTH_DOWNLOAD_SOURCE"] = "huggingface"
9
-
10
  import gradio as gr
11
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  from diffsynth.pipelines.z_image import (
13
  ZImagePipeline, ModelConfig,
14
  ZImageUnit_Image2LoRAEncode, ZImageUnit_Image2LoRADecode
15
  )
 
16
 
17
- # Use `vram_config` to enable LoRA hot-loading
 
 
18
  vram_config = {
19
  "offload_dtype": torch.bfloat16,
20
  "offload_device": "cuda",
@@ -26,166 +199,276 @@ vram_config = {
26
  "computation_device": "cuda",
27
  }
28
 
29
- # Load models
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  pipe = ZImagePipeline.from_pretrained(
31
  torch_dtype=torch.bfloat16,
32
  device="cuda",
33
- model_configs=[
34
- ModelConfig(model_id="Tongyi-MAI/Z-Image", origin_file_pattern="transformer/*.safetensors", **vram_config),
35
- ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"),
36
- ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
37
- ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="SigLIP2-G384/model.safetensors"),
38
- ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="DINOv3-7B/model.safetensors"),
39
- ModelConfig(model_id="DiffSynth-Studio/Z-Image-i2L", origin_file_pattern="model.safetensors"),
40
- ],
41
- tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
42
  )
43
 
44
- @spaces.GPU(duration=60)
45
- def run_inference(style_image_1, style_image_2, style_image_3, style_image_4, style_image_5, style_image_6, prompt, negative_prompt, cfg_scale, sigma_shift, seed, num_inference_steps, height, width, progress=gr.Progress()):
46
- # Filter out None values and collect valid images
47
- style_images = [img for img in [style_image_1, style_image_2, style_image_3, style_image_4, style_image_5, style_image_6] if img is not None]
 
 
 
 
 
 
 
 
 
 
 
48
 
49
- # Image to LoRA
50
- with torch.no_grad():
51
- embs = ZImageUnit_Image2LoRAEncode().process(pipe, image2lora_images=style_images)
52
- lora = ZImageUnit_Image2LoRADecode().process(pipe, **embs)["lora"]
53
-
54
- # Generate images
55
- progress(0, desc="Generating image")
56
- image = pipe(
57
- prompt=prompt,
58
- negative_prompt=negative_prompt,
59
- seed=None if seed == -1 else seed,
60
- cfg_scale=cfg_scale, num_inference_steps=num_inference_steps,
61
- positive_only_lora=lora,
62
- sigma_shift=sigma_shift,
63
- height=height,
64
- width=width,
65
- progress_bar_cmd=progress.tqdm
66
- )
67
- return image
68
-
69
-
70
- with gr.Blocks(title="Z-Image-Omni-Base") as demo:
71
- gr.Markdown("Model: https://modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L")
72
- gr.Markdown("GitHub: https://github.com/modelscope/DiffSynth-Studio")
73
- gr.Markdown("Upload images and generate new images based on text prompts")
 
 
 
74
 
75
- with gr.Row():
76
- # 第一列:输入的6张图
77
- with gr.Column():
78
- gr.Markdown("Input Images (upload 1-6 images)")
79
- with gr.Row():
80
- style_image_1 = gr.Image(label="Image 1", type="pil")
81
- style_image_2 = gr.Image(label="Image 2", type="pil")
82
- with gr.Row():
83
- style_image_3 = gr.Image(label="Image 3", type="pil")
84
- style_image_4 = gr.Image(label="Image 4", type="pil")
85
- with gr.Row():
86
- style_image_5 = gr.Image(label="Image 5", type="pil")
87
- style_image_6 = gr.Image(label="Image 6", type="pil")
88
-
89
- # 第二列:提示词等控件
90
- with gr.Column():
91
- gr.Markdown("Settings")
92
- prompt = gr.Textbox(
93
- label="Prompt",
94
- placeholder="Enter your prompt here...",
95
- value="a cat",
96
- lines=3
97
- )
98
-
99
- seed = gr.Number(
100
- label="Seed",
101
- value=-1,
102
- precision=0
103
- )
104
 
105
- height = gr.Slider(
106
- label="Height",
107
- minimum=512,
108
- maximum=1536,
109
- step=64,
110
- value=1024
111
- )
112
 
113
- width = gr.Slider(
114
- label="Width",
115
- minimum=512,
116
- maximum=1536,
117
- step=64,
118
- value=1024
119
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
- num_inference_steps = gr.Slider(
122
- label="Inference Steps",
123
- minimum=10,
124
- maximum=50,
125
- step=1,
126
- value=30
127
- )
128
-
129
- # Advanced settings (collapsed by default)
130
- with gr.Accordion("Advanced Settings", open=True):
131
- negative_prompt = gr.Textbox(
132
- label="Negative Prompt",
133
- placeholder="Enter negative prompt here...",
134
- value="泛黄,发绿,模糊,低分辨率,低质量图像,扭曲的肢体,诡异的外观,丑陋,AI感,噪点,网格感,JPEG压缩条纹,异常的肢体,水印,乱码,意义不明的字符"
135
- )
136
- cfg_scale = gr.Slider(
137
- label="CFG Scale",
138
- minimum=1.0,
139
- maximum=10.0,
140
- step=0.5,
141
- value=4.0
142
- )
143
- sigma_shift = gr.Slider(
144
- label="Sigma Shift",
145
- minimum=1.0,
146
- maximum=20.0,
147
- step=0.5,
148
- value=8.0
149
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
- generate_btn = gr.Button("Generate Image")
152
-
153
- # 第三列:输出图
154
- with gr.Column():
155
- gr.Markdown("Output")
156
- output_image = gr.Image(
157
- label="Generated Image",
158
- interactive=False
159
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
- generate_btn.click(
162
- fn=run_inference,
163
- inputs=[style_image_1, style_image_2, style_image_3, style_image_4, style_image_5, style_image_6, prompt, negative_prompt, cfg_scale, sigma_shift, seed, num_inference_steps, height, width],
164
- outputs=[output_image],
165
- )
166
- gr.Examples(
167
- examples=[
168
- [
169
- "assets/style/1/0.jpg", "assets/style/1/1.jpg", "assets/style/1/2.jpg", "assets/style/1/3.jpg", None, None,
170
- "a cat", "泛黄,发绿,模糊,低分辨率,低质量图像,扭曲的肢体,诡异的外观,丑陋,AI感,噪点,网格感,JPEG压缩条纹,异常的肢体,水印,乱码,意义不明的字符",
171
- 4, 8, 0, 30, 1024, 1024
172
- ],
173
- [
174
- "assets/style/4/0.jpg", "assets/style/4/1.jpg", "assets/style/4/2.jpg", "assets/style/4/3.jpg", "assets/style/4/4.jpg", "assets/style/4/5.jpg",
175
- "a dog", "泛黄,发绿,模糊,低分辨率,低质量图像,扭曲的肢体,诡异的外观,丑陋,AI感,噪点,网格感,JPEG压缩条纹,异常的肢体,水印,乱码,意义不明的字符",
176
- 4, 8, 2, 30, 1024, 1024
177
- ],
178
- [
179
- "assets/style/3/0.jpg", "assets/style/3/1.jpg", "assets/style/3/2.jpg", "assets/style/3/3.jpg", None, None,
180
- "a girl", "泛黄,发绿,模糊,低分辨率,低质量图像,扭曲的肢体,诡异的外观,丑陋,AI感,噪点,网格感,JPEG压缩条纹,异常的肢体,水印,乱码,意义不明的字符",
181
- 4, 8, 1, 30, 1024, 1024
182
- ],
183
- ],
184
- fn=run_inference,
185
- inputs=[style_image_1, style_image_2, style_image_3, style_image_4, style_image_5, style_image_6, prompt, negative_prompt, cfg_scale, sigma_shift, seed, num_inference_steps, height, width],
186
- outputs=[output_image],
187
- cache_examples=True
188
- )
189
 
190
  if __name__ == "__main__":
191
- demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
 
 
 
1
  import spaces
 
 
 
 
 
 
 
 
2
  import gradio as gr
3
  import torch
4
+ from PIL import Image
5
+ import os
6
+ import sys
7
+ import subprocess
8
+ import tempfile
9
+ from pathlib import Path
10
+ import glob
11
+
12
+ # Default negative prompts
13
+ NEGATIVE_PROMPT_CN = "泛黄,发绿,模糊,低分辨率,低质量图像,扭曲的肢体,诡异的外观,丑陋,AI感,噪点,网格感,JPEG压缩条纹,异常的肢体,水印,乱码,意义不明的字符"
14
+ NEGATIVE_PROMPT_EN = "Yellowed, green-tinted, blurry, low-resolution, low-quality image, distorted limbs, eerie appearance, ugly, AI-looking, noise, grid-like artifacts, JPEG compression artifacts, abnormal limbs, watermark, garbled text, meaningless characters"
15
+
16
+ # Model paths - can be overridden via environment variables
17
+ MODELS_DIR = Path(os.environ.get("ZIMAGE_MODELS_DIR", "./models"))
18
+
19
+
20
+ # =============================================================================
21
+ # Model Download Functions
22
+ # =============================================================================
23
+
24
+ def download_hf_models(output_dir: Path) -> dict:
25
+ """
26
+ Download required models from Hugging Face using huggingface_hub.
27
+
28
+ Downloads:
29
+ - DiffSynth-Studio/Z-Image-i2L
30
+ - Tongyi-MAI/Z-Image
31
+ - DiffSynth-Studio/General-Image-Encoders
32
+ - Tongyi-MAI/Z-Image-Turbo
33
+
34
+ Returns dict with paths to downloaded models.
35
+ """
36
+ from huggingface_hub import snapshot_download
37
+
38
+ output_dir.mkdir(parents=True, exist_ok=True)
39
+
40
+ models = [
41
+ {
42
+ "repo_id": "DiffSynth-Studio/General-Image-Encoders",
43
+ "description": "General Image Encoders (SigLIP2-G384, DINOv3-7B)",
44
+ "allow_patterns": None,
45
+ },
46
+ {
47
+ "repo_id": "Tongyi-MAI/Z-Image-Turbo",
48
+ "description": "Z-Image Turbo (text encoder, VAE, tokenizer)",
49
+ "allow_patterns": [
50
+ "text_encoder/*.safetensors",
51
+ "vae/*.safetensors",
52
+ "tokenizer/*",
53
+ ],
54
+ },
55
+ {
56
+ "repo_id": "Tongyi-MAI/Z-Image",
57
+ "description": "Z-Image base model (transformer)",
58
+ "allow_patterns": ["transformer/*.safetensors"],
59
+ },
60
+ {
61
+ "repo_id": "DiffSynth-Studio/Z-Image-i2L",
62
+ "description": "Z-Image-i2L (Image to LoRA model)",
63
+ "allow_patterns": ["*.safetensors"],
64
+ },
65
+ ]
66
+
67
+ downloaded_paths = {}
68
+
69
+ for model in models:
70
+ repo_id = model["repo_id"]
71
+ local_dir = output_dir / repo_id
72
+
73
+ # Check if already downloaded
74
+ if local_dir.exists() and any(local_dir.rglob("*.safetensors")):
75
+ print(f" ✓ {repo_id} (already downloaded)")
76
+ downloaded_paths[repo_id] = local_dir
77
+ continue
78
+
79
+ print(f" 📥 Downloading {repo_id}...")
80
+ print(f" {model['description']}")
81
+
82
+ try:
83
+ result_path = snapshot_download(
84
+ repo_id=repo_id,
85
+ local_dir=str(local_dir),
86
+ allow_patterns=model["allow_patterns"],
87
+ local_dir_use_symlinks=False,
88
+ resume_download=True,
89
+ )
90
+ downloaded_paths[repo_id] = Path(result_path)
91
+ print(f" ✓ {repo_id}")
92
+ except Exception as e:
93
+ print(f" ❌ Error downloading {repo_id}: {e}")
94
+ raise
95
+
96
+ return downloaded_paths
97
+
98
+
99
+ def get_model_files(base_path: Path, pattern: str) -> list:
100
+ """Get list of files matching a glob pattern."""
101
+ full_pattern = str(base_path / pattern)
102
+ files = sorted(glob.glob(full_pattern))
103
+ return files
104
+
105
+
106
+ def install_diffsynth_studio():
107
+ """Clone and install DiffSynth-Studio if not already installed."""
108
+ try:
109
+ from diffsynth.pipelines.z_image import ZImagePipeline
110
+ return True, "✅ DiffSynth-Studio is already installed."
111
+ except ImportError:
112
+ pass
113
+
114
+ repo_dir = Path(__file__).parent / "DiffSynth-Studio"
115
+
116
+ try:
117
+ if not repo_dir.exists():
118
+ print("📥 Cloning DiffSynth-Studio repository...")
119
+ subprocess.run(
120
+ ["git", "clone", "https://github.com/modelscope/DiffSynth-Studio.git", str(repo_dir)],
121
+ capture_output=True,
122
+ text=True,
123
+ check=True
124
+ )
125
+ print("✅ Repository cloned successfully.")
126
+ else:
127
+ print("📁 DiffSynth-Studio directory already exists, pulling latest...")
128
+ subprocess.run(
129
+ ["git", "-C", str(repo_dir), "pull"],
130
+ capture_output=True,
131
+ text=True
132
+ )
133
+
134
+ print("📦 Installing DiffSynth-Studio...")
135
+ subprocess.run(
136
+ [sys.executable, "-m", "pip", "install", "-e", str(repo_dir)],
137
+ capture_output=True,
138
+ text=True,
139
+ check=True
140
+ )
141
+ print("✅ DiffSynth-Studio installed successfully.")
142
+
143
+ sys.path.insert(0, str(repo_dir))
144
+
145
+ from diffsynth.pipelines.z_image import ZImagePipeline
146
+ return True, "✅ DiffSynth-Studio installed successfully!"
147
+
148
+ except subprocess.CalledProcessError as e:
149
+ error_msg = f"❌ Installation failed: {e.stderr}"
150
+ print(error_msg)
151
+ return False, error_msg
152
+ except Exception as e:
153
+ error_msg = f"❌ Error during installation: {str(e)}"
154
+ print(error_msg)
155
+ return False, error_msg
156
+
157
+
158
+ # =============================================================================
159
+ # Pipeline Initialization
160
+ # =============================================================================
161
+
162
+ print("=" * 60)
163
+ print(" Z-Image-i2L Gradio Demo - Initializing")
164
+ print("=" * 60)
165
+ print()
166
+
167
+ # Step 1: Install DiffSynth-Studio
168
+ print("🔍 Step 1: Checking DiffSynth-Studio installation...")
169
+ success, message = install_diffsynth_studio()
170
+ print(message)
171
+
172
+ if not success:
173
+ raise RuntimeError("Failed to install DiffSynth-Studio. Cannot continue.")
174
+
175
+ # Step 2: Download HuggingFace models
176
+ print()
177
+ print("🔍 Step 2: Downloading models from HuggingFace...")
178
+ print(f" Models directory: {MODELS_DIR.absolute()}")
179
+ downloaded_paths = download_hf_models(MODELS_DIR)
180
+
181
+ # Import required modules
182
  from diffsynth.pipelines.z_image import (
183
  ZImagePipeline, ModelConfig,
184
  ZImageUnit_Image2LoRAEncode, ZImageUnit_Image2LoRADecode
185
  )
186
+ from safetensors.torch import save_file, load_file
187
 
188
+ # Step 3: Configure VRAM settings
189
+ print()
190
+ print("⚙️ Step 3: Configuring VRAM settings...")
191
  vram_config = {
192
  "offload_dtype": torch.bfloat16,
193
  "offload_device": "cuda",
 
199
  "computation_device": "cuda",
200
  }
201
 
202
+ # Step 4: Resolve local model paths
203
+ print()
204
+ print("📂 Step 4: Resolving model paths...")
205
+
206
+ # Z-Image transformer
207
+ zimage_path = MODELS_DIR / "Tongyi-MAI" / "Z-Image"
208
+ zimage_transformer_files = get_model_files(zimage_path, "transformer/*.safetensors")
209
+
210
+ # Z-Image-Turbo
211
+ zimage_turbo_path = MODELS_DIR / "Tongyi-MAI" / "Z-Image-Turbo"
212
+ text_encoder_files = get_model_files(zimage_turbo_path, "text_encoder/*.safetensors")
213
+ vae_file = get_model_files(zimage_turbo_path, "vae/diffusion_pytorch_model.safetensors")
214
+ tokenizer_path = zimage_turbo_path / "tokenizer"
215
+
216
+ # General Image Encoders
217
+ encoders_path = MODELS_DIR / "DiffSynth-Studio" / "General-Image-Encoders"
218
+ siglip_file = get_model_files(encoders_path, "SigLIP2-G384/model.safetensors")
219
+ dino_file = get_model_files(encoders_path, "DINOv3-7B/model.safetensors")
220
+
221
+ # Z-Image-i2L from HuggingFace
222
+ zimage_i2l_path = MODELS_DIR / "DiffSynth-Studio" / "Z-Image-i2L"
223
+ zimage_i2l_file = get_model_files(zimage_i2l_path, "model.safetensors")
224
+
225
+ print(f" Z-Image transformer: {len(zimage_transformer_files)} file(s)")
226
+ print(f" Text encoder: {len(text_encoder_files)} file(s)")
227
+ print(f" VAE: {len(vae_file)} file(s)")
228
+ print(f" Tokenizer: {tokenizer_path}")
229
+ print(f" SigLIP2: {len(siglip_file)} file(s)")
230
+ print(f" DINOv3: {len(dino_file)} file(s)")
231
+ print(f" Z-Image-i2L: {len(zimage_i2l_file)} file(s)")
232
+
233
+ # Validate files
234
+ missing = []
235
+ if not zimage_transformer_files: missing.append("Z-Image transformer")
236
+ if not text_encoder_files: missing.append("Text encoder")
237
+ if not vae_file: missing.append("VAE")
238
+ if not tokenizer_path.exists(): missing.append("Tokenizer")
239
+ if not siglip_file: missing.append("SigLIP2")
240
+ if not dino_file: missing.append("DINOv3")
241
+ if not zimage_i2l_file: missing.append("Z-Image-i2L")
242
+
243
+ if missing:
244
+ raise FileNotFoundError(f"Missing model files: {', '.join(missing)}")
245
+
246
+ # Step 5: Load pipeline
247
+ print()
248
+ print("🚀 Step 5: Loading Z-Image pipeline...")
249
+ print(" All models loaded from HuggingFace local paths")
250
+
251
+ model_configs = [
252
+ # All models from HuggingFace - use path= for local files
253
+ ModelConfig(path=zimage_transformer_files, **vram_config),
254
+ ModelConfig(path=text_encoder_files),
255
+ ModelConfig(path=vae_file),
256
+ ModelConfig(path=siglip_file),
257
+ ModelConfig(path=dino_file),
258
+ ModelConfig(path=zimage_i2l_file),
259
+ ]
260
+
261
  pipe = ZImagePipeline.from_pretrained(
262
  torch_dtype=torch.bfloat16,
263
  device="cuda",
264
+ model_configs=model_configs,
265
+ tokenizer_config=ModelConfig(path=str(tokenizer_path)),
 
 
 
 
 
 
 
266
  )
267
 
268
+ print()
269
+ print("✅ Pipeline loaded successfully!")
270
+ print("=" * 60)
271
+ print()
272
+
273
+
274
+ # =============================================================================
275
+ # Gradio Functions
276
+ # =============================================================================
277
+
278
+ @spaces.GPU(duration=120)
279
+ def image_to_lora(images, progress=gr.Progress()):
280
+ """Convert input images to a LoRA model."""
281
+ if images is None or len(images) == 0:
282
+ return None, "❌ Please upload at least one image!"
283
 
284
+ try:
285
+ progress(0.1, desc="Processing images...")
286
+
287
+ pil_images = []
288
+ for img in images:
289
+ if isinstance(img, str):
290
+ pil_images.append(Image.open(img).convert("RGB"))
291
+ elif isinstance(img, tuple):
292
+ pil_images.append(Image.open(img[0]).convert("RGB"))
293
+ else:
294
+ pil_images.append(Image.fromarray(img).convert("RGB"))
295
+
296
+ progress(0.3, desc="Encoding images to LoRA...")
297
+
298
+ with torch.no_grad():
299
+ embs = ZImageUnit_Image2LoRAEncode().process(pipe, image2lora_images=pil_images)
300
+ progress(0.7, desc="Decoding LoRA weights...")
301
+ lora = ZImageUnit_Image2LoRADecode().process(pipe, **embs)["lora"]
302
+
303
+ progress(0.9, desc="Saving LoRA file...")
304
+
305
+ temp_dir = tempfile.mkdtemp()
306
+ lora_path = os.path.join(temp_dir, "generated_lora.safetensors")
307
+ save_file(lora, lora_path)
308
+
309
+ progress(1.0, desc="Done!")
310
+
311
+ return lora_path, f"✅ LoRA generated successfully from {len(pil_images)} image(s)!"
312
 
313
+ except Exception as e:
314
+ return None, f"❌ Error generating LoRA: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
 
 
 
 
 
 
 
 
316
 
317
+ @spaces.GPU(duration=60)
318
+ def generate_image(
319
+ lora_file,
320
+ prompt,
321
+ negative_prompt,
322
+ seed,
323
+ cfg_scale,
324
+ sigma_shift,
325
+ num_steps,
326
+ progress=gr.Progress()
327
+ ):
328
+ """Generate an image using the created LoRA."""
329
+ if lora_file is None:
330
+ return None, "❌ Please generate or upload a LoRA file first!"
331
+
332
+ try:
333
+ progress(0.1, desc="Loading LoRA...")
334
+
335
+ lora = load_file(lora_file)
336
+ # Move LoRA tensors to CUDA with correct dtype
337
+ lora = {k: v.to(device="cuda", dtype=torch.bfloat16) for k, v in lora.items()}
338
+
339
+ progress(0.3, desc="Generating image...")
340
+
341
+ image = pipe(
342
+ prompt=prompt,
343
+ negative_prompt=negative_prompt,
344
+ seed=int(seed),
345
+ cfg_scale=cfg_scale,
346
+ num_inference_steps=int(num_steps),
347
+ positive_only_lora=lora,
348
+ sigma_shift=sigma_shift
349
+ )
350
+
351
+ progress(1.0, desc="Done!")
352
+
353
+ return image, "✅ Image generated successfully!"
354
+
355
+ except Exception as e:
356
+ return None, f"❌ Error generating image: {str(e)}"
357
 
358
+
359
+ def create_demo():
360
+ """Create the Gradio interface."""
361
+
362
+ with gr.Blocks(
363
+ title="Z-Image-i2L Demo",
364
+ theme=gr.themes.Soft(),
365
+ css=".gradio-container { max-width: 1200px !important; margin: 0 auto}"
366
+ ) as demo:
367
+ gr.Markdown("""
368
+ # 🎨 Z-Image-i2L: Image to LoRA Demo
369
+
370
+ > 💡 **Tip**: For best results, use 4-6 images with a consistent artistic style.
371
+ """)
372
+
373
+ with gr.Tabs():
374
+ with gr.TabItem("📸 Step 1: Image to LoRA"):
375
+ with gr.Row():
376
+ with gr.Column(scale=1):
377
+ input_gallery = gr.Gallery(
378
+ label="Upload Style Images (1-6 images)",
379
+ file_types=["image"],
380
+ columns=3,
381
+ height=300,
382
+ interactive=True
383
+ )
384
+
385
+ gr.Markdown("""
386
+ **Guidelines:**
387
+ - Upload 1-6 images with a consistent style
388
+ - Higher quality images produce better results
389
+ - Mix of subjects helps generalization
390
+ """)
391
+
392
+ generate_lora_btn = gr.Button("🎯 Generate LoRA", variant="primary")
393
+
394
+ with gr.Column(scale=1):
395
+ lora_output = gr.File(
396
+ label="Generated LoRA File",
397
+ file_types=[".safetensors"],
398
+ interactive=False
399
+ )
400
+ lora_status = gr.Textbox(
401
+ label="Status",
402
+ interactive=False,
403
+ lines=2
404
+ )
405
 
406
+ with gr.TabItem("🖼️ Step 2: Generate Images"):
407
+ with gr.Row():
408
+ with gr.Column(scale=1):
409
+ lora_input = gr.File(
410
+ label="LoRA File (from Step 1 or upload)",
411
+ file_types=[".safetensors"]
412
+ )
413
+
414
+ prompt = gr.Textbox(
415
+ label="Prompt",
416
+ placeholder="Describe what you want to generate...",
417
+ value="a cat",
418
+ lines=2
419
+ )
420
+
421
+ with gr.Accordion("Negative Prompt", open=False):
422
+ negative_prompt = gr.Textbox(
423
+ label="Negative Prompt",
424
+ value=NEGATIVE_PROMPT_CN,
425
+ lines=3
426
+ )
427
+ with gr.Row():
428
+ use_cn_neg = gr.Button("Use Chinese", size="sm")
429
+ use_en_neg = gr.Button("Use English", size="sm")
430
+
431
+ with gr.Accordion("Advanced Settings", open=False):
432
+ seed = gr.Number(label="Seed", value=0, precision=0)
433
+ cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=10, value=4, step=0.5)
434
+ sigma_shift = gr.Slider(label="Sigma Shift", minimum=1, maximum=15, value=8, step=1)
435
+ num_steps = gr.Slider(label="Steps", minimum=20, maximum=100, value=50, step=5)
436
+
437
+ generate_btn = gr.Button("✨ Generate Image", variant="primary")
438
+
439
+ with gr.Column(scale=1):
440
+ output_image = gr.Image(label="Generated Image", type="pil", height=512)
441
+ gen_status = gr.Textbox(label="Status", interactive=False, lines=2)
442
+
443
+ gr.Markdown("""
444
+ ---
445
+ **Resources:** [Z-Image-i2L (HuggingFace)](https://huggingface.co/DiffSynth-Studio/Z-Image-i2L) |
446
+ [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio) |
447
+ **Settings:** CFG=4, Sigma Shift=8, Steps=50
448
+ """)
449
+
450
+ # Event handlers
451
+ generate_lora_btn.click(
452
+ fn=image_to_lora,
453
+ inputs=[input_gallery],
454
+ outputs=[lora_output, lora_status]
455
+ )
456
+
457
+ lora_output.change(fn=lambda x: x, inputs=[lora_output], outputs=[lora_input])
458
+
459
+ generate_btn.click(
460
+ fn=generate_image,
461
+ inputs=[lora_input, prompt, negative_prompt, seed, cfg_scale, sigma_shift, num_steps],
462
+ outputs=[output_image, gen_status]
463
+ )
464
+
465
+ use_cn_neg.click(fn=lambda: NEGATIVE_PROMPT_CN, outputs=[negative_prompt])
466
+ use_en_neg.click(fn=lambda: NEGATIVE_PROMPT_EN, outputs=[negative_prompt])
467
 
468
+ return demo
469
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
470
 
471
  if __name__ == "__main__":
472
+ print("Starting Gradio server...")
473
+ demo = create_demo()
474
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=False)