programmersd commited on
Commit
d183607
·
verified ·
1 Parent(s): 29958ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +223 -152
app.py CHANGED
@@ -1,97 +1,198 @@
1
  import os
2
- import gc
3
  import sys
4
  import time
 
 
5
  import random
6
  import torch
7
  import gradio as gr
8
- from threading import Lock, Event
 
9
  from contextlib import contextmanager
10
- from huggingface_hub import snapshot_download, LocalEntryNotFoundError
11
 
12
- # ----------- LOGGING -----------
 
13
 
 
14
  LOG_BUFFER = []
15
  LOG_LOCK = Lock()
16
-
17
  def log(msg):
18
  with LOG_LOCK:
19
- timestamp = time.strftime('%H:%M:%S')
20
- LOG_BUFFER.append(f"{timestamp} | {msg}")
 
21
  if len(LOG_BUFFER) > 500:
22
  LOG_BUFFER.pop(0)
23
- print(msg)
24
  return "\n".join(LOG_BUFFER)
25
 
26
- # ----------- ENV CONFIG -----------
27
-
28
- CPU_THREADS = min(8, os.cpu_count() or 1)
29
- for var in ["OMP_NUM_THREADS","MKL_NUM_THREADS","OPENBLAS_NUM_THREADS","VECLIB_MAXIMUM_THREADS","NUMEXPR_NUM_THREADS"]:
30
- os.environ[var] = str(CPU_THREADS)
31
-
32
- os.environ["CUDA_VISIBLE_DEVICES"] = ""
33
  os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
34
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"
35
- os.environ["HF_HUB_OFFLINE"] = "1"
36
- os.environ["TRANSFORMERS_OFFLINE"] = "1"
37
  os.environ["TRANSFORMERS_CACHE"] = "./hf_cache"
38
  os.environ["HF_DATASETS_CACHE"] = "./hf_cache"
 
39
 
40
  torch.set_grad_enabled(False)
41
- torch.set_num_threads(CPU_THREADS)
42
- torch.backends.mkldnn.enabled = True
43
  torch.set_float32_matmul_precision("medium")
44
 
45
  DEVICE = "cpu"
46
  DTYPE = torch.float32
47
- os.makedirs("./hf_cache", exist_ok=True)
48
 
49
  try:
50
- from diffusers import ZImagePipeline
51
- log("Imported diffusers successfully.")
52
  except ImportError as e:
53
- log(f"Import diffusers failed: {e}")
54
  sys.exit(1)
55
 
56
- pipe_cache = {}
57
- pipe_lock = Lock()
58
- generation_lock = Lock()
59
  interrupt_event = Event()
 
 
60
 
61
- # ----------- SNAPSHOT WITH RETRY -----------
62
-
63
  MODEL_SPECS = {
64
- "Z-Image Turbo": "Tongyi-MAI/Z-Image-Turbo",
65
- # Optionally add quantized variants here
66
- # "Z-Image Turbo GGUF": "unsloth/Z-Image-Turbo-GGUF",
67
  }
68
 
69
- def download_snapshot_with_retry(repo_id, local_path, retries=3):
70
- attempt = 1
71
- while attempt <= retries:
72
- log(f"Snapshot attempt {attempt}/{retries} for {repo_id}...")
73
- try:
74
- # snapshot_download respects HF cache and will skip downloads if cached
75
- path = snapshot_download(repo_id=repo_id, local_dir=local_path, local_dir_use_symlinks=False)
76
- log(f"Snapshot fully downloaded: {path}")
77
- return path
78
- except Exception as e:
79
- log(f"⚠️ snapshot_download failed: {e}")
80
- attempt += 1
81
- time.sleep(2)
82
- raise RuntimeError(f"Failed to download snapshot of {repo_id} after {retries} attempts")
83
-
84
- # Ensure snapshot is present
85
- for model_name, repo_id in MODEL_SPECS.items():
86
- local_dir = os.path.join("./hf_cache", f"{model_name}_snapshot")
87
- if not os.path.isdir(local_dir) or not os.listdir(local_dir):
88
- log(f"📥 No snapshot for {model_name}, starting download...")
89
- try:
90
- download_snapshot_with_retry(repo_id, local_dir, retries=3)
91
- except RuntimeError as err:
92
- log(f"❌ Snapshot download error: {err}")
93
-
94
- # ----------- PIPELINE LOADING -----------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
  @contextmanager
97
  def managed_memory():
@@ -103,127 +204,97 @@ def managed_memory():
103
  if torch.cuda.is_available():
104
  torch.cuda.empty_cache()
105
 
106
- def load_pipeline(model_name):
107
- if model_name in pipe_cache:
108
- return pipe_cache[model_name]
109
- with pipe_lock:
110
- log(f"Loading {model_name} pipeline.")
111
- repo_dir = os.path.join("./hf_cache", f"{model_name}_snapshot")
112
- try:
113
- pipe = ZImagePipeline.from_pretrained(repo_dir, torch_dtype=DTYPE, local_files_only=True, low_cpu_mem_usage=True)
114
- except LocalEntryNotFoundError:
115
- log(f"Incomplete local snapshot for {model_name}, retrying online load.")
116
- pipe = ZImagePipeline.from_pretrained(MODEL_SPECS[model_name], torch_dtype=DTYPE, cache_dir="./hf_cache", low_cpu_mem_usage=True)
117
- pipe.to(DEVICE)
118
- pipe.vae.eval()
119
- pipe.text_encoder.eval()
120
- pipe.transformer.eval()
121
- try:
122
- pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead")
123
- log("Transformer compiled.")
124
- except Exception as e:
125
- log(f"Transformer compile skipped: {e}")
126
- pipe_cache[model_name] = pipe
127
- return pipe
128
-
129
- # ----------- GENERATION LOGIC -----------
130
-
131
  @torch.inference_mode()
132
  @torch.no_grad()
133
- def generate(prompt, quality_mode, seed, model_name):
134
  if not prompt.strip():
135
- raise gr.Error("Prompt cannot be empty!")
136
 
137
  PRESETS = {
138
- "ultra_fast": (1, 256),
139
- "fast": (1, 256),
140
- "balanced": (2, 256),
141
- "quality": (4, 384),
142
- "ultra_quality": (4, 512),
143
  }
144
- steps, size = PRESETS.get(quality_mode, (1, 256))
145
  width = height = size
146
 
147
- seed = int(seed) if seed >= 0 else random.randint(0, (2**31)-1)
148
- log(f"Generating: '{prompt[:40]}...' | {quality_mode} | {width}x{height} | seed={seed}")
149
-
150
- with managed_memory(), generation_lock:
151
- pipe = load_pipeline(model_name)
152
- generator = torch.Generator("cpu").manual_seed(seed)
153
- start_time = time.time()
154
 
155
- preview_images = []
 
 
 
 
156
 
157
- def progress_cb(pipeline, step_idx, timestep, cbk):
158
  if interrupt_event.is_set():
159
- raise KeyboardInterrupt("Generation interrupted")
160
- if step_idx % 2 == 0: # preview every 2 steps
161
  try:
162
- preview_images.append(pipeline.image_from_latents(pipeline.latents))
163
- except Exception:
164
  pass
165
  return cbk
166
 
167
- try:
168
- result = pipe(
169
- prompt=prompt,
170
- negative_prompt=None,
171
- width=width,
172
- height=height,
173
- num_inference_steps=steps,
174
- guidance_scale=0.0,
175
- generator=generator,
176
- callback_on_step_end=progress_cb,
177
- callback_on_step_end_tensor_inputs=["latents"],
178
- output_type="pil"
179
- )
180
- final_image = result.images[0]
181
- log(f"Done in {time.time()-start_time:.1f}s")
182
- except KeyboardInterrupt:
183
- log("⚠️ Generation interrupted.")
184
- return None, seed, preview_images
185
-
186
- del result
187
- gc.collect()
188
-
189
- preview_images.append(final_image)
190
- return final_image, seed, preview_images
191
-
192
- # ----------- GRADIO UI -----------
193
-
194
- with gr.Blocks(title="🤩✨ Z‑Image Turbo CPU Ultimate + Retry + Preview + Interrupt") as demo:
195
- gr.Markdown("## Full feature CPU image generator — true snapshot retry + preview frames")
196
 
197
  with gr.Row():
198
  with gr.Column():
199
  prompt = gr.Textbox(label="Prompt", lines=4)
200
- quality_mode = gr.Radio(
201
- choices=["ultra_fast","fast","balanced","quality","ultra_quality"],
202
- value="fast",
203
- label="Quality Mode"
204
- )
205
- seed = gr.Number(value=-1, precision=0, label="Seed (-1=random)")
206
- model_choice = gr.Dropdown(list(MODEL_SPECS.keys()), value=list(MODEL_SPECS.keys())[0], label="Select model")
207
  gen_btn = gr.Button("GENERATE")
208
- interrupt_btn = gr.Button("STOP")
209
-
210
  with gr.Column():
211
- out_img = gr.Image(label="Final Image")
212
- out_seed = gr.Number(label="Seed Used", interactive=False)
213
- preview_gallery = gr.Gallery(label="Preview frames")
214
- log_output = gr.Textbox(label="Live System Log", lines=15, interactive=False)
 
 
 
 
 
 
215
 
216
- def on_generate(prompt, quality_mode, seed, model_choice):
217
  interrupt_event.clear()
218
- final_img, used_seed, previews = generate(prompt, quality_mode, seed, model_choice)
219
- return final_img, used_seed, previews, log("Generation done.")
220
 
221
- def on_interrupt():
222
  interrupt_event.set()
223
- return log("📌 Interrupt requested")
224
 
225
- gen_btn.click(on_generate, inputs=[prompt, quality_mode, seed, model_choice], outputs=[out_img, out_seed, preview_gallery, log_output])
226
- interrupt_btn.click(on_interrupt, inputs=None, outputs=log_output)
 
 
227
 
228
  demo.queue()
229
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
  import os
 
2
  import sys
3
  import time
4
+ import json
5
+ import gc
6
  import random
7
  import torch
8
  import gradio as gr
9
+ import requests
10
+ from threading import Lock, Event, Thread
11
  from contextlib import contextmanager
12
+ from urllib.parse import urlparse
13
 
14
+ from huggingface_hub import hf_hub_download, hf_hub_url
15
+ from huggingface_hub.utils import RepositoryNotFoundError
16
 
17
+ # ===== LOGGING =====
18
  LOG_BUFFER = []
19
  LOG_LOCK = Lock()
 
20
  def log(msg):
21
  with LOG_LOCK:
22
+ t = time.strftime("%H:%M:%S")
23
+ entry = f"{t} | {msg}"
24
+ LOG_BUFFER.append(entry)
25
  if len(LOG_BUFFER) > 500:
26
  LOG_BUFFER.pop(0)
 
27
  return "\n".join(LOG_BUFFER)
28
 
29
+ # ===== ENV SETUP =====
 
 
 
 
 
 
30
  os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
31
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"
 
 
32
  os.environ["TRANSFORMERS_CACHE"] = "./hf_cache"
33
  os.environ["HF_DATASETS_CACHE"] = "./hf_cache"
34
+ os.makedirs("./hf_cache", exist_ok=True)
35
 
36
  torch.set_grad_enabled(False)
37
+ torch.set_num_threads(min(8, os.cpu_count() or 1))
 
38
  torch.set_float32_matmul_precision("medium")
39
 
40
  DEVICE = "cpu"
41
  DTYPE = torch.float32
 
42
 
43
  try:
44
+ from diffusers import ZImagePipeline, GGUFQuantizationConfig, ZImageTransformer2DModel
45
+ log("Loaded diffusers modules")
46
  except ImportError as e:
47
+ log(f"Import error: {e}")
48
  sys.exit(1)
49
 
50
+ # ===== DOWNLOAD CONTEXT =====
 
 
51
  interrupt_event = Event()
52
+ pipe_cache = {}
53
+ download_lock = Lock()
54
 
55
+ # ===== MODEL LIST =====
 
56
  MODEL_SPECS = {
57
+ "Turbo Full": "Tongyi-MAI/Z-Image-Turbo",
58
+ "Turbo Q2_K GGUF": "unsloth/Z-Image-Turbo-GGUF"
 
59
  }
60
 
61
+ # ===== DOWNLOAD HELPERS =====
62
+ def list_repo_files(repo_id):
63
+ """
64
+ Returns a list of (filename, size) tuples by doing a dry run
65
+ (no actual data downloaded).
66
+ """
67
+ try:
68
+ infos = hf_hub_download(repo_id, dry_run=True)
69
+ return [(info.rfilename, info.size_in_bytes) for info in infos]
70
+ except Exception as e:
71
+ log(f"List failed: {e}")
72
+ return []
73
+
74
+ def download_file_chunked(repo_id, filename, target_dir, progress_updater):
75
+ """
76
+ Download a single file by streaming signed URL chunks.
77
+ Supports resume by checking existing file size.
78
+ """
79
+ local_path = os.path.join(target_dir, filename)
80
+ tmp_path = local_path + ".part"
81
+
82
+ os.makedirs(os.path.dirname(local_path), exist_ok=True)
83
+ already = 0
84
+ if os.path.exists(tmp_path):
85
+ already = os.path.getsize(tmp_path)
86
+
87
+ # Get a fresh signed URL from HF for that file
88
+ try:
89
+ url = hf_hub_url(repo_id, filename)
90
+ except RepositoryNotFoundError:
91
+ # fallback to normal
92
+ url = hf_hub_download(repo_id, filename=filename)
93
+
94
+ headers = {}
95
+ if already > 0:
96
+ headers["Range"] = f"bytes={already}-"
97
+
98
+ with requests.get(url, headers=headers, stream=True, timeout=10) as r:
99
+ total = int(r.headers.get("Content-Length", 0)) + already
100
+ with open(tmp_path, "ab") as f:
101
+ downloaded = already
102
+ for chunk in r.iter_content(chunk_size=1024*256):
103
+ if interrupt_event.is_set():
104
+ return False
105
+ if not chunk:
106
+ continue
107
+ f.write(chunk)
108
+ downloaded += len(chunk)
109
+ progress_updater(downloaded / total)
110
+
111
+ os.rename(tmp_path, local_path)
112
+ return True
113
+
114
+ def parallel_download_repo(repo_id, progress: gr.Progress):
115
+ """
116
+ Download all files in the repo in parallel with per-file progress.
117
+ """
118
+ base_dir = os.path.join("./hf_cache", repo_id.replace("/", "_"))
119
+ files = list_repo_files(repo_id)
120
+ if not files:
121
+ progress(1.0, desc="No files to download")
122
+ return
123
+
124
+ total_bytes = sum(sz for _, sz in files)
125
+ downloaded_bytes = 0
126
+
127
+ def file_thread(filename, size):
128
+ nonlocal downloaded_bytes
129
+ success = download_file_chunked(
130
+ repo_id, filename, base_dir,
131
+ lambda frac: progress((downloaded_bytes + frac * size) / total_bytes,
132
+ desc=f"{filename} {frac*100:.1f}%")
133
+ )
134
+ if success:
135
+ with download_lock:
136
+ downloaded_bytes += size
137
+
138
+ threads = []
139
+ for fname, size in files:
140
+ if interrupt_event.is_set():
141
+ break
142
+ # skip if fully cached already
143
+ local_full = os.path.join(base_dir, fname)
144
+ if os.path.exists(local_full) and os.path.getsize(local_full) == size:
145
+ downloaded_bytes += size
146
+ continue
147
+ t = Thread(target=file_thread, args=(fname, size))
148
+ t.start()
149
+ threads.append(t)
150
+
151
+ for t in threads:
152
+ t.join()
153
+
154
+ # ===== PIPELINE LOADER =====
155
+ def load_pipeline(model_key):
156
+ """
157
+ Load the HF pipeline, using quantized GGUF if selected.
158
+ """
159
+ if model_key in pipe_cache:
160
+ return pipe_cache[model_key]
161
+
162
+ repo = MODEL_SPECS[model_key]
163
+ repo_cache = os.path.join("./hf_cache", repo.replace("/", "_"))
164
+
165
+ # ensure cache
166
+ if not os.path.isdir(repo_cache) or not os.listdir(repo_cache):
167
+ raise gr.Error("Model not downloaded; press Preload first")
168
+
169
+ # load model
170
+ if "GGUF" in model_key:
171
+ # pick .gguf file
172
+ files = [f for f in os.listdir(repo_cache) if f.endswith(".gguf")]
173
+ if not files:
174
+ raise gr.Error("Quantized file not found")
175
+ gguf = os.path.join(repo_cache, files[0])
176
+ transformer = ZImageTransformer2DModel.from_single_file(
177
+ gguf,
178
+ quantization_config=GGUFQuantizationConfig(compute_dtype=DTYPE),
179
+ torch_dtype=DTYPE
180
+ )
181
+ pipe = ZImagePipeline.from_pretrained(
182
+ "Tongyi-MAI/Z-Image-Turbo",
183
+ transformer=transformer,
184
+ torch_dtype=DTYPE,
185
+ cache_dir="./hf_cache"
186
+ )
187
+ else:
188
+ pipe = ZImagePipeline.from_pretrained(repo_cache, torch_dtype=DTYPE, local_files_only=True)
189
+
190
+ pipe.to(DEVICE)
191
+ pipe.vae.eval()
192
+ pipe.text_encoder.eval()
193
+ pipe.transformer.eval()
194
+ pipe_cache[model_key] = pipe
195
+ return pipe
196
 
197
  @contextmanager
198
  def managed_memory():
 
204
  if torch.cuda.is_available():
205
  torch.cuda.empty_cache()
206
 
207
+ # ===== GENERATION =====
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  @torch.inference_mode()
209
  @torch.no_grad()
210
+ def generate(prompt, quality_mode, seed, model_key):
211
  if not prompt.strip():
212
+ raise gr.Error("Prompt cannot be empty")
213
 
214
  PRESETS = {
215
+ "ultra_fast": (1,256),
216
+ "fast": (1,256),
217
+ "balanced": (2,256),
218
+ "quality": (4,384),
219
+ "ultra_quality": (4,512),
220
  }
221
+ steps, size = PRESETS.get(quality_mode, (1,256))
222
  width = height = size
223
 
224
+ seed = int(seed) if seed>=0 else random.randint(0,2**31-1)
225
+ log(f"Gen: {prompt[:30]} | {quality_mode} | {model_key} | seed={seed}")
 
 
 
 
 
226
 
227
+ with managed_memory():
228
+ pipe = load_pipeline(model_key)
229
+ gen = torch.Generator("cpu").manual_seed(seed)
230
+ previews=[]
231
+ start = time.time()
232
 
233
+ def cb(ppl, step, timestep, cbk):
234
  if interrupt_event.is_set():
235
+ ppl._interrupt=True
236
+ if step % 2 == 0:
237
  try:
238
+ previews.append(ppl.image_from_latents(cbk["latents"]))
239
+ except:
240
  pass
241
  return cbk
242
 
243
+ result = pipe(
244
+ prompt=prompt,
245
+ negative_prompt=None,
246
+ width=width,
247
+ height=height,
248
+ num_inference_steps=steps,
249
+ guidance_scale=0.0,
250
+ generator=gen,
251
+ callback_on_step_end=cb,
252
+ callback_on_step_end_tensor_inputs=["latents"],
253
+ output_type="pil"
254
+ )
255
+ final = result.images[0]
256
+ previews.append(final)
257
+ log(f"Generated in {time.time()-start:.1f}s")
258
+ return final, seed, previews
259
+
260
+ # ===== GRADIO UI =====
261
+ with gr.Blocks(title="Z‑Image Turbo CPU Downloader + UI A‑Progress") as demo:
262
+ gr.Markdown("## True parallel download UI + chunked progress")
 
 
 
 
 
 
 
 
 
263
 
264
  with gr.Row():
265
  with gr.Column():
266
  prompt = gr.Textbox(label="Prompt", lines=4)
267
+ quality = gr.Radio(["ultra_fast","fast","balanced","quality","ultra_quality"], value="fast")
268
+ seed = gr.Number(value=-1, precision=0, label="Seed")
269
+ model_select = gr.Dropdown(list(MODEL_SPECS.keys()), value=list(MODEL_SPECS.keys())[0], label="Model")
270
+ preload = gr.Button("PRELOAD MODELS")
 
 
 
271
  gen_btn = gr.Button("GENERATE")
272
+ stop_btn = gr.Button("STOP")
 
273
  with gr.Column():
274
+ out_image = gr.Image(label="Final")
275
+ used_seed = gr.Number(label="Seed Used")
276
+ preview = gr.Gallery(label="Preview Frames")
277
+ logs = gr.Textbox(label="Logs", lines=25)
278
+
279
+ def do_preload(progress=gr.Progress()):
280
+ interrupt_event.clear()
281
+ for key, repo in MODEL_SPECS.items():
282
+ parallel_download_repo(repo, progress)
283
+ return log("📦 Preload finished")
284
 
285
+ def do_gen(prompt, quality, seed, model_key):
286
  interrupt_event.clear()
287
+ img, used, previews = generate(prompt, quality, seed, model_key)
288
+ return img, used, previews, log("🧠 Generation done")
289
 
290
+ def do_stop():
291
  interrupt_event.set()
292
+ return log("🔴 Interrupt set")
293
 
294
+ preload.click(do_preload, outputs=logs)
295
+ gen_btn.click(do_gen, inputs=[prompt,quality,seed,model_select],
296
+ outputs=[out_image,used_seed,preview,logs])
297
+ stop_btn.click(do_stop, outputs=logs)
298
 
299
  demo.queue()
300
  demo.launch(server_name="0.0.0.0", server_port=7860)