Zitc / app.py
programmersd's picture
Update app.py
29958ad verified
raw
history blame
8.08 kB
import os
import gc
import sys
import time
import random
import torch
import gradio as gr
from threading import Lock, Event
from contextlib import contextmanager
from huggingface_hub import snapshot_download, LocalEntryNotFoundError
# ----------- LOGGING -----------
LOG_BUFFER = []
LOG_LOCK = Lock()
def log(msg):
with LOG_LOCK:
timestamp = time.strftime('%H:%M:%S')
LOG_BUFFER.append(f"{timestamp} | {msg}")
if len(LOG_BUFFER) > 500:
LOG_BUFFER.pop(0)
print(msg)
return "\n".join(LOG_BUFFER)
# ----------- ENV CONFIG -----------
CPU_THREADS = min(8, os.cpu_count() or 1)
for var in ["OMP_NUM_THREADS","MKL_NUM_THREADS","OPENBLAS_NUM_THREADS","VECLIB_MAXIMUM_THREADS","NUMEXPR_NUM_THREADS"]:
os.environ[var] = str(CPU_THREADS)
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"
os.environ["HF_HUB_OFFLINE"] = "1"
os.environ["TRANSFORMERS_OFFLINE"] = "1"
os.environ["TRANSFORMERS_CACHE"] = "./hf_cache"
os.environ["HF_DATASETS_CACHE"] = "./hf_cache"
torch.set_grad_enabled(False)
torch.set_num_threads(CPU_THREADS)
torch.backends.mkldnn.enabled = True
torch.set_float32_matmul_precision("medium")
DEVICE = "cpu"
DTYPE = torch.float32
os.makedirs("./hf_cache", exist_ok=True)
try:
from diffusers import ZImagePipeline
log("Imported diffusers successfully.")
except ImportError as e:
log(f"Import diffusers failed: {e}")
sys.exit(1)
pipe_cache = {}
pipe_lock = Lock()
generation_lock = Lock()
interrupt_event = Event()
# ----------- SNAPSHOT WITH RETRY -----------
MODEL_SPECS = {
"Z-Image Turbo": "Tongyi-MAI/Z-Image-Turbo",
# Optionally add quantized variants here
# "Z-Image Turbo GGUF": "unsloth/Z-Image-Turbo-GGUF",
}
def download_snapshot_with_retry(repo_id, local_path, retries=3):
attempt = 1
while attempt <= retries:
log(f"Snapshot attempt {attempt}/{retries} for {repo_id}...")
try:
# snapshot_download respects HF cache and will skip downloads if cached
path = snapshot_download(repo_id=repo_id, local_dir=local_path, local_dir_use_symlinks=False)
log(f"Snapshot fully downloaded: {path}")
return path
except Exception as e:
log(f"⚠️ snapshot_download failed: {e}")
attempt += 1
time.sleep(2)
raise RuntimeError(f"Failed to download snapshot of {repo_id} after {retries} attempts")
# Ensure snapshot is present
for model_name, repo_id in MODEL_SPECS.items():
local_dir = os.path.join("./hf_cache", f"{model_name}_snapshot")
if not os.path.isdir(local_dir) or not os.listdir(local_dir):
log(f"📥 No snapshot for {model_name}, starting download...")
try:
download_snapshot_with_retry(repo_id, local_dir, retries=3)
except RuntimeError as err:
log(f"❌ Snapshot download error: {err}")
# ----------- PIPELINE LOADING -----------
@contextmanager
def managed_memory():
try:
yield
finally:
gc.collect()
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def load_pipeline(model_name):
if model_name in pipe_cache:
return pipe_cache[model_name]
with pipe_lock:
log(f"Loading {model_name} pipeline.")
repo_dir = os.path.join("./hf_cache", f"{model_name}_snapshot")
try:
pipe = ZImagePipeline.from_pretrained(repo_dir, torch_dtype=DTYPE, local_files_only=True, low_cpu_mem_usage=True)
except LocalEntryNotFoundError:
log(f"Incomplete local snapshot for {model_name}, retrying online load.")
pipe = ZImagePipeline.from_pretrained(MODEL_SPECS[model_name], torch_dtype=DTYPE, cache_dir="./hf_cache", low_cpu_mem_usage=True)
pipe.to(DEVICE)
pipe.vae.eval()
pipe.text_encoder.eval()
pipe.transformer.eval()
try:
pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead")
log("Transformer compiled.")
except Exception as e:
log(f"Transformer compile skipped: {e}")
pipe_cache[model_name] = pipe
return pipe
# ----------- GENERATION LOGIC -----------
@torch.inference_mode()
@torch.no_grad()
def generate(prompt, quality_mode, seed, model_name):
if not prompt.strip():
raise gr.Error("Prompt cannot be empty!")
PRESETS = {
"ultra_fast": (1, 256),
"fast": (1, 256),
"balanced": (2, 256),
"quality": (4, 384),
"ultra_quality": (4, 512),
}
steps, size = PRESETS.get(quality_mode, (1, 256))
width = height = size
seed = int(seed) if seed >= 0 else random.randint(0, (2**31)-1)
log(f"Generating: '{prompt[:40]}...' | {quality_mode} | {width}x{height} | seed={seed}")
with managed_memory(), generation_lock:
pipe = load_pipeline(model_name)
generator = torch.Generator("cpu").manual_seed(seed)
start_time = time.time()
preview_images = []
def progress_cb(pipeline, step_idx, timestep, cbk):
if interrupt_event.is_set():
raise KeyboardInterrupt("Generation interrupted")
if step_idx % 2 == 0: # preview every 2 steps
try:
preview_images.append(pipeline.image_from_latents(pipeline.latents))
except Exception:
pass
return cbk
try:
result = pipe(
prompt=prompt,
negative_prompt=None,
width=width,
height=height,
num_inference_steps=steps,
guidance_scale=0.0,
generator=generator,
callback_on_step_end=progress_cb,
callback_on_step_end_tensor_inputs=["latents"],
output_type="pil"
)
final_image = result.images[0]
log(f"Done in {time.time()-start_time:.1f}s")
except KeyboardInterrupt:
log("⚠️ Generation interrupted.")
return None, seed, preview_images
del result
gc.collect()
preview_images.append(final_image)
return final_image, seed, preview_images
# ----------- GRADIO UI -----------
with gr.Blocks(title="🤩✨ Z‑Image Turbo CPU Ultimate + Retry + Preview + Interrupt") as demo:
gr.Markdown("## Full feature CPU image generator — true snapshot retry + preview frames")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt", lines=4)
quality_mode = gr.Radio(
choices=["ultra_fast","fast","balanced","quality","ultra_quality"],
value="fast",
label="Quality Mode"
)
seed = gr.Number(value=-1, precision=0, label="Seed (-1=random)")
model_choice = gr.Dropdown(list(MODEL_SPECS.keys()), value=list(MODEL_SPECS.keys())[0], label="Select model")
gen_btn = gr.Button("GENERATE")
interrupt_btn = gr.Button("STOP")
with gr.Column():
out_img = gr.Image(label="Final Image")
out_seed = gr.Number(label="Seed Used", interactive=False)
preview_gallery = gr.Gallery(label="Preview frames")
log_output = gr.Textbox(label="Live System Log", lines=15, interactive=False)
def on_generate(prompt, quality_mode, seed, model_choice):
interrupt_event.clear()
final_img, used_seed, previews = generate(prompt, quality_mode, seed, model_choice)
return final_img, used_seed, previews, log("Generation done.")
def on_interrupt():
interrupt_event.set()
return log("📌 Interrupt requested")
gen_btn.click(on_generate, inputs=[prompt, quality_mode, seed, model_choice], outputs=[out_img, out_seed, preview_gallery, log_output])
interrupt_btn.click(on_interrupt, inputs=None, outputs=log_output)
demo.queue()
demo.launch(server_name="0.0.0.0", server_port=7860)