Z-IMAGE-LORA / app.py
Excalibro's picture
Update app.py
6fcb0cb verified
import os
import re
import time
import random
import pathlib
import hashlib
import shutil
from typing import Dict, Any, Tuple, List, Optional
import gradio as gr
import requests
import torch
from PIL import Image
import spaces
from huggingface_hub import snapshot_download, hf_hub_download, list_repo_files
# =========================
# UI CSS (provided)
# =========================
RESPONSIVE_CSS = """
/* ===== Tunables ===== */
:root{
--hf-topbar: 64px; /* HF header height */
--left-col: 30%; /* width of left column */
--gap: 0.75rem;
--pad-y: 0.60rem;
--pad-x: 0.85rem;
/* Global scale: 0.75 = ~25% smaller, 0.66 = ~34% smaller */
--ui-zoom: 0.75;
/* Output sizing */
--history-h: 110px;
}
/* Scale everything via root font-size (more reliable than transform:scale) */
html, body { height: 100%; }
html { font-size: calc(16px * var(--ui-zoom)); }
body { margin: 0; overflow: hidden; }
/* HuggingFace Spaces wrapper */
.gradio-container{
max-width: 100vw !important;
height: calc(100vh - var(--hf-topbar));
overflow: hidden;
padding: var(--pad-y) var(--pad-x);
box-sizing: border-box;
}
/* ===== Main two-column layout fills viewport ===== */
.layout-main{
display: grid !important;
grid-template-columns: minmax(320px, var(--left-col)) 1fr;
gap: var(--gap);
align-items: stretch;
height: 100%;
min-height: 0;
}
/* ===== Panels ===== */
.panel{
background: #0f172a;
border-radius: 12px;
padding: 0.75rem;
box-shadow: 0 10px 26px rgba(0,0,0,0.40);
border: 1px solid rgba(255,255,255,0.06);
overflow: hidden;
min-height: 0;
}
/* ===== Left panel ===== */
.panel.controls{
display: flex;
flex-direction: column;
min-height: 0;
overflow: hidden;
}
/* Tabs need to be allowed to shrink */
.panel.controls .gr-tabs{ min-height: 0; overflow: hidden; }
/* Make tab content scroll only if absolutely necessary */
.panel.controls .tabitem{
min-height: 0 !important;
overflow: auto !important;
padding-right: 2px;
}
/* Single big box in Generate tab */
#left_box{
display: flex;
flex-direction: column;
gap: 0.6rem;
min-height: 0;
}
/* Tighten default Gradio spacing */
.panel .gr-markdown{ margin: 0.15rem 0 0.35rem !important; }
.panel .gr-form{ gap: 0.35rem; }
.compact-row{ gap: 0.45rem; }
/* Make textboxes compact */
.gradio-container textarea{
padding: 0.45rem 0.55rem !important;
line-height: 1.2 !important;
}
.gradio-container input[type="text"], .gradio-container input[type="number"]{
padding: 0.40rem 0.55rem !important;
}
/* Reduce slider vertical padding */
.gradio-container .gr-slider,
.gradio-container .slider-container{
margin-top: 0.10rem !important;
margin-bottom: 0.10rem !important;
}
/* ===== Right panel (output) ===== */
.panel.output{
display: flex;
flex-direction: column;
min-height: 0;
height: 100%;
overflow: hidden;
}
/* Prevent "infinite height" from nested wrappers */
.panel.output > *,
.panel.output .gr-block,
.panel.output .gr-box,
.panel.output .gr-row,
.panel.output .gr-column,
.panel.output .gr-form{
min-height: 0 !important;
}
/* Main image takes remaining space */
#main_image{
flex: 1 1 auto;
min-height: 0 !important;
overflow: hidden;
display: flex;
align-items: center;
justify-content: center;
}
#main_image > div,
#main_image .wrap,
#main_image .image-container,
#main_image .image-preview{
height: 100% !important;
min-height: 0 !important;
}
#main_image img{
max-width: 100% !important;
max-height: 100% !important;
object-fit: contain !important;
border-radius: 10px;
}
/* History strip (fixed) */
#history_gallery{
flex: 0 0 var(--history-h);
height: var(--history-h) !important;
margin-top: 0.5rem;
overflow: hidden;
}
#history_gallery .grid,
#history_gallery .gallery,
#history_gallery .wrap{
height: 100% !important;
overflow: hidden !important;
}
#history_gallery img{
height: calc(var(--history-h) - 28px) !important;
width: auto !important;
object-fit: cover !important;
border-radius: 8px;
}
/* Logs accordion compact */
#logs_accordion{ flex: 0 0 auto; }
#logs_accordion textarea{ min-height: 9.5rem !important; }
/* Generate button */
#generate_btn{
width: 100%;
font-weight: 700;
font-size: 1.02rem;
padding: 0.55rem;
margin-top: 0.35rem;
}
/* ===== Mobile: stack and allow normal page scroll ===== */
@media (max-width: 900px){
body{ overflow: auto; }
.gradio-container{ height: auto; overflow: visible; }
.layout-main{ grid-template-columns: 1fr; height: auto; }
.panel{ overflow: visible; }
.panel.controls .tabitem{ overflow: visible !important; }
#history_gallery{ height: 100px !important; flex-basis: 100px; }
#history_gallery img{ height: 68px !important; }
}
"""
# =========================
# Space-friendly caching
# =========================
os.environ.setdefault("HF_HOME", "/data/.huggingface" if os.path.isdir("/data") else os.environ.get("HF_HOME", ""))
if os.environ.get("HF_HOME"):
os.environ.setdefault("TRANSFORMERS_CACHE", os.path.join(os.environ["HF_HOME"], "hub"))
os.environ.setdefault("DIFFUSERS_CACHE", os.path.join(os.environ["HF_HOME"], "hub"))
# =========================
# Defaults
# =========================
ZIMAGE_REPO_DEFAULT = "Tongyi-MAI/Z-Image-Turbo"
ZIMAGE_LOCAL_DIR = str((pathlib.Path(__file__).resolve().parent / "models" / "zimage").resolve())
# Optional Turbo preservation patch LoRA (stacked, never fused)
DISTILLPATCH_REPO = "DiffSynth-Studio/Z-Image-Turbo-DistillPatch"
DISTILLPATCH_REVISION = "main"
LORAS_ROOT = pathlib.Path(__file__).resolve().parent / "loras"
LORAS_ROOT.mkdir(parents=True, exist_ok=True)
# =========================
# Startup prefetch (downloads on app start, not on first generate)
# =========================
PREFETCH_STATUS = {"ok": False, "msg": "Not started", "path": ""}
def prefetch_zimage_repo(repo_id: str, local_dir: str) -> str:
os.makedirs(local_dir, exist_ok=True)
snapshot_download(
repo_id=repo_id,
repo_type="model",
local_dir=local_dir,
)
return local_dir
try:
PREFETCH_STATUS["msg"] = f"Prefetching {ZIMAGE_REPO_DEFAULT}{ZIMAGE_LOCAL_DIR} ..."
prefetch_zimage_repo(ZIMAGE_REPO_DEFAULT, ZIMAGE_LOCAL_DIR)
PREFETCH_STATUS["ok"] = True
PREFETCH_STATUS["path"] = ZIMAGE_LOCAL_DIR
PREFETCH_STATUS["msg"] = f"FASTER"
except Exception as e:
PREFETCH_STATUS["ok"] = False
PREFETCH_STATUS["msg"] = f"⚠️ Prefetch failed: {type(e).__name__}: {e}"
# =========================
# Session + LoRA helpers
# =========================
def _new_session_id() -> str:
return hashlib.sha1(os.urandom(16)).hexdigest()[:10]
def _safe_filename(name: str) -> str:
name = (name or "").strip()
name = re.sub(r"[^a-zA-Z0-9._-]+", "_", name)
name = re.sub(r"_+", "_", name).strip("_")
return name or "lora"
def _sha1_file(path: str) -> str:
h = hashlib.sha1()
with open(path, "rb") as f:
for chunk in iter(lambda: f.read(1024 * 1024), b""):
h.update(chunk)
return h.hexdigest()[:12]
def _session_lora_dir(session_id: str) -> pathlib.Path:
d = LORAS_ROOT / f"session_{session_id}"
d.mkdir(parents=True, exist_ok=True)
return d
def _list_loras(session_id: str) -> List[str]:
d = _session_lora_dir(session_id)
items: List[str] = []
for ext in ("*.safetensors", "*.bin"):
for p in sorted(d.glob(ext)):
items.append(p.name)
return items
def _download_file(url: str, dst_path: pathlib.Path, progress_cb=None, headers=None) -> pathlib.Path:
headers = headers or {}
with requests.get(url, stream=True, headers=headers, timeout=60, allow_redirects=True) as r:
r.raise_for_status()
total = int(r.headers.get("content-length", "0") or "0")
dst_path.parent.mkdir(parents=True, exist_ok=True)
tmp = dst_path.with_suffix(dst_path.suffix + ".part")
downloaded = 0
with open(tmp, "wb") as f:
for chunk in r.iter_content(chunk_size=1024 * 1024):
if not chunk:
continue
f.write(chunk)
downloaded += len(chunk)
if progress_cb and total > 0:
progress_cb(downloaded / total)
tmp.replace(dst_path)
return dst_path
def _validate_downloaded_lora(path: pathlib.Path) -> None:
# A real LoRA is almost never < 1MB. If it is, it's usually HTML or an error blob.
if not path.exists():
raise RuntimeError("Downloaded LoRA file is missing.")
sz = path.stat().st_size
if sz < 1024 * 1024:
head = path.read_bytes()[:256]
raise RuntimeError(f"Downloaded file is too small ({sz} bytes). First bytes: {head!r} (likely HTML/error).")
head8 = path.read_bytes()[:8]
if head8.startswith(b"PK"):
raise RuntimeError("Downloaded file looks like a ZIP (PK...). Not a safetensors LoRA.")
if head8.lower().startswith(b"<html") or head8.startswith(b"<!DOCTYP"):
raise RuntimeError("Downloaded file is HTML (likely auth/permission issue).")
def refresh_loras_for_session(session_id: Optional[str]) -> Tuple[str, Any]:
sid = session_id or _new_session_id()
choices = ["<none>"] + _list_loras(sid)
return sid, gr.update(choices=choices, value="<none>")
def _civitai_extract_model_version_id(url: str) -> Optional[str]:
m = re.search(r"modelVersionId=(\d+)", url)
if m:
return m.group(1)
m = re.search(r"/model-versions/(\d+)", url)
if m:
return m.group(1)
return None
def _civitai_fetch_trigger_words(model_version_id: str, civitai_token: str) -> List[str]:
api = f"https://civitai.com/api/v1/model-versions/{model_version_id}"
headers = {"User-Agent": "Mozilla/5.0 (HF Spaces) ZImage-LoRA-Downloader"}
if civitai_token:
headers["Authorization"] = f"Bearer {civitai_token}"
headers["X-Api-Key"] = civitai_token
r = requests.get(api, headers=headers, timeout=30)
r.raise_for_status()
data = r.json()
words = data.get("trainedWords") or data.get("triggerWords") or []
if isinstance(words, str):
words = [words]
out = []
for w in words:
w = str(w).strip()
if w and w not in out:
out.append(w)
return out
def _merge_prompt_with_triggers(prompt: str, triggers: List[str]) -> str:
p = (prompt or "").strip()
if not triggers:
return p
existing = set(re.findall(r"[\w\-\:]+", p.lower()))
to_add = [t for t in triggers if t.lower() not in existing]
if not to_add:
return p
return (p + " " + " ".join(to_add)).strip()
def _is_hf_repo_id(s: str) -> bool:
return bool(re.fullmatch(r"[A-Za-z0-9_.-]+/[A-Za-z0-9_.-]+", (s or "").strip()))
def _hf_extract_repo_id(url: str) -> Optional[str]:
# https://huggingface.co/user/repo (optionally with /tree/main etc)
m = re.search(r"huggingface\.co/([^/\s]+)/([^/\s?#]+)", url)
if not m:
return None
return f"{m.group(1)}/{m.group(2)}"
def _hf_pick_lora_file(repo_id: str, revision: str = "main", hf_token: str = "") -> str:
files = list_repo_files(repo_id, revision=revision or "main", token=hf_token or None)
safes = [f for f in files if f.lower().endswith(".safetensors")]
if not safes:
raise RuntimeError("No .safetensors files found in that Hugging Face repo.")
if len(safes) == 1:
return safes[0]
# Prefer files that look like LoRAs
lora_like = [f for f in safes if "lora" in f.lower()]
return lora_like[0] if lora_like else safes[0]
def _hf_fetch_trigger_words(repo_id: str, revision: str = "main", hf_token: str = "") -> List[str]:
"""
Best-effort: download README/model card and extract trigger words.
Looks for patterns like: trigger word(s): `xxx` or Trigger: xxx, or "trained words".
"""
candidates = ["README.md", "readme.md", "README.MD", "modelcard.md", "MODEL_CARD.md"]
text = ""
for fn in candidates:
try:
p = hf_hub_download(repo_id=repo_id, filename=fn, revision=revision or "main", token=hf_token or None)
text = pathlib.Path(p).read_text(encoding="utf-8", errors="ignore")
if text.strip():
break
except Exception:
continue
if not text:
return []
words: List[str] = []
# 1) backticked tokens after "trigger"
for m in re.finditer(r"(?i)trigger\w*\s*(?:words?|token|phrase)?\s*[:\-]\s*([^\n]+)", text):
chunk = m.group(1)
# collect backticked and comma-separated tokens
back = re.findall(r"`([^`]+)`", chunk)
if back:
for b in back:
b = b.strip()
if b and b not in words:
words.append(b)
else:
# split common separators
for part in re.split(r"[,\|/]", chunk):
part = part.strip()
# keep short-ish single tokens
if 1 <= len(part) <= 64 and re.fullmatch(r"[A-Za-z0-9_\-:]+", part):
if part not in words:
words.append(part)
# 2) "trained words" style
for m in re.finditer(r"(?i)trained\s*words?\s*[:\-]\s*([^\n]+)", text):
chunk = m.group(1)
for b in re.findall(r"`([^`]+)`", chunk):
b = b.strip()
if b and b not in words:
words.append(b)
# prune empties
words = [w for w in words if w.strip()]
return words[:20]
def download_lora_for_session(
session_id: Optional[str],
lora_url: str,
lora_filename: str,
hf_token: str,
civitai_token: str,
current_prompt: str,
progress=gr.Progress(track_tqdm=False),
):
sid = session_id or _new_session_id()
url = (lora_url or "").strip()
if not url:
return sid, gr.update(choices=["<none>"] + _list_loras(sid), value="<none>"), "❌ Please provide a URL or repo id.", "", current_prompt
# env/secret fallback
hf_token = (hf_token or "").strip() or os.environ.get("HF_TOKEN", "").strip() or os.environ.get("HUGGINGFACE_HUB_TOKEN", "").strip()
civitai_token = (civitai_token or "").strip() or os.environ.get("CIVITAI_TOKEN", "").strip() or os.environ.get("CIVITAI_API_KEY", "").strip()
trigger_words: List[str] = []
mv_id = None
lower_url = url.lower()
# =========================
# Hugging Face repo-id / repo URL (EARLY RETURN)
# =========================
hf_repo_id = None
hf_revision = "main"
if _is_hf_repo_id(url):
hf_repo_id = url
elif "huggingface.co/" in lower_url and "/resolve/" not in lower_url:
hf_repo_id = _hf_extract_repo_id(url)
if hf_repo_id:
try:
filename_in_repo = _hf_pick_lora_file(hf_repo_id, revision=hf_revision, hf_token=hf_token)
out_dir = _session_lora_dir(sid)
base = _safe_filename(lora_filename) if lora_filename else _safe_filename(pathlib.Path(filename_in_repo).name)
if not base.lower().endswith(".safetensors"):
base += ".safetensors"
dst = out_dir / base
local_path = hf_hub_download(repo_id=hf_repo_id, filename=filename_in_repo, revision=hf_revision, token=hf_token or None)
# Copy into session folder so it appears in dropdown
if not dst.exists():
shutil.copyfile(local_path, dst)
_validate_downloaded_lora(dst)
# Trigger words from README (best effort)
try:
trigger_words = _hf_fetch_trigger_words(hf_repo_id, revision=hf_revision, hf_token=hf_token)
except Exception:
trigger_words = []
triggers_text = ", ".join(trigger_words) if trigger_words else ""
new_prompt = _merge_prompt_with_triggers(current_prompt, trigger_words)
msg = f"✅ Downloaded from Hugging Face: {hf_repo_id}/{filename_in_repo}"
return (
sid,
gr.update(choices=["<none>"] + _list_loras(sid), value=dst.name),
msg,
triggers_text,
new_prompt,
)
except Exception as e:
msg = f"❌ Hugging Face download failed: {type(e).__name__}: {e}"
return sid, gr.update(choices=["<none>"] + _list_loras(sid), value="<none>"), msg, "", current_prompt
# =========================
# CivitAI conversion
# =========================
if "civitai.com" in lower_url:
mv_id = _civitai_extract_model_version_id(url)
if mv_id:
url = f"https://civitai.com/api/download/models/{mv_id}?type=Model&format=SafeTensor"
lower_url = url.lower()
out_dir = _session_lora_dir(sid)
base = _safe_filename(lora_filename) if lora_filename else _safe_filename(pathlib.Path(url.split("?")[0]).name)
if not base.lower().endswith((".safetensors", ".bin")):
base += ".safetensors"
dst = out_dir / base
headers = {"User-Agent": "Mozilla/5.0 (HF Spaces) ZImage-LoRA-Downloader"}
# Auth headers
if "civitai.com/api/download/models" in lower_url:
if civitai_token:
headers["Authorization"] = f"Bearer {civitai_token}"
headers["X-Api-Key"] = civitai_token
else:
return sid, gr.update(choices=["<none>"] + _list_loras(sid), value="<none>"), (
"❌ CivitAI download requires a token. Set the Space Secret CIVITAI_TOKEN."
), "", current_prompt
if "huggingface.co" in lower_url and "/resolve/" in lower_url and hf_token:
headers["Authorization"] = f"Bearer {hf_token}"
def _cb(p):
progress(p)
try:
_download_file(url, dst, progress_cb=_cb, headers=headers)
_validate_downloaded_lora(dst)
msg = f"✅ Downloaded: {dst.name}"
except Exception as e:
msg = f"❌ Download failed: {type(e).__name__}: {e}"
return sid, gr.update(choices=["<none>"] + _list_loras(sid), value="<none>"), msg, "", current_prompt
# Trigger words from CivitAI
try:
if mv_id:
trigger_words = _civitai_fetch_trigger_words(mv_id, civitai_token)
except Exception:
trigger_words = []
triggers_text = ", ".join(trigger_words) if trigger_words else ""
new_prompt = _merge_prompt_with_triggers(current_prompt, trigger_words)
return (
sid,
gr.update(choices=["<none>"] + _list_loras(sid), value=dst.name if dst.exists() else "<none>"),
msg,
triggers_text,
new_prompt,
)
def _resolve_lora_path(session_id: str, lora_choice: str) -> str:
if not lora_choice or lora_choice == "<none>":
return ""
return str(_session_lora_dir(session_id) / lora_choice)
# =========================
# SeedVarianceEnhancer v2.1 (repo-like)
# =========================
@torch.no_grad()
def _auto_mask_trailing_zeros(embeds: torch.Tensor) -> torch.Tensor:
token_is_zero = (embeds.abs().sum(dim=-1, keepdim=True) == 0)
return ~token_is_zero
@torch.no_grad()
def sve_apply(
prompt_embeds: torch.Tensor,
seed: int,
strength: float,
randomize_percent: float,
mask_starts_at_mode: str = "beginning",
mask_percent: float = 0.0,
) -> torch.Tensor:
if strength == 0 or randomize_percent <= 0:
return prompt_embeds
device = prompt_embeds.device
g = torch.Generator(device=device).manual_seed(int(seed) & 0x7FFFFFFF)
# Some pipelines return prompt_embeds as 2D (B,C) instead of 3D (B,T,C).
sve__squeeze_token_dim = False
if isinstance(prompt_embeds, torch.Tensor) and prompt_embeds.ndim == 2:
prompt_embeds = prompt_embeds.unsqueeze(1)
sve__squeeze_token_dim = True
elif isinstance(prompt_embeds, torch.Tensor) and prompt_embeds.ndim == 1:
prompt_embeds = prompt_embeds.view(1, 1, -1)
sve__squeeze_token_dim = True
p = float(randomize_percent) / 100.0
elem_mask = torch.ones_like(prompt_embeds, dtype=torch.bool) if p >= 1.0 else (torch.rand(prompt_embeds.shape, device=device, dtype=prompt_embeds.dtype, generator=g) < p)
B, T, _ = prompt_embeds.shape
token_allow = torch.ones((B, T, 1), device=device, dtype=torch.bool)
mp = max(0.0, min(100.0, float(mask_percent)))
if mp > 0:
length = int(round((mp / 100.0) * T))
length = max(0, min(T, length))
mode = (mask_starts_at_mode or "beginning").strip().lower()
if length > 0:
if mode in ("begin", "beginning", "start"):
start = 0
elif mode in ("middle", "center", "centre"):
start = max(0, (T - length) // 2)
elif mode in ("end", "ending", "finish"):
start = max(0, T - length)
else:
# fallback: treat unknown as beginning
start = 0
end = min(T, start + length)
if start < T and end > start:
# False = protected region (no noise)
token_allow[:, start:end, :] = False
token_allow = token_allow & _auto_mask_trailing_zeros(prompt_embeds)
mask = elem_mask & token_allow
noise = torch.randn(prompt_embeds.shape, device=device, dtype=prompt_embeds.dtype, generator=g) * float(strength)
out = torch.where(mask, prompt_embeds + noise, prompt_embeds)
if 'sve__squeeze_token_dim' in locals() and sve__squeeze_token_dim:
# Back to original rank
if out.ndim == 3 and out.shape[1] == 1:
out = out.squeeze(1)
return out
def sve_apply_any(
prompt_embeds,
seed: int,
strength: float,
randomize_percent: float,
mask_starts_at_mode: str = "beginning",
mask_percent: float = 0.0,
):
"""Apply SVE to a tensor or a (nested) list/tuple of tensors.
Some Diffusers pipelines return prompt_embeds as a list (multi-encoder); we preserve structure.
"""
if isinstance(prompt_embeds, (list, tuple)):
out = []
# use different sub-seeds so encoders don't get identical noise patterns
for i, pe in enumerate(prompt_embeds):
out.append(
sve_apply_any(
pe,
seed=int(seed) + int(i),
strength=strength,
randomize_percent=randomize_percent,
mask_starts_at_mode=mask_starts_at_mode,
mask_percent=mask_percent,
)
)
return out if isinstance(prompt_embeds, list) else tuple(out)
# base case: tensor
return sve_apply(
prompt_embeds=prompt_embeds,
seed=seed,
strength=strength,
randomize_percent=randomize_percent,
mask_starts_at_mode=mask_starts_at_mode,
mask_percent=mask_percent,
)
def _build_sve_callback(clean_embeds: torch.Tensor, noisy_embeds: torch.Tensor, switch_step: int):
# Backwards-compatible: noisy before switch_step, clean after.
def _cb(pipe, step_index: int, timestep: int, callback_kwargs: Dict[str, Any]):
callback_kwargs["prompt_embeds"] = noisy_embeds if step_index < switch_step else clean_embeds
return callback_kwargs
return _cb
def _build_sve_callback_mode(clean_embeds: torch.Tensor, noisy_embeds: torch.Tensor, start_step: int, mode: str):
"""Mode-aware callback.
- mode='beginning': noisy before start_step, clean after (start_step = switch)
- mode='ending': clean before start_step, noisy after (start_step = start of noisy region)
"""
mode = (mode or 'beginning').strip().lower()
def _cb(pipe, step_index: int, timestep: int, callback_kwargs: Dict[str, Any]):
if mode in ('end', 'ending'):
callback_kwargs['prompt_embeds'] = noisy_embeds if step_index >= start_step else clean_embeds
else:
callback_kwargs['prompt_embeds'] = noisy_embeds if step_index < start_step else clean_embeds
return callback_kwargs
return _cb
# =========================
# GPU: cached Z-Image pipeline
# =========================
GPU_ENV: Dict[str, Any] = {"pipe": None, "dtype": None, "loaded_loras": {}, "distillpatch_loaded": False, "distillpatch_adapter": "distillpatch"}
CPU_ENV: Dict[str, Any] = {"pipe": None, "path": None}
def _ensure_cpu_pipe() -> None:
"""Load the Z-Image pipeline ON CPU once per container (ZeroGPU-safe)."""
if CPU_ENV.get("pipe") is not None:
return
from diffusers import ZImagePipeline
load_path = PREFETCH_STATUS.get("path") or ZIMAGE_LOCAL_DIR or ZIMAGE_REPO_DEFAULT
CPU_ENV["path"] = load_path
# NOTE: keep on CPU; do NOT touch CUDA outside @spaces.GPU
CPU_ENV["pipe"] = ZImagePipeline.from_pretrained(
load_path,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=False,
)
def _torch_dtype(name: str):
name = (name or "bf16").lower()
if name in ("bf16", "bfloat16"):
return torch.bfloat16
if name in ("fp16", "float16", "half"):
return torch.float16
return torch.float32
@spaces.GPU
def generate_route(
session_id: str,
prompt: str,
negative: str,
steps: int,
cfg: float,
width: int,
height: int,
seed: int,
zimage_shift: float,
max_sequence_length: int,
precision: str,
use_lora: bool,
use_distillpatch: bool,
selected_lora: str,
lora_scale: float,
distillpatch_scale: float,
sve_enabled: bool,
sve_noise_insert: str,
sve_steps_switchover_percent: float,
sve_seed: int,
sve_control_after_generate: str,
sve_strength: float,
sve_random_percent: float,
sve_mask_starts_at: str,
sve_mask_percent: float,
sve_log_to_console: bool,
) -> Tuple[str, List[Image.Image], str, str]:
logs: List[str] = []
t0 = time.time()
seed = int(seed) if seed is not None else -1
if seed == -1:
seed = random.randint(0, 2**31 - 1)
from diffusers import ZImagePipeline, FlowMatchEulerDiscreteScheduler
# We pin dtype to bf16.
dtype = torch.bfloat16
# ZeroGPU-safe strategy:
# - Load the pipeline ON CPU once per container (outside GPU allocation).
# - On each GPU call, move that already-loaded pipeline to CUDA.
# This avoids re-reading checkpoint shards on every request.
_ensure_cpu_pipe()
if GPU_ENV["pipe"] is None:
logs.append(f"📦 Moving Z-Image pipeline to CUDA ({dtype})")
# Move the already-loaded CPU pipeline to GPU for this allocated GPU session.
GPU_ENV["pipe"] = CPU_ENV["pipe"].to("cuda")
GPU_ENV["dtype"] = str(dtype)
GPU_ENV["compiled"] = False
pipe = GPU_ENV["pipe"]
# Update scheduler shift each run
try:
pipe.scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=float(zimage_shift))
except Exception as e:
logs.append(f"⚠️ scheduler shift failed: {type(e).__name__}: {e}")
# ---- Optional Turbo DistillPatch LoRA (stacked, never fused) ----
# Note: DiffSynth's reference snippet loads user LoRA first, then DistillPatch.
# In Diffusers/PEFT multi-adapter mode the effects are generally additive, but we keep
# DistillPatch LAST in the adapter list to mirror the official intent.
active_adapters: List[str] = []
active_weights: List[float] = []
distill_ready = False
if use_distillpatch:
try:
if not GPU_ENV.get("distillpatch_loaded", False):
logs.append(f"🩹 Loading DistillPatch: {DISTILLPATCH_REPO}")
# DistillPatch repo typically exposes `model.safetensors`.
# We try that first for determinism; then fall back to first .safetensors in repo.
try:
pipe.load_lora_weights(
DISTILLPATCH_REPO,
revision=DISTILLPATCH_REVISION,
weight_name="model.safetensors",
adapter_name=GPU_ENV["distillpatch_adapter"],
)
except Exception:
files = list_repo_files(DISTILLPATCH_REPO, revision=DISTILLPATCH_REVISION, token=None)
cand = [f for f in files if f.lower().endswith(".safetensors")]
pipe.load_lora_weights(
DISTILLPATCH_REPO,
revision=DISTILLPATCH_REVISION,
weight_name=(cand[0] if cand else None),
adapter_name=GPU_ENV["distillpatch_adapter"],
)
GPU_ENV["distillpatch_loaded"] = True
distill_ready = True
except Exception as e:
logs.append(f"⚠️ DistillPatch load failed: {type(e).__name__}: {e}")
distill_ready = False
# LoRA (NO FUSING)
if use_lora and selected_lora and selected_lora != "<none>":
lora_path = _resolve_lora_path(session_id, selected_lora)
if os.path.exists(lora_path):
lora_abs = os.path.abspath(lora_path)
try:
_validate_downloaded_lora(pathlib.Path(lora_abs))
except Exception as e:
logs.append(f"❌ LoRA file invalid: {e}")
mem_html = f"<span style='font-family:monospace;font-size:0.76rem;'>seed={seed} · steps={steps} · shift={zimage_shift} · dtype={dtype}</span>"
return session_id, [], "\n".join(logs), mem_html
key = (str(dtype), lora_abs)
if key in GPU_ENV["loaded_loras"]:
adapter_name = GPU_ENV["loaded_loras"][key]
logs.append(f"🧩 LoRA already loaded: {selected_lora} as {adapter_name}")
else:
adapter_name = f"lora_{_sha1_file(lora_abs)}"
logs.append(f"🧩 Loading LoRA: {selected_lora} -> {adapter_name}")
pipe.load_lora_weights(lora_abs, adapter_name=adapter_name)
GPU_ENV["loaded_loras"][key] = adapter_name
active_adapters.append(adapter_name)
active_weights.append(float(lora_scale))
# Keep DistillPatch last (mirror official snippet)
if distill_ready:
active_adapters.append(GPU_ENV["distillpatch_adapter"])
active_weights.append(float(distillpatch_scale))
try:
pipe.set_adapters(active_adapters, adapter_weights=active_weights)
except Exception as e:
logs.append(f"⚠️ set_adapters failed: {type(e).__name__}: {e}")
else:
logs.append("⚠️ LoRA file missing on disk")
else:
# No user LoRA. Keep only DistillPatch if enabled; otherwise clear.
try:
if distill_ready:
pipe.set_adapters([GPU_ENV["distillpatch_adapter"]], adapter_weights=[float(distillpatch_scale)])
else:
pipe.set_adapters([])
except Exception:
pass
# Ensure dimensions are valid for Z-Image (multiple of 16)
w0, h0 = int(width), int(height)
w = max(16, w0 - (w0 % 16))
h = max(16, h0 - (h0 % 16))
if (w, h) != (w0, h0):
logs.append(f"ℹ️ Rounded size to multiple-of-16: {w0}x{h0} -> {w}x{h}")
width, height = w, h
generator = torch.Generator(device="cuda").manual_seed(seed)
call_kwargs: Dict[str, Any] = dict(
height=int(height),
width=int(width),
num_inference_steps=int(steps),
guidance_scale=float(cfg),
generator=generator,
max_sequence_length=int(max_sequence_length),
)
# SeedVarianceEnhancer (1:1-ish with ChangeTheConstants node UI)
# noise_insert: disabled | noise on beginning steps | noise on all steps | noise on ending steps
if sve_enabled:
mode_raw = (sve_noise_insert or "disabled").strip().lower()
if mode_raw.startswith("noise on beginning"):
mode = "beginning"
elif mode_raw.startswith("noise on ending"):
mode = "ending"
elif mode_raw.startswith("noise on all"):
mode = "all"
else:
mode = "disabled"
# Node uses a separate seed for embedding noise; allow -1 to "follow main seed"
sve_seed_used = int(sve_seed) if sve_seed is not None else -1
if sve_seed_used == -1:
sve_seed_used = int(seed)
# switchover percent means: beginning steps noise % OR ending steps noise %
sw = max(0.0, min(100.0, float(sve_steps_switchover_percent)))
k = int(round((sw / 100.0) * int(steps)))
k = max(0, min(int(steps), k))
if mode == "beginning":
switch_step = k
elif mode == "ending":
switch_step = max(0, int(steps) - k) # start of noisy region
else:
switch_step = 0
logs.append(
f"🎲 SVE: mode={mode} seed={sve_seed_used} strength={float(sve_strength)} rand%={float(sve_random_percent)} "
f"sw%={sw} mask_start={sve_mask_starts_at} mask%={float(sve_mask_percent)}"
)
try:
dev = pipe._execution_device if hasattr(pipe, "_execution_device") else "cuda"
# Encode to embeddings (diffusers variant-safe)
if hasattr(pipe, "encode_prompt"):
pe, ne = pipe.encode_prompt(
prompt=prompt,
negative_prompt=negative,
device=dev,
do_classifier_free_guidance=True,
max_sequence_length=int(max_sequence_length),
)
else:
pe, ne = pipe._encode_prompt(
prompt=prompt,
negative_prompt=negative,
device=dev,
num_images_per_prompt=1,
do_classifier_free_guidance=True,
max_sequence_length=int(max_sequence_length),
)
clean = pe
noisy = sve_apply_any(
prompt_embeds=clean,
seed=sve_seed_used,
strength=float(sve_strength),
randomize_percent=float(sve_random_percent),
mask_starts_at_mode=str(sve_mask_starts_at),
mask_percent=float(sve_mask_percent),
)
# Apply according to mode
if mode == "disabled" or float(sve_strength) == 0 or float(sve_random_percent) <= 0:
pass
elif mode == "all":
call_kwargs["prompt_embeds"] = noisy
call_kwargs["negative_prompt_embeds"] = ne
elif mode == "beginning":
# noisy for first switch_step steps
call_kwargs["prompt_embeds"] = noisy if switch_step > 0 else clean
call_kwargs["negative_prompt_embeds"] = ne
if 0 < switch_step < int(steps):
call_kwargs["callback_on_step_end"] = _build_sve_callback_mode(clean, noisy, switch_step, "beginning")
call_kwargs["callback_on_step_end_tensor_inputs"] = ["prompt_embeds"]
elif mode == "ending":
# noisy for last k steps starting at switch_step
call_kwargs["prompt_embeds"] = clean
call_kwargs["negative_prompt_embeds"] = ne
if 0 <= switch_step < int(steps):
call_kwargs["callback_on_step_end"] = _build_sve_callback_mode(clean, noisy, switch_step, "ending")
call_kwargs["callback_on_step_end_tensor_inputs"] = ["prompt_embeds"]
except Exception as e:
logs.append(f"⚠️ SVE embedding path failed: {type(e).__name__}: {e} (fall back to raw prompt)")
# Optional log_to_console (mirrors node toggle)
if sve_log_to_console:
try:
print("[SVE]", logs[-1])
except Exception:
pass
logs.append("🚀 Generating…")
with torch.inference_mode():
out = pipe(
prompt=None if "prompt_embeds" in call_kwargs else prompt,
negative_prompt=None if "negative_prompt_embeds" in call_kwargs else negative,
**call_kwargs,
)
img = out.images[0]
logs.append(f"✅ Done in {time.time() - t0:.2f}s (seed={seed})")
mem_html = f"<span style='font-family:monospace;font-size:0.76rem;'>seed={seed} · steps={steps} · shift={zimage_shift} · dtype={dtype}</span>"
return session_id, [img], "\n".join(logs), mem_html
# =========================
# Gallery normalization
# =========================
def _normalize_gallery_list(lst):
from PIL import Image as PILImage
import numpy as np
if not isinstance(lst, list):
return []
out = []
for item in lst:
if isinstance(item, PILImage.Image):
out.append(item)
continue
if isinstance(item, np.ndarray):
arr = item
if arr.dtype != np.uint8:
arr = np.clip(arr, 0, 255).astype(np.uint8)
if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[0] != arr.shape[-1]:
arr = np.transpose(arr, (1, 2, 0))
out.append(PILImage.fromarray(arr))
continue
if isinstance(item, tuple) and item:
sub = _normalize_gallery_list([item[0]])
if sub:
out.append(sub[0])
continue
if isinstance(item, dict):
data = item.get("image") or item.get("data") or item.get("value")
sub = _normalize_gallery_list([data])
if sub:
out.append(sub[0])
continue
return out
# =========================
# Build UI
# =========================
with gr.Blocks() as demo:
session_state = gr.State(value=None)
# Hidden token inputs (kept to match downloader signature; env/secrets fallback)
hf_token_hidden = gr.Textbox(value="", visible=False)
civitai_token_hidden = gr.Textbox(value="", visible=False)
with gr.Row(elem_classes="layout-main"):
# -------- LEFT: Controls --------
with gr.Column(elem_classes=["panel", "controls"]):
with gr.Tabs():
# --- Generate ---
with gr.Tab("Generate"):
with gr.Column(elem_id="left_box"):
gr.Markdown("### Generate")
prefetch_status = gr.Markdown(value=PREFETCH_STATUS.get("msg", ""))
run_btn = gr.Button("Generate 🎨", elem_id="generate_btn")
with gr.Row(elem_classes="compact-row"):
prompt = gr.Textbox(label="Prompt", value="", lines=2, placeholder="Describe what you want to see...")
negative = gr.Textbox(label="Negative", value="", lines=2, placeholder="What to avoid...")
with gr.Row(elem_classes="compact-row"):
steps = gr.Slider(1, 64, 8, step=1, label="Steps")
cfg = gr.Slider(0.0, 8.0, 1.0, step=0.1, label="CFG")
with gr.Row(elem_classes="compact-row"):
seed = gr.Number(-1, label="Seed (-1=random)")
precision = gr.Dropdown(choices=["bf16", "fp16", "fp32"], value="bf16", label="Precision")
with gr.Row(elem_classes="compact-row"):
width = gr.Slider(256, 1536, 1024, step=16, label="Width")
height = gr.Slider(256, 1536, 1024, step=16, label="Height")
zimage_shift = gr.Slider(0.0, 10.0, 5.0, step=0.1, label="FlowMatch shift")
max_sequence_length = gr.Slider(64, 512, 512, step=8, label="Max sequence length")
with gr.Row(elem_classes="compact-row"):
use_lora = gr.Checkbox(False, label="Use LoRA")
use_distillpatch = gr.Checkbox(False, label="Turbo DistillPatch (default OFF)")
lora_dropdown = gr.Dropdown(choices=["<none>"], value="<none>", label="LoRA file")
refresh_loras_btn = gr.Button("🔄 Refresh")
with gr.Row(elem_classes="compact-row"):
lora_scale = gr.Slider(0.0, 2.0, 0.8, step=0.05, label="LoRA scale")
distillpatch_scale = gr.Slider(0.0, 2.0, 1.0, step=0.05, label="DistillPatch scale")
# ---- SeedVarianceEnhancer (node-like UI) ----
with gr.Row(elem_classes="compact-row"):
sve_enabled = gr.Checkbox(value=False, label="SeedVarianceEnhancer")
sve_noise_insert = gr.Dropdown(
choices=[
"disabled",
"noise on beginning steps",
"noise on all steps",
"noise on ending steps",
],
value="noise on beginning steps",
label="noise_insert",
)
with gr.Row(elem_classes="compact-row"):
sve_random_percent = gr.Slider(0.0, 100.0, 50.0, step=1.0, label="randomize_percent")
sve_strength = gr.Slider(0.0, 80.0, 20.0, step=1.0, label="strength")
with gr.Row(elem_classes="compact-row"):
sve_steps_switchover = gr.Slider(0.0, 100.0, 20.0, step=1.0, label="steps_switchover_percent")
sve_seed = gr.Number(2019, label="seed")
sve_control_after_generate = gr.Dropdown(
choices=["keep", "increment", "randomize"],
value="randomize",
label="control_after_generate",
)
with gr.Row(elem_classes="compact-row"):
sve_mask_starts = gr.Dropdown(
choices=["beginning", "middle", "end"],
value="beginning",
label="mask_starts_at",
)
sve_mask_percent = gr.Slider(0.0, 100.0, 0.0, step=1.0, label="mask_percent")
sve_log_to_console = gr.Checkbox(value=False, label="log_to_console")
trigger_words_box = gr.Textbox(label="Detected trigger words", value="", lines=1, interactive=False)
# --- LoRA Downloader ---
with gr.Tab("LoRA Downloader"):
gr.Markdown("### Download a LoRA")
gr.Markdown("**Tip:** Use **CivitAI model URL** or **Hugging Face URL/repo id** (e.g. `user/repo`).")
with gr.Row(elem_classes="compact-row"):
lora_url = gr.Textbox(label="LoRA URL / Repo ID", placeholder="Civitai model page URL, HF repo id (user/repo), or direct file URL…")
lora_filename = gr.Textbox(label="Save name (optional)", placeholder="my_lora_name")
download_lora_btn = gr.Button("📥 Download LoRA")
lora_dl_log = gr.Textbox(label="Downloader log", lines=6, interactive=False)
gr.Markdown("After downloading: go to **Generate** → **Refresh**.")
# -------- RIGHT: Output --------
with gr.Column(elem_classes=["panel", "output"]):
main_image = gr.Image(label="Result", interactive=False, elem_id="main_image")
history_gallery = gr.Gallery(
label="History",
columns=6,
height=110,
elem_id="history_gallery",
type="pil",
)
with gr.Accordion("Logs & system info", open=False, elem_id="logs_accordion"):
log_box = gr.Textbox(label="Generation Logs", lines=10, interactive=False)
mem_status = gr.HTML(value="<span style='font-family:monospace;font-size:0.76rem;'>No memory data yet.</span>")
download_lora_btn.click(
fn=download_lora_for_session,
inputs=[session_state, lora_url, lora_filename, hf_token_hidden, civitai_token_hidden, prompt],
outputs=[session_state, lora_dropdown, lora_dl_log, trigger_words_box, prompt],
)
refresh_loras_btn.click(
fn=refresh_loras_for_session,
inputs=[session_state],
outputs=[session_state, lora_dropdown],
)
def generate_and_clip_gallery(
session_id,
prompt,
negative,
steps,
cfg,
width,
height,
seed,
zimage_shift,
max_sequence_length,
precision,
use_lora,
use_distillpatch,
selected_lora,
lora_scale,
distillpatch_scale,
sve_enabled,
sve_noise_insert,
sve_steps_switchover,
sve_seed,
sve_control_after_generate,
sve_strength,
sve_random_percent,
sve_mask_starts,
sve_mask_percent,
sve_log_to_console,
current_history,
):
session_id = session_id or _new_session_id()
# Call GPU route with a small retry: ZeroGPU can race CUDA init
last_err = None
mem_html = ""
for _attempt in range(6):
try:
session_id, images, logs, mem_html = generate_route(
session_id,
prompt,
negative,
steps,
cfg,
width,
height,
seed,
zimage_shift,
max_sequence_length,
precision,
use_lora,
use_distillpatch,
selected_lora,
lora_scale,
distillpatch_scale,
sve_enabled,
sve_noise_insert,
sve_steps_switchover,
sve_seed,
sve_control_after_generate,
sve_strength,
sve_random_percent,
sve_mask_starts,
sve_mask_percent,
sve_log_to_console,
)
last_err = None
break
except Exception as e:
# ZeroGPU may briefly fail CUDA init if a GPU is not yet allocated.
# Never crash the UI: retry a few times, then return a friendly message.
last_err = e
msg = str(e) or repr(e)
low = msg.lower()
transient_cuda = (
"cuda driver initialization failed" in low
or "might not have a cuda gpu" in low
or "_cuda_init" in low
or "torch.init" in low
or "cuda initialization" in low
)
if transient_cuda:
time.sleep(0.7)
continue
# Not a GPU-allocation race: surface the real error in logs, but keep the app alive.
logs = f"❌ Generation failed: {type(e).__name__}: {msg}"
history_list = _normalize_gallery_list(current_history)
new_history = history_list[-10:] if len(history_list) > 10 else history_list
main = new_history[-1] if new_history else None
sve_seed_out = int(sve_seed) if sve_seed is not None else 2019
return session_id, main, new_history, logs, mem_html, sve_seed_out
# Node-like control_after_generate update for the UI seed field
sve_seed_out = int(sve_seed) if sve_seed is not None else 2019
if sve_enabled:
mode = (sve_control_after_generate or "keep").strip().lower()
if mode == "increment":
sve_seed_out += 1
elif mode == "randomize":
sve_seed_out = random.randint(0, 2**31 - 1)
if last_err is not None:
# Graceful error: keep history unchanged; no new images
logs = "⚠️ GPU wasn’t ready (ZeroGPU race). Please click **Generate** again."
mem_html = mem_html if "mem_html" in locals() else ""
history_list = _normalize_gallery_list(current_history)
new_history = history_list[-10:] if len(history_list) > 10 else history_list
main = new_history[-1] if new_history else None
return session_id, main, new_history, logs, mem_html, sve_seed_out
history_list = _normalize_gallery_list(current_history)
images = _normalize_gallery_list(images)
new_history = (history_list + images)[-10:]
main = new_history[-1] if new_history else None
return session_id, main, new_history, logs, mem_html, sve_seed_out
run_btn.click(
fn=generate_and_clip_gallery,
inputs=[
session_state,
prompt,
negative,
steps,
cfg,
width,
height,
seed,
zimage_shift,
max_sequence_length,
precision,
use_lora,
use_distillpatch,
lora_dropdown,
lora_scale,
distillpatch_scale,
sve_enabled,
sve_noise_insert,
sve_steps_switchover,
sve_seed,
sve_control_after_generate,
sve_strength,
sve_random_percent,
sve_mask_starts,
sve_mask_percent,
sve_log_to_console,
history_gallery,
],
outputs=[session_state, main_image, history_gallery, log_box, mem_status, sve_seed],
)
def select_from_history(history, evt: gr.SelectData):
normalized = _normalize_gallery_list(history)
idx = evt.index
if isinstance(idx, int) and 0 <= idx < len(normalized):
return normalized[idx]
return None
history_gallery.select(fn=select_from_history, inputs=[history_gallery], outputs=[main_image])
def _on_load(session_id):
_ensure_cpu_pipe()
sid, dd = refresh_loras_for_session(session_id)
return sid, dd, PREFETCH_STATUS.get("msg", "")
demo.load(fn=_on_load, inputs=[session_state], outputs=[session_state, lora_dropdown, prefetch_status])
demo.queue(default_concurrency_limit=1, max_size=20).launch(ssr_mode=False, css=RESPONSIVE_CSS, show_error=False)