Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import time | |
| import random | |
| import pathlib | |
| import hashlib | |
| import shutil | |
| from typing import Dict, Any, Tuple, List, Optional | |
| import gradio as gr | |
| import requests | |
| import torch | |
| from PIL import Image | |
| import spaces | |
| from huggingface_hub import snapshot_download, hf_hub_download, list_repo_files | |
| # ========================= | |
| # UI CSS (provided) | |
| # ========================= | |
| RESPONSIVE_CSS = """ | |
| /* ===== Tunables ===== */ | |
| :root{ | |
| --hf-topbar: 64px; /* HF header height */ | |
| --left-col: 30%; /* width of left column */ | |
| --gap: 0.75rem; | |
| --pad-y: 0.60rem; | |
| --pad-x: 0.85rem; | |
| /* Global scale: 0.75 = ~25% smaller, 0.66 = ~34% smaller */ | |
| --ui-zoom: 0.75; | |
| /* Output sizing */ | |
| --history-h: 110px; | |
| } | |
| /* Scale everything via root font-size (more reliable than transform:scale) */ | |
| html, body { height: 100%; } | |
| html { font-size: calc(16px * var(--ui-zoom)); } | |
| body { margin: 0; overflow: hidden; } | |
| /* HuggingFace Spaces wrapper */ | |
| .gradio-container{ | |
| max-width: 100vw !important; | |
| height: calc(100vh - var(--hf-topbar)); | |
| overflow: hidden; | |
| padding: var(--pad-y) var(--pad-x); | |
| box-sizing: border-box; | |
| } | |
| /* ===== Main two-column layout fills viewport ===== */ | |
| .layout-main{ | |
| display: grid !important; | |
| grid-template-columns: minmax(320px, var(--left-col)) 1fr; | |
| gap: var(--gap); | |
| align-items: stretch; | |
| height: 100%; | |
| min-height: 0; | |
| } | |
| /* ===== Panels ===== */ | |
| .panel{ | |
| background: #0f172a; | |
| border-radius: 12px; | |
| padding: 0.75rem; | |
| box-shadow: 0 10px 26px rgba(0,0,0,0.40); | |
| border: 1px solid rgba(255,255,255,0.06); | |
| overflow: hidden; | |
| min-height: 0; | |
| } | |
| /* ===== Left panel ===== */ | |
| .panel.controls{ | |
| display: flex; | |
| flex-direction: column; | |
| min-height: 0; | |
| overflow: hidden; | |
| } | |
| /* Tabs need to be allowed to shrink */ | |
| .panel.controls .gr-tabs{ min-height: 0; overflow: hidden; } | |
| /* Make tab content scroll only if absolutely necessary */ | |
| .panel.controls .tabitem{ | |
| min-height: 0 !important; | |
| overflow: auto !important; | |
| padding-right: 2px; | |
| } | |
| /* Single big box in Generate tab */ | |
| #left_box{ | |
| display: flex; | |
| flex-direction: column; | |
| gap: 0.6rem; | |
| min-height: 0; | |
| } | |
| /* Tighten default Gradio spacing */ | |
| .panel .gr-markdown{ margin: 0.15rem 0 0.35rem !important; } | |
| .panel .gr-form{ gap: 0.35rem; } | |
| .compact-row{ gap: 0.45rem; } | |
| /* Make textboxes compact */ | |
| .gradio-container textarea{ | |
| padding: 0.45rem 0.55rem !important; | |
| line-height: 1.2 !important; | |
| } | |
| .gradio-container input[type="text"], .gradio-container input[type="number"]{ | |
| padding: 0.40rem 0.55rem !important; | |
| } | |
| /* Reduce slider vertical padding */ | |
| .gradio-container .gr-slider, | |
| .gradio-container .slider-container{ | |
| margin-top: 0.10rem !important; | |
| margin-bottom: 0.10rem !important; | |
| } | |
| /* ===== Right panel (output) ===== */ | |
| .panel.output{ | |
| display: flex; | |
| flex-direction: column; | |
| min-height: 0; | |
| height: 100%; | |
| overflow: hidden; | |
| } | |
| /* Prevent "infinite height" from nested wrappers */ | |
| .panel.output > *, | |
| .panel.output .gr-block, | |
| .panel.output .gr-box, | |
| .panel.output .gr-row, | |
| .panel.output .gr-column, | |
| .panel.output .gr-form{ | |
| min-height: 0 !important; | |
| } | |
| /* Main image takes remaining space */ | |
| #main_image{ | |
| flex: 1 1 auto; | |
| min-height: 0 !important; | |
| overflow: hidden; | |
| display: flex; | |
| align-items: center; | |
| justify-content: center; | |
| } | |
| #main_image > div, | |
| #main_image .wrap, | |
| #main_image .image-container, | |
| #main_image .image-preview{ | |
| height: 100% !important; | |
| min-height: 0 !important; | |
| } | |
| #main_image img{ | |
| max-width: 100% !important; | |
| max-height: 100% !important; | |
| object-fit: contain !important; | |
| border-radius: 10px; | |
| } | |
| /* History strip (fixed) */ | |
| #history_gallery{ | |
| flex: 0 0 var(--history-h); | |
| height: var(--history-h) !important; | |
| margin-top: 0.5rem; | |
| overflow: hidden; | |
| } | |
| #history_gallery .grid, | |
| #history_gallery .gallery, | |
| #history_gallery .wrap{ | |
| height: 100% !important; | |
| overflow: hidden !important; | |
| } | |
| #history_gallery img{ | |
| height: calc(var(--history-h) - 28px) !important; | |
| width: auto !important; | |
| object-fit: cover !important; | |
| border-radius: 8px; | |
| } | |
| /* Logs accordion compact */ | |
| #logs_accordion{ flex: 0 0 auto; } | |
| #logs_accordion textarea{ min-height: 9.5rem !important; } | |
| /* Generate button */ | |
| #generate_btn{ | |
| width: 100%; | |
| font-weight: 700; | |
| font-size: 1.02rem; | |
| padding: 0.55rem; | |
| margin-top: 0.35rem; | |
| } | |
| /* ===== Mobile: stack and allow normal page scroll ===== */ | |
| @media (max-width: 900px){ | |
| body{ overflow: auto; } | |
| .gradio-container{ height: auto; overflow: visible; } | |
| .layout-main{ grid-template-columns: 1fr; height: auto; } | |
| .panel{ overflow: visible; } | |
| .panel.controls .tabitem{ overflow: visible !important; } | |
| #history_gallery{ height: 100px !important; flex-basis: 100px; } | |
| #history_gallery img{ height: 68px !important; } | |
| } | |
| """ | |
| # ========================= | |
| # Space-friendly caching | |
| # ========================= | |
| os.environ.setdefault("HF_HOME", "/data/.huggingface" if os.path.isdir("/data") else os.environ.get("HF_HOME", "")) | |
| if os.environ.get("HF_HOME"): | |
| os.environ.setdefault("TRANSFORMERS_CACHE", os.path.join(os.environ["HF_HOME"], "hub")) | |
| os.environ.setdefault("DIFFUSERS_CACHE", os.path.join(os.environ["HF_HOME"], "hub")) | |
| # ========================= | |
| # Defaults | |
| # ========================= | |
| ZIMAGE_REPO_DEFAULT = "Tongyi-MAI/Z-Image-Turbo" | |
| ZIMAGE_LOCAL_DIR = str((pathlib.Path(__file__).resolve().parent / "models" / "zimage").resolve()) | |
| # Optional Turbo preservation patch LoRA (stacked, never fused) | |
| DISTILLPATCH_REPO = "DiffSynth-Studio/Z-Image-Turbo-DistillPatch" | |
| DISTILLPATCH_REVISION = "main" | |
| LORAS_ROOT = pathlib.Path(__file__).resolve().parent / "loras" | |
| LORAS_ROOT.mkdir(parents=True, exist_ok=True) | |
| # ========================= | |
| # Startup prefetch (downloads on app start, not on first generate) | |
| # ========================= | |
| PREFETCH_STATUS = {"ok": False, "msg": "Not started", "path": ""} | |
| def prefetch_zimage_repo(repo_id: str, local_dir: str) -> str: | |
| os.makedirs(local_dir, exist_ok=True) | |
| snapshot_download( | |
| repo_id=repo_id, | |
| repo_type="model", | |
| local_dir=local_dir, | |
| ) | |
| return local_dir | |
| try: | |
| PREFETCH_STATUS["msg"] = f"Prefetching {ZIMAGE_REPO_DEFAULT} → {ZIMAGE_LOCAL_DIR} ..." | |
| prefetch_zimage_repo(ZIMAGE_REPO_DEFAULT, ZIMAGE_LOCAL_DIR) | |
| PREFETCH_STATUS["ok"] = True | |
| PREFETCH_STATUS["path"] = ZIMAGE_LOCAL_DIR | |
| PREFETCH_STATUS["msg"] = f"FASTER" | |
| except Exception as e: | |
| PREFETCH_STATUS["ok"] = False | |
| PREFETCH_STATUS["msg"] = f"⚠️ Prefetch failed: {type(e).__name__}: {e}" | |
| # ========================= | |
| # Session + LoRA helpers | |
| # ========================= | |
| def _new_session_id() -> str: | |
| return hashlib.sha1(os.urandom(16)).hexdigest()[:10] | |
| def _safe_filename(name: str) -> str: | |
| name = (name or "").strip() | |
| name = re.sub(r"[^a-zA-Z0-9._-]+", "_", name) | |
| name = re.sub(r"_+", "_", name).strip("_") | |
| return name or "lora" | |
| def _sha1_file(path: str) -> str: | |
| h = hashlib.sha1() | |
| with open(path, "rb") as f: | |
| for chunk in iter(lambda: f.read(1024 * 1024), b""): | |
| h.update(chunk) | |
| return h.hexdigest()[:12] | |
| def _session_lora_dir(session_id: str) -> pathlib.Path: | |
| d = LORAS_ROOT / f"session_{session_id}" | |
| d.mkdir(parents=True, exist_ok=True) | |
| return d | |
| def _list_loras(session_id: str) -> List[str]: | |
| d = _session_lora_dir(session_id) | |
| items: List[str] = [] | |
| for ext in ("*.safetensors", "*.bin"): | |
| for p in sorted(d.glob(ext)): | |
| items.append(p.name) | |
| return items | |
| def _download_file(url: str, dst_path: pathlib.Path, progress_cb=None, headers=None) -> pathlib.Path: | |
| headers = headers or {} | |
| with requests.get(url, stream=True, headers=headers, timeout=60, allow_redirects=True) as r: | |
| r.raise_for_status() | |
| total = int(r.headers.get("content-length", "0") or "0") | |
| dst_path.parent.mkdir(parents=True, exist_ok=True) | |
| tmp = dst_path.with_suffix(dst_path.suffix + ".part") | |
| downloaded = 0 | |
| with open(tmp, "wb") as f: | |
| for chunk in r.iter_content(chunk_size=1024 * 1024): | |
| if not chunk: | |
| continue | |
| f.write(chunk) | |
| downloaded += len(chunk) | |
| if progress_cb and total > 0: | |
| progress_cb(downloaded / total) | |
| tmp.replace(dst_path) | |
| return dst_path | |
| def _validate_downloaded_lora(path: pathlib.Path) -> None: | |
| # A real LoRA is almost never < 1MB. If it is, it's usually HTML or an error blob. | |
| if not path.exists(): | |
| raise RuntimeError("Downloaded LoRA file is missing.") | |
| sz = path.stat().st_size | |
| if sz < 1024 * 1024: | |
| head = path.read_bytes()[:256] | |
| raise RuntimeError(f"Downloaded file is too small ({sz} bytes). First bytes: {head!r} (likely HTML/error).") | |
| head8 = path.read_bytes()[:8] | |
| if head8.startswith(b"PK"): | |
| raise RuntimeError("Downloaded file looks like a ZIP (PK...). Not a safetensors LoRA.") | |
| if head8.lower().startswith(b"<html") or head8.startswith(b"<!DOCTYP"): | |
| raise RuntimeError("Downloaded file is HTML (likely auth/permission issue).") | |
| def refresh_loras_for_session(session_id: Optional[str]) -> Tuple[str, Any]: | |
| sid = session_id or _new_session_id() | |
| choices = ["<none>"] + _list_loras(sid) | |
| return sid, gr.update(choices=choices, value="<none>") | |
| def _civitai_extract_model_version_id(url: str) -> Optional[str]: | |
| m = re.search(r"modelVersionId=(\d+)", url) | |
| if m: | |
| return m.group(1) | |
| m = re.search(r"/model-versions/(\d+)", url) | |
| if m: | |
| return m.group(1) | |
| return None | |
| def _civitai_fetch_trigger_words(model_version_id: str, civitai_token: str) -> List[str]: | |
| api = f"https://civitai.com/api/v1/model-versions/{model_version_id}" | |
| headers = {"User-Agent": "Mozilla/5.0 (HF Spaces) ZImage-LoRA-Downloader"} | |
| if civitai_token: | |
| headers["Authorization"] = f"Bearer {civitai_token}" | |
| headers["X-Api-Key"] = civitai_token | |
| r = requests.get(api, headers=headers, timeout=30) | |
| r.raise_for_status() | |
| data = r.json() | |
| words = data.get("trainedWords") or data.get("triggerWords") or [] | |
| if isinstance(words, str): | |
| words = [words] | |
| out = [] | |
| for w in words: | |
| w = str(w).strip() | |
| if w and w not in out: | |
| out.append(w) | |
| return out | |
| def _merge_prompt_with_triggers(prompt: str, triggers: List[str]) -> str: | |
| p = (prompt or "").strip() | |
| if not triggers: | |
| return p | |
| existing = set(re.findall(r"[\w\-\:]+", p.lower())) | |
| to_add = [t for t in triggers if t.lower() not in existing] | |
| if not to_add: | |
| return p | |
| return (p + " " + " ".join(to_add)).strip() | |
| def _is_hf_repo_id(s: str) -> bool: | |
| return bool(re.fullmatch(r"[A-Za-z0-9_.-]+/[A-Za-z0-9_.-]+", (s or "").strip())) | |
| def _hf_extract_repo_id(url: str) -> Optional[str]: | |
| # https://huggingface.co/user/repo (optionally with /tree/main etc) | |
| m = re.search(r"huggingface\.co/([^/\s]+)/([^/\s?#]+)", url) | |
| if not m: | |
| return None | |
| return f"{m.group(1)}/{m.group(2)}" | |
| def _hf_pick_lora_file(repo_id: str, revision: str = "main", hf_token: str = "") -> str: | |
| files = list_repo_files(repo_id, revision=revision or "main", token=hf_token or None) | |
| safes = [f for f in files if f.lower().endswith(".safetensors")] | |
| if not safes: | |
| raise RuntimeError("No .safetensors files found in that Hugging Face repo.") | |
| if len(safes) == 1: | |
| return safes[0] | |
| # Prefer files that look like LoRAs | |
| lora_like = [f for f in safes if "lora" in f.lower()] | |
| return lora_like[0] if lora_like else safes[0] | |
| def _hf_fetch_trigger_words(repo_id: str, revision: str = "main", hf_token: str = "") -> List[str]: | |
| """ | |
| Best-effort: download README/model card and extract trigger words. | |
| Looks for patterns like: trigger word(s): `xxx` or Trigger: xxx, or "trained words". | |
| """ | |
| candidates = ["README.md", "readme.md", "README.MD", "modelcard.md", "MODEL_CARD.md"] | |
| text = "" | |
| for fn in candidates: | |
| try: | |
| p = hf_hub_download(repo_id=repo_id, filename=fn, revision=revision or "main", token=hf_token or None) | |
| text = pathlib.Path(p).read_text(encoding="utf-8", errors="ignore") | |
| if text.strip(): | |
| break | |
| except Exception: | |
| continue | |
| if not text: | |
| return [] | |
| words: List[str] = [] | |
| # 1) backticked tokens after "trigger" | |
| for m in re.finditer(r"(?i)trigger\w*\s*(?:words?|token|phrase)?\s*[:\-]\s*([^\n]+)", text): | |
| chunk = m.group(1) | |
| # collect backticked and comma-separated tokens | |
| back = re.findall(r"`([^`]+)`", chunk) | |
| if back: | |
| for b in back: | |
| b = b.strip() | |
| if b and b not in words: | |
| words.append(b) | |
| else: | |
| # split common separators | |
| for part in re.split(r"[,\|/]", chunk): | |
| part = part.strip() | |
| # keep short-ish single tokens | |
| if 1 <= len(part) <= 64 and re.fullmatch(r"[A-Za-z0-9_\-:]+", part): | |
| if part not in words: | |
| words.append(part) | |
| # 2) "trained words" style | |
| for m in re.finditer(r"(?i)trained\s*words?\s*[:\-]\s*([^\n]+)", text): | |
| chunk = m.group(1) | |
| for b in re.findall(r"`([^`]+)`", chunk): | |
| b = b.strip() | |
| if b and b not in words: | |
| words.append(b) | |
| # prune empties | |
| words = [w for w in words if w.strip()] | |
| return words[:20] | |
| def download_lora_for_session( | |
| session_id: Optional[str], | |
| lora_url: str, | |
| lora_filename: str, | |
| hf_token: str, | |
| civitai_token: str, | |
| current_prompt: str, | |
| progress=gr.Progress(track_tqdm=False), | |
| ): | |
| sid = session_id or _new_session_id() | |
| url = (lora_url or "").strip() | |
| if not url: | |
| return sid, gr.update(choices=["<none>"] + _list_loras(sid), value="<none>"), "❌ Please provide a URL or repo id.", "", current_prompt | |
| # env/secret fallback | |
| hf_token = (hf_token or "").strip() or os.environ.get("HF_TOKEN", "").strip() or os.environ.get("HUGGINGFACE_HUB_TOKEN", "").strip() | |
| civitai_token = (civitai_token or "").strip() or os.environ.get("CIVITAI_TOKEN", "").strip() or os.environ.get("CIVITAI_API_KEY", "").strip() | |
| trigger_words: List[str] = [] | |
| mv_id = None | |
| lower_url = url.lower() | |
| # ========================= | |
| # Hugging Face repo-id / repo URL (EARLY RETURN) | |
| # ========================= | |
| hf_repo_id = None | |
| hf_revision = "main" | |
| if _is_hf_repo_id(url): | |
| hf_repo_id = url | |
| elif "huggingface.co/" in lower_url and "/resolve/" not in lower_url: | |
| hf_repo_id = _hf_extract_repo_id(url) | |
| if hf_repo_id: | |
| try: | |
| filename_in_repo = _hf_pick_lora_file(hf_repo_id, revision=hf_revision, hf_token=hf_token) | |
| out_dir = _session_lora_dir(sid) | |
| base = _safe_filename(lora_filename) if lora_filename else _safe_filename(pathlib.Path(filename_in_repo).name) | |
| if not base.lower().endswith(".safetensors"): | |
| base += ".safetensors" | |
| dst = out_dir / base | |
| local_path = hf_hub_download(repo_id=hf_repo_id, filename=filename_in_repo, revision=hf_revision, token=hf_token or None) | |
| # Copy into session folder so it appears in dropdown | |
| if not dst.exists(): | |
| shutil.copyfile(local_path, dst) | |
| _validate_downloaded_lora(dst) | |
| # Trigger words from README (best effort) | |
| try: | |
| trigger_words = _hf_fetch_trigger_words(hf_repo_id, revision=hf_revision, hf_token=hf_token) | |
| except Exception: | |
| trigger_words = [] | |
| triggers_text = ", ".join(trigger_words) if trigger_words else "" | |
| new_prompt = _merge_prompt_with_triggers(current_prompt, trigger_words) | |
| msg = f"✅ Downloaded from Hugging Face: {hf_repo_id}/{filename_in_repo}" | |
| return ( | |
| sid, | |
| gr.update(choices=["<none>"] + _list_loras(sid), value=dst.name), | |
| msg, | |
| triggers_text, | |
| new_prompt, | |
| ) | |
| except Exception as e: | |
| msg = f"❌ Hugging Face download failed: {type(e).__name__}: {e}" | |
| return sid, gr.update(choices=["<none>"] + _list_loras(sid), value="<none>"), msg, "", current_prompt | |
| # ========================= | |
| # CivitAI conversion | |
| # ========================= | |
| if "civitai.com" in lower_url: | |
| mv_id = _civitai_extract_model_version_id(url) | |
| if mv_id: | |
| url = f"https://civitai.com/api/download/models/{mv_id}?type=Model&format=SafeTensor" | |
| lower_url = url.lower() | |
| out_dir = _session_lora_dir(sid) | |
| base = _safe_filename(lora_filename) if lora_filename else _safe_filename(pathlib.Path(url.split("?")[0]).name) | |
| if not base.lower().endswith((".safetensors", ".bin")): | |
| base += ".safetensors" | |
| dst = out_dir / base | |
| headers = {"User-Agent": "Mozilla/5.0 (HF Spaces) ZImage-LoRA-Downloader"} | |
| # Auth headers | |
| if "civitai.com/api/download/models" in lower_url: | |
| if civitai_token: | |
| headers["Authorization"] = f"Bearer {civitai_token}" | |
| headers["X-Api-Key"] = civitai_token | |
| else: | |
| return sid, gr.update(choices=["<none>"] + _list_loras(sid), value="<none>"), ( | |
| "❌ CivitAI download requires a token. Set the Space Secret CIVITAI_TOKEN." | |
| ), "", current_prompt | |
| if "huggingface.co" in lower_url and "/resolve/" in lower_url and hf_token: | |
| headers["Authorization"] = f"Bearer {hf_token}" | |
| def _cb(p): | |
| progress(p) | |
| try: | |
| _download_file(url, dst, progress_cb=_cb, headers=headers) | |
| _validate_downloaded_lora(dst) | |
| msg = f"✅ Downloaded: {dst.name}" | |
| except Exception as e: | |
| msg = f"❌ Download failed: {type(e).__name__}: {e}" | |
| return sid, gr.update(choices=["<none>"] + _list_loras(sid), value="<none>"), msg, "", current_prompt | |
| # Trigger words from CivitAI | |
| try: | |
| if mv_id: | |
| trigger_words = _civitai_fetch_trigger_words(mv_id, civitai_token) | |
| except Exception: | |
| trigger_words = [] | |
| triggers_text = ", ".join(trigger_words) if trigger_words else "" | |
| new_prompt = _merge_prompt_with_triggers(current_prompt, trigger_words) | |
| return ( | |
| sid, | |
| gr.update(choices=["<none>"] + _list_loras(sid), value=dst.name if dst.exists() else "<none>"), | |
| msg, | |
| triggers_text, | |
| new_prompt, | |
| ) | |
| def _resolve_lora_path(session_id: str, lora_choice: str) -> str: | |
| if not lora_choice or lora_choice == "<none>": | |
| return "" | |
| return str(_session_lora_dir(session_id) / lora_choice) | |
| # ========================= | |
| # SeedVarianceEnhancer v2.1 (repo-like) | |
| # ========================= | |
| def _auto_mask_trailing_zeros(embeds: torch.Tensor) -> torch.Tensor: | |
| token_is_zero = (embeds.abs().sum(dim=-1, keepdim=True) == 0) | |
| return ~token_is_zero | |
| def sve_apply( | |
| prompt_embeds: torch.Tensor, | |
| seed: int, | |
| strength: float, | |
| randomize_percent: float, | |
| mask_starts_at_mode: str = "beginning", | |
| mask_percent: float = 0.0, | |
| ) -> torch.Tensor: | |
| if strength == 0 or randomize_percent <= 0: | |
| return prompt_embeds | |
| device = prompt_embeds.device | |
| g = torch.Generator(device=device).manual_seed(int(seed) & 0x7FFFFFFF) | |
| # Some pipelines return prompt_embeds as 2D (B,C) instead of 3D (B,T,C). | |
| sve__squeeze_token_dim = False | |
| if isinstance(prompt_embeds, torch.Tensor) and prompt_embeds.ndim == 2: | |
| prompt_embeds = prompt_embeds.unsqueeze(1) | |
| sve__squeeze_token_dim = True | |
| elif isinstance(prompt_embeds, torch.Tensor) and prompt_embeds.ndim == 1: | |
| prompt_embeds = prompt_embeds.view(1, 1, -1) | |
| sve__squeeze_token_dim = True | |
| p = float(randomize_percent) / 100.0 | |
| elem_mask = torch.ones_like(prompt_embeds, dtype=torch.bool) if p >= 1.0 else (torch.rand(prompt_embeds.shape, device=device, dtype=prompt_embeds.dtype, generator=g) < p) | |
| B, T, _ = prompt_embeds.shape | |
| token_allow = torch.ones((B, T, 1), device=device, dtype=torch.bool) | |
| mp = max(0.0, min(100.0, float(mask_percent))) | |
| if mp > 0: | |
| length = int(round((mp / 100.0) * T)) | |
| length = max(0, min(T, length)) | |
| mode = (mask_starts_at_mode or "beginning").strip().lower() | |
| if length > 0: | |
| if mode in ("begin", "beginning", "start"): | |
| start = 0 | |
| elif mode in ("middle", "center", "centre"): | |
| start = max(0, (T - length) // 2) | |
| elif mode in ("end", "ending", "finish"): | |
| start = max(0, T - length) | |
| else: | |
| # fallback: treat unknown as beginning | |
| start = 0 | |
| end = min(T, start + length) | |
| if start < T and end > start: | |
| # False = protected region (no noise) | |
| token_allow[:, start:end, :] = False | |
| token_allow = token_allow & _auto_mask_trailing_zeros(prompt_embeds) | |
| mask = elem_mask & token_allow | |
| noise = torch.randn(prompt_embeds.shape, device=device, dtype=prompt_embeds.dtype, generator=g) * float(strength) | |
| out = torch.where(mask, prompt_embeds + noise, prompt_embeds) | |
| if 'sve__squeeze_token_dim' in locals() and sve__squeeze_token_dim: | |
| # Back to original rank | |
| if out.ndim == 3 and out.shape[1] == 1: | |
| out = out.squeeze(1) | |
| return out | |
| def sve_apply_any( | |
| prompt_embeds, | |
| seed: int, | |
| strength: float, | |
| randomize_percent: float, | |
| mask_starts_at_mode: str = "beginning", | |
| mask_percent: float = 0.0, | |
| ): | |
| """Apply SVE to a tensor or a (nested) list/tuple of tensors. | |
| Some Diffusers pipelines return prompt_embeds as a list (multi-encoder); we preserve structure. | |
| """ | |
| if isinstance(prompt_embeds, (list, tuple)): | |
| out = [] | |
| # use different sub-seeds so encoders don't get identical noise patterns | |
| for i, pe in enumerate(prompt_embeds): | |
| out.append( | |
| sve_apply_any( | |
| pe, | |
| seed=int(seed) + int(i), | |
| strength=strength, | |
| randomize_percent=randomize_percent, | |
| mask_starts_at_mode=mask_starts_at_mode, | |
| mask_percent=mask_percent, | |
| ) | |
| ) | |
| return out if isinstance(prompt_embeds, list) else tuple(out) | |
| # base case: tensor | |
| return sve_apply( | |
| prompt_embeds=prompt_embeds, | |
| seed=seed, | |
| strength=strength, | |
| randomize_percent=randomize_percent, | |
| mask_starts_at_mode=mask_starts_at_mode, | |
| mask_percent=mask_percent, | |
| ) | |
| def _build_sve_callback(clean_embeds: torch.Tensor, noisy_embeds: torch.Tensor, switch_step: int): | |
| # Backwards-compatible: noisy before switch_step, clean after. | |
| def _cb(pipe, step_index: int, timestep: int, callback_kwargs: Dict[str, Any]): | |
| callback_kwargs["prompt_embeds"] = noisy_embeds if step_index < switch_step else clean_embeds | |
| return callback_kwargs | |
| return _cb | |
| def _build_sve_callback_mode(clean_embeds: torch.Tensor, noisy_embeds: torch.Tensor, start_step: int, mode: str): | |
| """Mode-aware callback. | |
| - mode='beginning': noisy before start_step, clean after (start_step = switch) | |
| - mode='ending': clean before start_step, noisy after (start_step = start of noisy region) | |
| """ | |
| mode = (mode or 'beginning').strip().lower() | |
| def _cb(pipe, step_index: int, timestep: int, callback_kwargs: Dict[str, Any]): | |
| if mode in ('end', 'ending'): | |
| callback_kwargs['prompt_embeds'] = noisy_embeds if step_index >= start_step else clean_embeds | |
| else: | |
| callback_kwargs['prompt_embeds'] = noisy_embeds if step_index < start_step else clean_embeds | |
| return callback_kwargs | |
| return _cb | |
| # ========================= | |
| # GPU: cached Z-Image pipeline | |
| # ========================= | |
| GPU_ENV: Dict[str, Any] = {"pipe": None, "dtype": None, "loaded_loras": {}, "distillpatch_loaded": False, "distillpatch_adapter": "distillpatch"} | |
| CPU_ENV: Dict[str, Any] = {"pipe": None, "path": None} | |
| def _ensure_cpu_pipe() -> None: | |
| """Load the Z-Image pipeline ON CPU once per container (ZeroGPU-safe).""" | |
| if CPU_ENV.get("pipe") is not None: | |
| return | |
| from diffusers import ZImagePipeline | |
| load_path = PREFETCH_STATUS.get("path") or ZIMAGE_LOCAL_DIR or ZIMAGE_REPO_DEFAULT | |
| CPU_ENV["path"] = load_path | |
| # NOTE: keep on CPU; do NOT touch CUDA outside @spaces.GPU | |
| CPU_ENV["pipe"] = ZImagePipeline.from_pretrained( | |
| load_path, | |
| torch_dtype=torch.bfloat16, | |
| low_cpu_mem_usage=False, | |
| ) | |
| def _torch_dtype(name: str): | |
| name = (name or "bf16").lower() | |
| if name in ("bf16", "bfloat16"): | |
| return torch.bfloat16 | |
| if name in ("fp16", "float16", "half"): | |
| return torch.float16 | |
| return torch.float32 | |
| def generate_route( | |
| session_id: str, | |
| prompt: str, | |
| negative: str, | |
| steps: int, | |
| cfg: float, | |
| width: int, | |
| height: int, | |
| seed: int, | |
| zimage_shift: float, | |
| max_sequence_length: int, | |
| precision: str, | |
| use_lora: bool, | |
| use_distillpatch: bool, | |
| selected_lora: str, | |
| lora_scale: float, | |
| distillpatch_scale: float, | |
| sve_enabled: bool, | |
| sve_noise_insert: str, | |
| sve_steps_switchover_percent: float, | |
| sve_seed: int, | |
| sve_control_after_generate: str, | |
| sve_strength: float, | |
| sve_random_percent: float, | |
| sve_mask_starts_at: str, | |
| sve_mask_percent: float, | |
| sve_log_to_console: bool, | |
| ) -> Tuple[str, List[Image.Image], str, str]: | |
| logs: List[str] = [] | |
| t0 = time.time() | |
| seed = int(seed) if seed is not None else -1 | |
| if seed == -1: | |
| seed = random.randint(0, 2**31 - 1) | |
| from diffusers import ZImagePipeline, FlowMatchEulerDiscreteScheduler | |
| # We pin dtype to bf16. | |
| dtype = torch.bfloat16 | |
| # ZeroGPU-safe strategy: | |
| # - Load the pipeline ON CPU once per container (outside GPU allocation). | |
| # - On each GPU call, move that already-loaded pipeline to CUDA. | |
| # This avoids re-reading checkpoint shards on every request. | |
| _ensure_cpu_pipe() | |
| if GPU_ENV["pipe"] is None: | |
| logs.append(f"📦 Moving Z-Image pipeline to CUDA ({dtype})") | |
| # Move the already-loaded CPU pipeline to GPU for this allocated GPU session. | |
| GPU_ENV["pipe"] = CPU_ENV["pipe"].to("cuda") | |
| GPU_ENV["dtype"] = str(dtype) | |
| GPU_ENV["compiled"] = False | |
| pipe = GPU_ENV["pipe"] | |
| # Update scheduler shift each run | |
| try: | |
| pipe.scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=float(zimage_shift)) | |
| except Exception as e: | |
| logs.append(f"⚠️ scheduler shift failed: {type(e).__name__}: {e}") | |
| # ---- Optional Turbo DistillPatch LoRA (stacked, never fused) ---- | |
| # Note: DiffSynth's reference snippet loads user LoRA first, then DistillPatch. | |
| # In Diffusers/PEFT multi-adapter mode the effects are generally additive, but we keep | |
| # DistillPatch LAST in the adapter list to mirror the official intent. | |
| active_adapters: List[str] = [] | |
| active_weights: List[float] = [] | |
| distill_ready = False | |
| if use_distillpatch: | |
| try: | |
| if not GPU_ENV.get("distillpatch_loaded", False): | |
| logs.append(f"🩹 Loading DistillPatch: {DISTILLPATCH_REPO}") | |
| # DistillPatch repo typically exposes `model.safetensors`. | |
| # We try that first for determinism; then fall back to first .safetensors in repo. | |
| try: | |
| pipe.load_lora_weights( | |
| DISTILLPATCH_REPO, | |
| revision=DISTILLPATCH_REVISION, | |
| weight_name="model.safetensors", | |
| adapter_name=GPU_ENV["distillpatch_adapter"], | |
| ) | |
| except Exception: | |
| files = list_repo_files(DISTILLPATCH_REPO, revision=DISTILLPATCH_REVISION, token=None) | |
| cand = [f for f in files if f.lower().endswith(".safetensors")] | |
| pipe.load_lora_weights( | |
| DISTILLPATCH_REPO, | |
| revision=DISTILLPATCH_REVISION, | |
| weight_name=(cand[0] if cand else None), | |
| adapter_name=GPU_ENV["distillpatch_adapter"], | |
| ) | |
| GPU_ENV["distillpatch_loaded"] = True | |
| distill_ready = True | |
| except Exception as e: | |
| logs.append(f"⚠️ DistillPatch load failed: {type(e).__name__}: {e}") | |
| distill_ready = False | |
| # LoRA (NO FUSING) | |
| if use_lora and selected_lora and selected_lora != "<none>": | |
| lora_path = _resolve_lora_path(session_id, selected_lora) | |
| if os.path.exists(lora_path): | |
| lora_abs = os.path.abspath(lora_path) | |
| try: | |
| _validate_downloaded_lora(pathlib.Path(lora_abs)) | |
| except Exception as e: | |
| logs.append(f"❌ LoRA file invalid: {e}") | |
| mem_html = f"<span style='font-family:monospace;font-size:0.76rem;'>seed={seed} · steps={steps} · shift={zimage_shift} · dtype={dtype}</span>" | |
| return session_id, [], "\n".join(logs), mem_html | |
| key = (str(dtype), lora_abs) | |
| if key in GPU_ENV["loaded_loras"]: | |
| adapter_name = GPU_ENV["loaded_loras"][key] | |
| logs.append(f"🧩 LoRA already loaded: {selected_lora} as {adapter_name}") | |
| else: | |
| adapter_name = f"lora_{_sha1_file(lora_abs)}" | |
| logs.append(f"🧩 Loading LoRA: {selected_lora} -> {adapter_name}") | |
| pipe.load_lora_weights(lora_abs, adapter_name=adapter_name) | |
| GPU_ENV["loaded_loras"][key] = adapter_name | |
| active_adapters.append(adapter_name) | |
| active_weights.append(float(lora_scale)) | |
| # Keep DistillPatch last (mirror official snippet) | |
| if distill_ready: | |
| active_adapters.append(GPU_ENV["distillpatch_adapter"]) | |
| active_weights.append(float(distillpatch_scale)) | |
| try: | |
| pipe.set_adapters(active_adapters, adapter_weights=active_weights) | |
| except Exception as e: | |
| logs.append(f"⚠️ set_adapters failed: {type(e).__name__}: {e}") | |
| else: | |
| logs.append("⚠️ LoRA file missing on disk") | |
| else: | |
| # No user LoRA. Keep only DistillPatch if enabled; otherwise clear. | |
| try: | |
| if distill_ready: | |
| pipe.set_adapters([GPU_ENV["distillpatch_adapter"]], adapter_weights=[float(distillpatch_scale)]) | |
| else: | |
| pipe.set_adapters([]) | |
| except Exception: | |
| pass | |
| # Ensure dimensions are valid for Z-Image (multiple of 16) | |
| w0, h0 = int(width), int(height) | |
| w = max(16, w0 - (w0 % 16)) | |
| h = max(16, h0 - (h0 % 16)) | |
| if (w, h) != (w0, h0): | |
| logs.append(f"ℹ️ Rounded size to multiple-of-16: {w0}x{h0} -> {w}x{h}") | |
| width, height = w, h | |
| generator = torch.Generator(device="cuda").manual_seed(seed) | |
| call_kwargs: Dict[str, Any] = dict( | |
| height=int(height), | |
| width=int(width), | |
| num_inference_steps=int(steps), | |
| guidance_scale=float(cfg), | |
| generator=generator, | |
| max_sequence_length=int(max_sequence_length), | |
| ) | |
| # SeedVarianceEnhancer (1:1-ish with ChangeTheConstants node UI) | |
| # noise_insert: disabled | noise on beginning steps | noise on all steps | noise on ending steps | |
| if sve_enabled: | |
| mode_raw = (sve_noise_insert or "disabled").strip().lower() | |
| if mode_raw.startswith("noise on beginning"): | |
| mode = "beginning" | |
| elif mode_raw.startswith("noise on ending"): | |
| mode = "ending" | |
| elif mode_raw.startswith("noise on all"): | |
| mode = "all" | |
| else: | |
| mode = "disabled" | |
| # Node uses a separate seed for embedding noise; allow -1 to "follow main seed" | |
| sve_seed_used = int(sve_seed) if sve_seed is not None else -1 | |
| if sve_seed_used == -1: | |
| sve_seed_used = int(seed) | |
| # switchover percent means: beginning steps noise % OR ending steps noise % | |
| sw = max(0.0, min(100.0, float(sve_steps_switchover_percent))) | |
| k = int(round((sw / 100.0) * int(steps))) | |
| k = max(0, min(int(steps), k)) | |
| if mode == "beginning": | |
| switch_step = k | |
| elif mode == "ending": | |
| switch_step = max(0, int(steps) - k) # start of noisy region | |
| else: | |
| switch_step = 0 | |
| logs.append( | |
| f"🎲 SVE: mode={mode} seed={sve_seed_used} strength={float(sve_strength)} rand%={float(sve_random_percent)} " | |
| f"sw%={sw} mask_start={sve_mask_starts_at} mask%={float(sve_mask_percent)}" | |
| ) | |
| try: | |
| dev = pipe._execution_device if hasattr(pipe, "_execution_device") else "cuda" | |
| # Encode to embeddings (diffusers variant-safe) | |
| if hasattr(pipe, "encode_prompt"): | |
| pe, ne = pipe.encode_prompt( | |
| prompt=prompt, | |
| negative_prompt=negative, | |
| device=dev, | |
| do_classifier_free_guidance=True, | |
| max_sequence_length=int(max_sequence_length), | |
| ) | |
| else: | |
| pe, ne = pipe._encode_prompt( | |
| prompt=prompt, | |
| negative_prompt=negative, | |
| device=dev, | |
| num_images_per_prompt=1, | |
| do_classifier_free_guidance=True, | |
| max_sequence_length=int(max_sequence_length), | |
| ) | |
| clean = pe | |
| noisy = sve_apply_any( | |
| prompt_embeds=clean, | |
| seed=sve_seed_used, | |
| strength=float(sve_strength), | |
| randomize_percent=float(sve_random_percent), | |
| mask_starts_at_mode=str(sve_mask_starts_at), | |
| mask_percent=float(sve_mask_percent), | |
| ) | |
| # Apply according to mode | |
| if mode == "disabled" or float(sve_strength) == 0 or float(sve_random_percent) <= 0: | |
| pass | |
| elif mode == "all": | |
| call_kwargs["prompt_embeds"] = noisy | |
| call_kwargs["negative_prompt_embeds"] = ne | |
| elif mode == "beginning": | |
| # noisy for first switch_step steps | |
| call_kwargs["prompt_embeds"] = noisy if switch_step > 0 else clean | |
| call_kwargs["negative_prompt_embeds"] = ne | |
| if 0 < switch_step < int(steps): | |
| call_kwargs["callback_on_step_end"] = _build_sve_callback_mode(clean, noisy, switch_step, "beginning") | |
| call_kwargs["callback_on_step_end_tensor_inputs"] = ["prompt_embeds"] | |
| elif mode == "ending": | |
| # noisy for last k steps starting at switch_step | |
| call_kwargs["prompt_embeds"] = clean | |
| call_kwargs["negative_prompt_embeds"] = ne | |
| if 0 <= switch_step < int(steps): | |
| call_kwargs["callback_on_step_end"] = _build_sve_callback_mode(clean, noisy, switch_step, "ending") | |
| call_kwargs["callback_on_step_end_tensor_inputs"] = ["prompt_embeds"] | |
| except Exception as e: | |
| logs.append(f"⚠️ SVE embedding path failed: {type(e).__name__}: {e} (fall back to raw prompt)") | |
| # Optional log_to_console (mirrors node toggle) | |
| if sve_log_to_console: | |
| try: | |
| print("[SVE]", logs[-1]) | |
| except Exception: | |
| pass | |
| logs.append("🚀 Generating…") | |
| with torch.inference_mode(): | |
| out = pipe( | |
| prompt=None if "prompt_embeds" in call_kwargs else prompt, | |
| negative_prompt=None if "negative_prompt_embeds" in call_kwargs else negative, | |
| **call_kwargs, | |
| ) | |
| img = out.images[0] | |
| logs.append(f"✅ Done in {time.time() - t0:.2f}s (seed={seed})") | |
| mem_html = f"<span style='font-family:monospace;font-size:0.76rem;'>seed={seed} · steps={steps} · shift={zimage_shift} · dtype={dtype}</span>" | |
| return session_id, [img], "\n".join(logs), mem_html | |
| # ========================= | |
| # Gallery normalization | |
| # ========================= | |
| def _normalize_gallery_list(lst): | |
| from PIL import Image as PILImage | |
| import numpy as np | |
| if not isinstance(lst, list): | |
| return [] | |
| out = [] | |
| for item in lst: | |
| if isinstance(item, PILImage.Image): | |
| out.append(item) | |
| continue | |
| if isinstance(item, np.ndarray): | |
| arr = item | |
| if arr.dtype != np.uint8: | |
| arr = np.clip(arr, 0, 255).astype(np.uint8) | |
| if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[0] != arr.shape[-1]: | |
| arr = np.transpose(arr, (1, 2, 0)) | |
| out.append(PILImage.fromarray(arr)) | |
| continue | |
| if isinstance(item, tuple) and item: | |
| sub = _normalize_gallery_list([item[0]]) | |
| if sub: | |
| out.append(sub[0]) | |
| continue | |
| if isinstance(item, dict): | |
| data = item.get("image") or item.get("data") or item.get("value") | |
| sub = _normalize_gallery_list([data]) | |
| if sub: | |
| out.append(sub[0]) | |
| continue | |
| return out | |
| # ========================= | |
| # Build UI | |
| # ========================= | |
| with gr.Blocks() as demo: | |
| session_state = gr.State(value=None) | |
| # Hidden token inputs (kept to match downloader signature; env/secrets fallback) | |
| hf_token_hidden = gr.Textbox(value="", visible=False) | |
| civitai_token_hidden = gr.Textbox(value="", visible=False) | |
| with gr.Row(elem_classes="layout-main"): | |
| # -------- LEFT: Controls -------- | |
| with gr.Column(elem_classes=["panel", "controls"]): | |
| with gr.Tabs(): | |
| # --- Generate --- | |
| with gr.Tab("Generate"): | |
| with gr.Column(elem_id="left_box"): | |
| gr.Markdown("### Generate") | |
| prefetch_status = gr.Markdown(value=PREFETCH_STATUS.get("msg", "")) | |
| run_btn = gr.Button("Generate 🎨", elem_id="generate_btn") | |
| with gr.Row(elem_classes="compact-row"): | |
| prompt = gr.Textbox(label="Prompt", value="", lines=2, placeholder="Describe what you want to see...") | |
| negative = gr.Textbox(label="Negative", value="", lines=2, placeholder="What to avoid...") | |
| with gr.Row(elem_classes="compact-row"): | |
| steps = gr.Slider(1, 64, 8, step=1, label="Steps") | |
| cfg = gr.Slider(0.0, 8.0, 1.0, step=0.1, label="CFG") | |
| with gr.Row(elem_classes="compact-row"): | |
| seed = gr.Number(-1, label="Seed (-1=random)") | |
| precision = gr.Dropdown(choices=["bf16", "fp16", "fp32"], value="bf16", label="Precision") | |
| with gr.Row(elem_classes="compact-row"): | |
| width = gr.Slider(256, 1536, 1024, step=16, label="Width") | |
| height = gr.Slider(256, 1536, 1024, step=16, label="Height") | |
| zimage_shift = gr.Slider(0.0, 10.0, 5.0, step=0.1, label="FlowMatch shift") | |
| max_sequence_length = gr.Slider(64, 512, 512, step=8, label="Max sequence length") | |
| with gr.Row(elem_classes="compact-row"): | |
| use_lora = gr.Checkbox(False, label="Use LoRA") | |
| use_distillpatch = gr.Checkbox(False, label="Turbo DistillPatch (default OFF)") | |
| lora_dropdown = gr.Dropdown(choices=["<none>"], value="<none>", label="LoRA file") | |
| refresh_loras_btn = gr.Button("🔄 Refresh") | |
| with gr.Row(elem_classes="compact-row"): | |
| lora_scale = gr.Slider(0.0, 2.0, 0.8, step=0.05, label="LoRA scale") | |
| distillpatch_scale = gr.Slider(0.0, 2.0, 1.0, step=0.05, label="DistillPatch scale") | |
| # ---- SeedVarianceEnhancer (node-like UI) ---- | |
| with gr.Row(elem_classes="compact-row"): | |
| sve_enabled = gr.Checkbox(value=False, label="SeedVarianceEnhancer") | |
| sve_noise_insert = gr.Dropdown( | |
| choices=[ | |
| "disabled", | |
| "noise on beginning steps", | |
| "noise on all steps", | |
| "noise on ending steps", | |
| ], | |
| value="noise on beginning steps", | |
| label="noise_insert", | |
| ) | |
| with gr.Row(elem_classes="compact-row"): | |
| sve_random_percent = gr.Slider(0.0, 100.0, 50.0, step=1.0, label="randomize_percent") | |
| sve_strength = gr.Slider(0.0, 80.0, 20.0, step=1.0, label="strength") | |
| with gr.Row(elem_classes="compact-row"): | |
| sve_steps_switchover = gr.Slider(0.0, 100.0, 20.0, step=1.0, label="steps_switchover_percent") | |
| sve_seed = gr.Number(2019, label="seed") | |
| sve_control_after_generate = gr.Dropdown( | |
| choices=["keep", "increment", "randomize"], | |
| value="randomize", | |
| label="control_after_generate", | |
| ) | |
| with gr.Row(elem_classes="compact-row"): | |
| sve_mask_starts = gr.Dropdown( | |
| choices=["beginning", "middle", "end"], | |
| value="beginning", | |
| label="mask_starts_at", | |
| ) | |
| sve_mask_percent = gr.Slider(0.0, 100.0, 0.0, step=1.0, label="mask_percent") | |
| sve_log_to_console = gr.Checkbox(value=False, label="log_to_console") | |
| trigger_words_box = gr.Textbox(label="Detected trigger words", value="", lines=1, interactive=False) | |
| # --- LoRA Downloader --- | |
| with gr.Tab("LoRA Downloader"): | |
| gr.Markdown("### Download a LoRA") | |
| gr.Markdown("**Tip:** Use **CivitAI model URL** or **Hugging Face URL/repo id** (e.g. `user/repo`).") | |
| with gr.Row(elem_classes="compact-row"): | |
| lora_url = gr.Textbox(label="LoRA URL / Repo ID", placeholder="Civitai model page URL, HF repo id (user/repo), or direct file URL…") | |
| lora_filename = gr.Textbox(label="Save name (optional)", placeholder="my_lora_name") | |
| download_lora_btn = gr.Button("📥 Download LoRA") | |
| lora_dl_log = gr.Textbox(label="Downloader log", lines=6, interactive=False) | |
| gr.Markdown("After downloading: go to **Generate** → **Refresh**.") | |
| # -------- RIGHT: Output -------- | |
| with gr.Column(elem_classes=["panel", "output"]): | |
| main_image = gr.Image(label="Result", interactive=False, elem_id="main_image") | |
| history_gallery = gr.Gallery( | |
| label="History", | |
| columns=6, | |
| height=110, | |
| elem_id="history_gallery", | |
| type="pil", | |
| ) | |
| with gr.Accordion("Logs & system info", open=False, elem_id="logs_accordion"): | |
| log_box = gr.Textbox(label="Generation Logs", lines=10, interactive=False) | |
| mem_status = gr.HTML(value="<span style='font-family:monospace;font-size:0.76rem;'>No memory data yet.</span>") | |
| download_lora_btn.click( | |
| fn=download_lora_for_session, | |
| inputs=[session_state, lora_url, lora_filename, hf_token_hidden, civitai_token_hidden, prompt], | |
| outputs=[session_state, lora_dropdown, lora_dl_log, trigger_words_box, prompt], | |
| ) | |
| refresh_loras_btn.click( | |
| fn=refresh_loras_for_session, | |
| inputs=[session_state], | |
| outputs=[session_state, lora_dropdown], | |
| ) | |
| def generate_and_clip_gallery( | |
| session_id, | |
| prompt, | |
| negative, | |
| steps, | |
| cfg, | |
| width, | |
| height, | |
| seed, | |
| zimage_shift, | |
| max_sequence_length, | |
| precision, | |
| use_lora, | |
| use_distillpatch, | |
| selected_lora, | |
| lora_scale, | |
| distillpatch_scale, | |
| sve_enabled, | |
| sve_noise_insert, | |
| sve_steps_switchover, | |
| sve_seed, | |
| sve_control_after_generate, | |
| sve_strength, | |
| sve_random_percent, | |
| sve_mask_starts, | |
| sve_mask_percent, | |
| sve_log_to_console, | |
| current_history, | |
| ): | |
| session_id = session_id or _new_session_id() | |
| # Call GPU route with a small retry: ZeroGPU can race CUDA init | |
| last_err = None | |
| mem_html = "" | |
| for _attempt in range(6): | |
| try: | |
| session_id, images, logs, mem_html = generate_route( | |
| session_id, | |
| prompt, | |
| negative, | |
| steps, | |
| cfg, | |
| width, | |
| height, | |
| seed, | |
| zimage_shift, | |
| max_sequence_length, | |
| precision, | |
| use_lora, | |
| use_distillpatch, | |
| selected_lora, | |
| lora_scale, | |
| distillpatch_scale, | |
| sve_enabled, | |
| sve_noise_insert, | |
| sve_steps_switchover, | |
| sve_seed, | |
| sve_control_after_generate, | |
| sve_strength, | |
| sve_random_percent, | |
| sve_mask_starts, | |
| sve_mask_percent, | |
| sve_log_to_console, | |
| ) | |
| last_err = None | |
| break | |
| except Exception as e: | |
| # ZeroGPU may briefly fail CUDA init if a GPU is not yet allocated. | |
| # Never crash the UI: retry a few times, then return a friendly message. | |
| last_err = e | |
| msg = str(e) or repr(e) | |
| low = msg.lower() | |
| transient_cuda = ( | |
| "cuda driver initialization failed" in low | |
| or "might not have a cuda gpu" in low | |
| or "_cuda_init" in low | |
| or "torch.init" in low | |
| or "cuda initialization" in low | |
| ) | |
| if transient_cuda: | |
| time.sleep(0.7) | |
| continue | |
| # Not a GPU-allocation race: surface the real error in logs, but keep the app alive. | |
| logs = f"❌ Generation failed: {type(e).__name__}: {msg}" | |
| history_list = _normalize_gallery_list(current_history) | |
| new_history = history_list[-10:] if len(history_list) > 10 else history_list | |
| main = new_history[-1] if new_history else None | |
| sve_seed_out = int(sve_seed) if sve_seed is not None else 2019 | |
| return session_id, main, new_history, logs, mem_html, sve_seed_out | |
| # Node-like control_after_generate update for the UI seed field | |
| sve_seed_out = int(sve_seed) if sve_seed is not None else 2019 | |
| if sve_enabled: | |
| mode = (sve_control_after_generate or "keep").strip().lower() | |
| if mode == "increment": | |
| sve_seed_out += 1 | |
| elif mode == "randomize": | |
| sve_seed_out = random.randint(0, 2**31 - 1) | |
| if last_err is not None: | |
| # Graceful error: keep history unchanged; no new images | |
| logs = "⚠️ GPU wasn’t ready (ZeroGPU race). Please click **Generate** again." | |
| mem_html = mem_html if "mem_html" in locals() else "" | |
| history_list = _normalize_gallery_list(current_history) | |
| new_history = history_list[-10:] if len(history_list) > 10 else history_list | |
| main = new_history[-1] if new_history else None | |
| return session_id, main, new_history, logs, mem_html, sve_seed_out | |
| history_list = _normalize_gallery_list(current_history) | |
| images = _normalize_gallery_list(images) | |
| new_history = (history_list + images)[-10:] | |
| main = new_history[-1] if new_history else None | |
| return session_id, main, new_history, logs, mem_html, sve_seed_out | |
| run_btn.click( | |
| fn=generate_and_clip_gallery, | |
| inputs=[ | |
| session_state, | |
| prompt, | |
| negative, | |
| steps, | |
| cfg, | |
| width, | |
| height, | |
| seed, | |
| zimage_shift, | |
| max_sequence_length, | |
| precision, | |
| use_lora, | |
| use_distillpatch, | |
| lora_dropdown, | |
| lora_scale, | |
| distillpatch_scale, | |
| sve_enabled, | |
| sve_noise_insert, | |
| sve_steps_switchover, | |
| sve_seed, | |
| sve_control_after_generate, | |
| sve_strength, | |
| sve_random_percent, | |
| sve_mask_starts, | |
| sve_mask_percent, | |
| sve_log_to_console, | |
| history_gallery, | |
| ], | |
| outputs=[session_state, main_image, history_gallery, log_box, mem_status, sve_seed], | |
| ) | |
| def select_from_history(history, evt: gr.SelectData): | |
| normalized = _normalize_gallery_list(history) | |
| idx = evt.index | |
| if isinstance(idx, int) and 0 <= idx < len(normalized): | |
| return normalized[idx] | |
| return None | |
| history_gallery.select(fn=select_from_history, inputs=[history_gallery], outputs=[main_image]) | |
| def _on_load(session_id): | |
| _ensure_cpu_pipe() | |
| sid, dd = refresh_loras_for_session(session_id) | |
| return sid, dd, PREFETCH_STATUS.get("msg", "") | |
| demo.load(fn=_on_load, inputs=[session_state], outputs=[session_state, lora_dropdown, prefetch_status]) | |
| demo.queue(default_concurrency_limit=1, max_size=20).launch(ssr_mode=False, css=RESPONSIVE_CSS, show_error=False) |