Spaces:
Sleeping
Sleeping
| import os | |
| import gc | |
| import sys | |
| import time | |
| import random | |
| import torch | |
| import gradio as gr | |
| from threading import Lock, Event | |
| from contextlib import contextmanager | |
| from huggingface_hub import snapshot_download, LocalEntryNotFoundError | |
| # ----------- LOGGING ----------- | |
| LOG_BUFFER = [] | |
| LOG_LOCK = Lock() | |
| def log(msg): | |
| with LOG_LOCK: | |
| timestamp = time.strftime('%H:%M:%S') | |
| LOG_BUFFER.append(f"{timestamp} | {msg}") | |
| if len(LOG_BUFFER) > 500: | |
| LOG_BUFFER.pop(0) | |
| print(msg) | |
| return "\n".join(LOG_BUFFER) | |
| # ----------- ENV CONFIG ----------- | |
| CPU_THREADS = min(8, os.cpu_count() or 1) | |
| for var in ["OMP_NUM_THREADS","MKL_NUM_THREADS","OPENBLAS_NUM_THREADS","VECLIB_MAXIMUM_THREADS","NUMEXPR_NUM_THREADS"]: | |
| os.environ[var] = str(CPU_THREADS) | |
| os.environ["CUDA_VISIBLE_DEVICES"] = "" | |
| os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1" | |
| os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0" | |
| os.environ["HF_HUB_OFFLINE"] = "1" | |
| os.environ["TRANSFORMERS_OFFLINE"] = "1" | |
| os.environ["TRANSFORMERS_CACHE"] = "./hf_cache" | |
| os.environ["HF_DATASETS_CACHE"] = "./hf_cache" | |
| torch.set_grad_enabled(False) | |
| torch.set_num_threads(CPU_THREADS) | |
| torch.backends.mkldnn.enabled = True | |
| torch.set_float32_matmul_precision("medium") | |
| DEVICE = "cpu" | |
| DTYPE = torch.float32 | |
| os.makedirs("./hf_cache", exist_ok=True) | |
| try: | |
| from diffusers import ZImagePipeline | |
| log("Imported diffusers successfully.") | |
| except ImportError as e: | |
| log(f"Import diffusers failed: {e}") | |
| sys.exit(1) | |
| pipe_cache = {} | |
| pipe_lock = Lock() | |
| generation_lock = Lock() | |
| interrupt_event = Event() | |
| # ----------- SNAPSHOT WITH RETRY ----------- | |
| MODEL_SPECS = { | |
| "Z-Image Turbo": "Tongyi-MAI/Z-Image-Turbo", | |
| # Optionally add quantized variants here | |
| # "Z-Image Turbo GGUF": "unsloth/Z-Image-Turbo-GGUF", | |
| } | |
| def download_snapshot_with_retry(repo_id, local_path, retries=3): | |
| attempt = 1 | |
| while attempt <= retries: | |
| log(f"Snapshot attempt {attempt}/{retries} for {repo_id}...") | |
| try: | |
| # snapshot_download respects HF cache and will skip downloads if cached | |
| path = snapshot_download(repo_id=repo_id, local_dir=local_path, local_dir_use_symlinks=False) | |
| log(f"Snapshot fully downloaded: {path}") | |
| return path | |
| except Exception as e: | |
| log(f"⚠️ snapshot_download failed: {e}") | |
| attempt += 1 | |
| time.sleep(2) | |
| raise RuntimeError(f"Failed to download snapshot of {repo_id} after {retries} attempts") | |
| # Ensure snapshot is present | |
| for model_name, repo_id in MODEL_SPECS.items(): | |
| local_dir = os.path.join("./hf_cache", f"{model_name}_snapshot") | |
| if not os.path.isdir(local_dir) or not os.listdir(local_dir): | |
| log(f"📥 No snapshot for {model_name}, starting download...") | |
| try: | |
| download_snapshot_with_retry(repo_id, local_dir, retries=3) | |
| except RuntimeError as err: | |
| log(f"❌ Snapshot download error: {err}") | |
| # ----------- PIPELINE LOADING ----------- | |
| def managed_memory(): | |
| try: | |
| yield | |
| finally: | |
| gc.collect() | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| def load_pipeline(model_name): | |
| if model_name in pipe_cache: | |
| return pipe_cache[model_name] | |
| with pipe_lock: | |
| log(f"Loading {model_name} pipeline.") | |
| repo_dir = os.path.join("./hf_cache", f"{model_name}_snapshot") | |
| try: | |
| pipe = ZImagePipeline.from_pretrained(repo_dir, torch_dtype=DTYPE, local_files_only=True, low_cpu_mem_usage=True) | |
| except LocalEntryNotFoundError: | |
| log(f"Incomplete local snapshot for {model_name}, retrying online load.") | |
| pipe = ZImagePipeline.from_pretrained(MODEL_SPECS[model_name], torch_dtype=DTYPE, cache_dir="./hf_cache", low_cpu_mem_usage=True) | |
| pipe.to(DEVICE) | |
| pipe.vae.eval() | |
| pipe.text_encoder.eval() | |
| pipe.transformer.eval() | |
| try: | |
| pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead") | |
| log("Transformer compiled.") | |
| except Exception as e: | |
| log(f"Transformer compile skipped: {e}") | |
| pipe_cache[model_name] = pipe | |
| return pipe | |
| # ----------- GENERATION LOGIC ----------- | |
| def generate(prompt, quality_mode, seed, model_name): | |
| if not prompt.strip(): | |
| raise gr.Error("Prompt cannot be empty!") | |
| PRESETS = { | |
| "ultra_fast": (1, 256), | |
| "fast": (1, 256), | |
| "balanced": (2, 256), | |
| "quality": (4, 384), | |
| "ultra_quality": (4, 512), | |
| } | |
| steps, size = PRESETS.get(quality_mode, (1, 256)) | |
| width = height = size | |
| seed = int(seed) if seed >= 0 else random.randint(0, (2**31)-1) | |
| log(f"Generating: '{prompt[:40]}...' | {quality_mode} | {width}x{height} | seed={seed}") | |
| with managed_memory(), generation_lock: | |
| pipe = load_pipeline(model_name) | |
| generator = torch.Generator("cpu").manual_seed(seed) | |
| start_time = time.time() | |
| preview_images = [] | |
| def progress_cb(pipeline, step_idx, timestep, cbk): | |
| if interrupt_event.is_set(): | |
| raise KeyboardInterrupt("Generation interrupted") | |
| if step_idx % 2 == 0: # preview every 2 steps | |
| try: | |
| preview_images.append(pipeline.image_from_latents(pipeline.latents)) | |
| except Exception: | |
| pass | |
| return cbk | |
| try: | |
| result = pipe( | |
| prompt=prompt, | |
| negative_prompt=None, | |
| width=width, | |
| height=height, | |
| num_inference_steps=steps, | |
| guidance_scale=0.0, | |
| generator=generator, | |
| callback_on_step_end=progress_cb, | |
| callback_on_step_end_tensor_inputs=["latents"], | |
| output_type="pil" | |
| ) | |
| final_image = result.images[0] | |
| log(f"Done in {time.time()-start_time:.1f}s") | |
| except KeyboardInterrupt: | |
| log("⚠️ Generation interrupted.") | |
| return None, seed, preview_images | |
| del result | |
| gc.collect() | |
| preview_images.append(final_image) | |
| return final_image, seed, preview_images | |
| # ----------- GRADIO UI ----------- | |
| with gr.Blocks(title="🤩✨ Z‑Image Turbo CPU Ultimate + Retry + Preview + Interrupt") as demo: | |
| gr.Markdown("## Full feature CPU image generator — true snapshot retry + preview frames") | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt = gr.Textbox(label="Prompt", lines=4) | |
| quality_mode = gr.Radio( | |
| choices=["ultra_fast","fast","balanced","quality","ultra_quality"], | |
| value="fast", | |
| label="Quality Mode" | |
| ) | |
| seed = gr.Number(value=-1, precision=0, label="Seed (-1=random)") | |
| model_choice = gr.Dropdown(list(MODEL_SPECS.keys()), value=list(MODEL_SPECS.keys())[0], label="Select model") | |
| gen_btn = gr.Button("GENERATE") | |
| interrupt_btn = gr.Button("STOP") | |
| with gr.Column(): | |
| out_img = gr.Image(label="Final Image") | |
| out_seed = gr.Number(label="Seed Used", interactive=False) | |
| preview_gallery = gr.Gallery(label="Preview frames") | |
| log_output = gr.Textbox(label="Live System Log", lines=15, interactive=False) | |
| def on_generate(prompt, quality_mode, seed, model_choice): | |
| interrupt_event.clear() | |
| final_img, used_seed, previews = generate(prompt, quality_mode, seed, model_choice) | |
| return final_img, used_seed, previews, log("Generation done.") | |
| def on_interrupt(): | |
| interrupt_event.set() | |
| return log("📌 Interrupt requested") | |
| gen_btn.click(on_generate, inputs=[prompt, quality_mode, seed, model_choice], outputs=[out_img, out_seed, preview_gallery, log_output]) | |
| interrupt_btn.click(on_interrupt, inputs=None, outputs=log_output) | |
| demo.queue() | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |