Spaces:
Running on Zero
Running on Zero
| import inspect | |
| import io | |
| import json | |
| import os | |
| import queue | |
| import random | |
| import re | |
| import threading | |
| import time | |
| import numpy as np | |
| import torch | |
| import spaces | |
| import gradio as gr | |
| from diffusers import ( | |
| DiffusionPipeline, | |
| StableDiffusionXLPipeline, | |
| AutoencoderKL, | |
| ZImageTransformer2DModel, | |
| AnimaModularPipeline, | |
| CosmosTransformer3DModel, | |
| ) | |
| from compel import Compel, ReturnedEmbeddingsType | |
| from huggingface_hub import hf_hub_download | |
| from safetensors.torch import load_file | |
| from transformers import ( | |
| AutoProcessor, | |
| AutoModelForImageTextToText, | |
| BitsAndBytesConfig, | |
| TextIteratorStreamer, | |
| ) | |
| # Structured progress contract (also see wan2.2 / LTX2.3 Spaces and the generator | |
| # orchestrator). Every streaming endpoint yields a hidden JSON "progress" output | |
| # in addition to its real outputs. Each yielded progress dict carries: | |
| # {"stage": <phase id>, "p": <0..1 fraction>, "step": int, "total": int, | |
| # "label": <human text>} | |
| # Because the function is a generator, Gradio surfaces every yield as an | |
| # `event: generating` frame on the /gradio_api/call SSE stream, so a downstream | |
| # consumer reading the JSON at the progress index gets live progress over SSE. | |
| def _progress(stage, p, step=0, total=0, label=""): | |
| return { | |
| "stage": stage, | |
| "p": max(0.0, min(1.0, float(p))), | |
| "step": int(step), | |
| "total": int(total), | |
| "label": label, | |
| } | |
| import r2_uploader | |
| import watermark as watermark_module | |
| # Per-Space namespace embedded in every uploaded object key. Deliberately opaque | |
| # (not the readable Space name) but stable so the owner can tell assets apart. | |
| # Legend: s01=ImageStudio, s02=LTX2.3-Studio, s03=wan2-2-fp8da-aoti-preview-2. | |
| R2_NAMESPACE = "s01" | |
| # ============================================================================= | |
| # Model registry | |
| # ============================================================================= | |
| # Two pipelines are exposed through a single UI via the selector below. Both now | |
| # run custom single-file checkpoints pulled from the nsfwalex/checkpoint_n_lora | |
| # dataset (populated by the model_downloader Space): | |
| # * Moody Pro Mix -> Z-Image-Turbo finetune (fast, guidance-free DiT) | |
| # * One Obsession -> Illustrious / SDXL anime model with Compel weighting | |
| # Constant names keep their pipeline meaning (zimage = the Z-Image DiT pipeline, | |
| # noobxl = the SDXL pipeline); only the displayed labels / weights changed. | |
| # Display labels for the two original built-in models. Kept as named constants | |
| # because the SDXL prefix path and the prompt→video helper reference them; both | |
| # are just entries in IMAGE_MODELS (the registry) below. | |
| MODEL_ZIMAGE = "Moody Pro Mix (ZIT V12 DPO)" | |
| MODEL_NOOBXL = "One Obsession v2.1 Anime" | |
| def _supports_negative(model_name): | |
| """True if ``model_name`` uses the negative prompt during generation. | |
| Z-Image-Turbo is a guidance-free distilled model and ignores it; SDXL | |
| (illustrious) and Anima both apply it. Sourced from the model's registry | |
| entry so it stays correct as models are added. | |
| """ | |
| return bool(_model_entry(model_name).get("negative")) | |
| # Custom checkpoints live in this HF dataset; HF_TOKEN (a Space secret) grants | |
| # read access. Paths mirror the model_downloader index.json layout. | |
| CHECKPOINT_DATASET = "nsfwalex/checkpoint_n_lora" | |
| # Per-model checkpoint paths in the dataset live in the IMAGE_MODELS registry. | |
| # Base Z-Image repo supplies the VAE / text encoder / tokenizer / scheduler and | |
| # the transformer config; the single-file checkpoint only carries the DiT. | |
| ZIMAGE_BASE = "Tongyi-MAI/Z-Image-Turbo" | |
| # Base Anima repo (diffusers-converted) supplies the Qwen3 text encoder, | |
| # Qwen-Image VAE, scheduler, text-conditioner and the transformer config; the | |
| # CivitAI single-file checkpoints only carry the finetuned Cosmos-Predict2 DiT. | |
| # NOTE: Anima ships under the CircleStone Labs *non-commercial* license. | |
| ANIMA_BASE = "circlestone-labs/Anima-Base-v1.0-Diffusers" | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| def _checkpoint_path(filename): | |
| """Download a checkpoint from the dataset and return its local cache path.""" | |
| return hf_hub_download( | |
| repo_id=CHECKPOINT_DATASET, | |
| repo_type="dataset", | |
| filename=filename, | |
| token=HF_TOKEN, | |
| ) | |
| MAX_SEED = np.iinfo(np.int32).max | |
| # One Obsession (Illustrious/SDXL) quality defaults. The prefix is prepended to | |
| # the positive prompt (skipping any tag the prompt already has); the negative is | |
| # the model's baseline negative. | |
| NOOBXL_PREFIX = ( | |
| "masterpiece, best quality, amazing quality, very awa, absurdres, newest, " | |
| "very aesthetic, depth of field, highres" | |
| ) | |
| NOOBXL_NEGATIVE = ( | |
| "worst quality, normal quality, anatomical nonsense, bad anatomy, " | |
| "interlocked fingers, extra fingers, watermark, simple background, transparent, " | |
| "low quality, logo, text, signature, face backlighting, backlighting" | |
| ) | |
| def _apply_prefix(prompt, prefix): | |
| """Prepend ``prefix`` to ``prompt``, dropping any prefix tag the prompt already | |
| contains (case-insensitive, per comma-separated token). Empty prefix is a no-op.""" | |
| prompt = (prompt or "").strip() | |
| if not prefix: | |
| return prompt | |
| have = {t.strip().lower() for t in prompt.split(",") if t.strip()} | |
| add = [t.strip() for t in prefix.split(",") | |
| if t.strip() and t.strip().lower() not in have] | |
| if not add: | |
| return prompt | |
| pre = ", ".join(add) | |
| return f"{pre}, {prompt}" if prompt else pre | |
| # ----------------------------------------------------------------------------- | |
| # Image-model registry + startup loader. | |
| # | |
| # Each entry says which pipeline *family* loads the model, its checkpoint in the | |
| # `nsfwalex/checkpoint_n_lora` dataset, sane sampler defaults, and whether it | |
| # honors a negative prompt. Adding a model = one entry here + uploading its | |
| # checkpoint (via the model_downloader Space); no other code change. | |
| # | |
| # WHICH of these load on a given boot is controlled by the IMAGE_MODELS env var | |
| # (comma-separated keys); anything not listed is skipped, so a box only pays VRAM | |
| # for what it serves and a test script can pin ONE model: | |
| # IMAGE_MODELS=pornmaster-zimage (set the Space variable, then restart) | |
| # Empty/unset -> DEFAULT_IMAGE_MODELS. | |
| # | |
| # Families: | |
| # illustrious -> SDXL (StableDiffusionXLPipeline, Compel weighting, negatives) | |
| # zimageturbo -> Z-Image-Turbo DiT swapped into ZIMAGE_BASE (guidance-free) | |
| # anima -> Anima (Cosmos-Predict2 DiT) via the experimental modular | |
| # pipeline. NON-COMMERCIAL license (CircleStone Labs). | |
| # ----------------------------------------------------------------------------- | |
| # `prefix` is prepended to the positive prompt (dedup'd, via _apply_prefix); | |
| # `default_negative` is used as the negative prompt when the caller supports | |
| # negatives but supplies none. Both apply to the illustrious + anima families | |
| # (the guidance-free zimageturbo family ignores negatives). | |
| IMAGE_MODELS = { | |
| "one-obsession": dict( | |
| label=MODEL_NOOBXL, family="illustrious", | |
| checkpoint="checkpoints/il/oneObsession_v21Anime.safetensors", | |
| steps=28, guidance=5.0, height=1536, width=1024, | |
| negative=True, prefix=NOOBXL_PREFIX, default_negative=NOOBXL_NEGATIVE, | |
| ), | |
| "realism-illustrious": dict( | |
| label="Realism Illustrious (Stable Yogi)", family="illustrious", | |
| checkpoint="checkpoints/il/realismIllustriousBy_v55FP16.safetensors", | |
| steps=30, guidance=5.0, height=1152, width=896, negative=True, | |
| prefix="masterpiece, best quality, amazing quality, very aesthetic, absurdres, highres, ultra detailed", | |
| default_negative=( | |
| "worst quality, low quality, normal quality, blurry, jpeg artifacts, " | |
| "bad anatomy, bad hands, bad toes, extra digits, fewer digits, deformed, " | |
| "distorted, disfigured, simple background, text, watermark, signature, " | |
| "web address, username, NEGATIVE_HANDS, lazyhand, lazyneg"), | |
| ), | |
| "moody-pro-mix": dict( | |
| label=MODEL_ZIMAGE, family="zimageturbo", | |
| checkpoint="checkpoints/zimageturbo/moodyProMix_zitV12DPO.safetensors", | |
| steps=9, guidance=0.0, height=1024, width=1024, negative=False, prefix="", | |
| ), | |
| "pornmaster-zimage": dict( | |
| label="PornMaster Z-Image", family="zimageturbo", | |
| checkpoint="checkpoints/zimageturbo/pornmasterZImage_turboV35Bf16.safetensors", | |
| steps=9, guidance=0.0, height=1024, width=1024, negative=False, prefix="", | |
| ), | |
| "zimage-nsfw-yogi": dict( | |
| label="Z-Image NSFW (Stable Yogi)", family="zimageturbo", | |
| checkpoint="checkpoints/zimageturbo/zimageTurboNSFWBy_2602NSFWBF16.safetensors", | |
| steps=9, guidance=0.0, height=1024, width=1024, negative=False, prefix="", | |
| ), | |
| "wai-anima": dict( | |
| label="WAI-ANIMA", family="anima", | |
| checkpoint="checkpoints/anima/waiANIMA_v10Base10.safetensors", | |
| steps=24, guidance=0.0, height=1024, width=1024, negative=True, | |
| prefix="masterpiece, best quality,score_7,", | |
| default_negative=( | |
| "worst quality, low quality, score_1, score_2, score_3, artist name," | |
| "blurry, jpeg artifacts, lowres,censor"), | |
| ), | |
| "samanima": dict( | |
| label="SamANIMA", family="anima", | |
| checkpoint="checkpoints/anima/samANIMA_v20.safetensors", | |
| steps=24, guidance=0.0, height=1024, width=1024, negative=True, | |
| prefix="score_9,score_6,@2025,@candid photo,newest,HDR, photography, (dramatic lighting), very aesthetic", | |
| default_negative=( | |
| "missing finger,mutation,censor,censored,high contrast, harsh lighting, " | |
| "cel shading, linear hatching,cartoon,toon,2D,CGI, comic, empty background, " | |
| "monotone background, plastic skin, bad hands, blurry details,ugly face," | |
| "ugly woman,mutation,censor,filter,skinny female,vintage,wood,old background," | |
| "ulgly face,perspective default,fake breasts,breasts implants,title,logo,watermark"), | |
| ), | |
| "pornmaster-anima": dict( | |
| label="PornMaster Anima", family="anima", | |
| checkpoint="checkpoints/anima/pornmasterAnima_baseV1.safetensors", | |
| steps=24, guidance=0.0, height=1024, width=1024, negative=True, | |
| prefix="score_9, score_8,from front, (triple penetration),masterpiece, best quality, year 2025, newest, highres, absurdres,", | |
| default_negative="loli,child. score_1, score_2, score_3, blurry, worst quality, low quality", | |
| ), | |
| } | |
| # label -> entry (the API/UI pass the human label as `model_name`). | |
| MODEL_BY_LABEL = {e["label"]: e for e in IMAGE_MODELS.values()} | |
| def _model_entry(model_name): | |
| return MODEL_BY_LABEL.get(model_name, {}) | |
| def _resolve_negative(entry, negative_prompt, use_negative_prompt, model_name): | |
| """The negative prompt to actually use: the caller's (when they enabled it and | |
| this model honors negatives), else the model's `default_negative`. Disabled or | |
| unsupported -> empty.""" | |
| if not (use_negative_prompt and _supports_negative(model_name)): | |
| return "" | |
| return (negative_prompt or "").strip() or entry.get("default_negative", "") | |
| # Which models to load this boot. Unknown keys are ignored; an empty / all-invalid | |
| # selection falls back to the first registered model so the Space always boots. | |
| DEFAULT_IMAGE_MODELS = "one-obsession,moody-pro-mix" | |
| _requested = [s.strip() for s in os.environ.get("IMAGE_MODELS", DEFAULT_IMAGE_MODELS).split(",") if s.strip()] | |
| ENABLED_MODEL_KEYS = [k for k in _requested if k in IMAGE_MODELS] or [next(iter(IMAGE_MODELS))] | |
| # Shared SDXL fp16-fix VAE (avoids black/NaN images); loaded once for every | |
| # illustrious model. | |
| _SDXL_VAE = None | |
| def _sdxl_vae(): | |
| global _SDXL_VAE | |
| if _SDXL_VAE is None: | |
| _SDXL_VAE = AutoencoderKL.from_pretrained( | |
| "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16 | |
| ) | |
| return _SDXL_VAE | |
| def _load_illustrious(entry): | |
| pipe = StableDiffusionXLPipeline.from_single_file( | |
| _checkpoint_path(entry["checkpoint"]), vae=_sdxl_vae(), | |
| torch_dtype=torch.float16, use_safetensors=True, add_watermarker=False, | |
| ) | |
| pipe.to("cuda") | |
| return pipe | |
| def _load_zimageturbo(entry): | |
| # The CivitAI checkpoint is a transformer-only (model.diffusion_model.*) file, | |
| # so load the base Z-Image-Turbo pipeline for the other components and swap in | |
| # the custom DiT from the single-file checkpoint. | |
| transformer = ZImageTransformer2DModel.from_single_file( | |
| _checkpoint_path(entry["checkpoint"]), config=ZIMAGE_BASE, | |
| subfolder="transformer", torch_dtype=torch.bfloat16, token=HF_TOKEN, | |
| ) | |
| pipe = DiffusionPipeline.from_pretrained( | |
| ZIMAGE_BASE, transformer=transformer, torch_dtype=torch.bfloat16, | |
| low_cpu_mem_usage=False, | |
| ) | |
| pipe.to("cuda") | |
| return pipe | |
| def _load_anima(entry): | |
| # Anima = Cosmos-Predict2 DiT + Qwen3 text encoder + Qwen-Image VAE. Stock | |
| # diffusers ships it only as an experimental *modular* pipeline. Load the base | |
| # components (TE / VAE / scheduler / text-conditioner) from the converted repo, | |
| # then swap the finetuned weights in from the single-file checkpoint. | |
| pipe = AnimaModularPipeline.from_pretrained(ANIMA_BASE) | |
| pipe.load_components(torch_dtype=torch.bfloat16) | |
| # CivitAI Anima checkpoints are ComfyUI-format: every tensor is prefixed | |
| # `model.diffusion_model.`, and the finetune trains TWO diffusers components: | |
| # * 567 DiT tensors -> the `transformer` (CosmosTransformer3DModel) | |
| # * 118 `llm_adapter.*` tensors -> the separate `text_conditioner` component | |
| # The diffusers Cosmos single-file converter only strips a `net.` prefix (it has | |
| # no `model.diffusion_model.` rule, unlike the Z-Image converter), so feeding the | |
| # raw file leaves every key prefixed -> it matches NO param -> the load fails and | |
| # (previously, behind a silent try/except) fell back to the BASE DiT. That made | |
| # all three Anima finetunes render identical base weights, differing only by | |
| # prompt prefix. We de-prefix and load each component explicitly instead, and let | |
| # any real failure raise loudly rather than silently degrade to base weights. | |
| DIT_PREFIX = "model.diffusion_model." | |
| ADAPTER_PREFIX = "model.diffusion_model.llm_adapter." | |
| raw = load_file(_checkpoint_path(entry["checkpoint"])) | |
| dit_sd = { | |
| k[len(DIT_PREFIX):]: v for k, v in raw.items() | |
| if k.startswith(DIT_PREFIX) and not k.startswith(ADAPTER_PREFIX) | |
| } | |
| adapter_sd = { | |
| k[len(ADAPTER_PREFIX):]: v for k, v in raw.items() | |
| if k.startswith(ADAPTER_PREFIX) | |
| } | |
| n_dit, n_adapter = len(dit_sd), len(adapter_sd) # capture now: from_single_file | |
| if not dit_sd: # pops dit_sd empty during convert | |
| raise RuntimeError( | |
| f"[anima] {entry['label']}: no '{DIT_PREFIX}*' keys in checkpoint " | |
| f"(got prefixes {sorted({k.split('.')[0] for k in raw})}); " | |
| f"the single-file format changed — fix the de-prefix logic.") | |
| # Finetuned DiT: from_single_file runs the Cosmos converter over the de-prefixed | |
| # state dict (keys now line up with what the converter expects). | |
| transformer = CosmosTransformer3DModel.from_single_file( | |
| dit_sd, config=ANIMA_BASE, subfolder="transformer", | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| # A half-mapped checkpoint would leave params on the meta device (later crashing | |
| # in pipe.to("cuda")); assert the swap fully materialized so a key-format drift | |
| # fails here with a clear message instead of silently degrading. | |
| meta = [n for n, p in transformer.named_parameters() if p.device.type == "meta"] | |
| if meta: | |
| raise RuntimeError( | |
| f"[anima] {entry['label']}: {len(meta)} transformer params unloaded " | |
| f"(meta) after from_single_file, e.g. {meta[:3]} — key mapping is wrong.") | |
| pipe.update_components(transformer=transformer) | |
| # Finetuned LLM adapter -> the `text_conditioner` component. Its keys match the | |
| # base component's names verbatim, so load straight in (strict=False just so a | |
| # future key drift logs instead of crashing the whole boot). | |
| tc = getattr(pipe, "text_conditioner", None) | |
| if adapter_sd and tc is not None: | |
| info = tc.load_state_dict(adapter_sd, strict=False) | |
| tc.to(torch.bfloat16) | |
| if info.missing_keys or info.unexpected_keys: | |
| print(f"[anima] {entry['label']} text_conditioner: " | |
| f"{len(info.missing_keys)} missing / " | |
| f"{len(info.unexpected_keys)} unexpected keys") | |
| else: | |
| print(f"[anima] {entry['label']}: no text_conditioner / adapter weights " | |
| f"(adapter_tensors={len(adapter_sd)}); using base adapter") | |
| print(f"[anima] {entry['label']}: loaded finetuned DiT ({n_dit} tensors) " | |
| f"+ adapter ({n_adapter} tensors)") | |
| pipe.to("cuda") | |
| return pipe | |
| _LOADERS = { | |
| "illustrious": _load_illustrious, | |
| "zimageturbo": _load_zimageturbo, | |
| "anima": _load_anima, | |
| } | |
| # Load each enabled model once at startup. On ZeroGPU the `.to("cuda")` calls are | |
| # captured by the runtime and run inside the @spaces.GPU function. PIPELINES maps | |
| # the human label -> pipe object (looked up by `model_name` at generation time). | |
| PIPELINES = {} | |
| for _key in ENABLED_MODEL_KEYS: | |
| _entry = IMAGE_MODELS[_key] | |
| print(f"Loading image model '{_key}' [{_entry['family']}] -> {_entry['label']} ...") | |
| PIPELINES[_entry["label"]] = _LOADERS[_entry["family"]](_entry) | |
| print(f"Image models loaded: {list(PIPELINES)}") | |
| # ----------------------------------------------------------------------------- | |
| # Multimodal assistant. Powers the "Prompt Assistant" tab: one-turn question | |
| # answering with an optional image, for turning rough ideas into rich prompts or | |
| # describing a reference image. Loaded once, bf16 — fits alongside the two | |
| # diffusion pipelines on the shared GPU. | |
| # ----------------------------------------------------------------------------- | |
| # Other models, kept for easy revert (set VLM_MODEL_ID, and VLM_LOAD_8BIT for big ones): | |
| # "rodrigomt/Qwen3.5-4B-Uncensored-Aggressive" # no generation_config; needed eos pinning | |
| # "ccharnkij/Qwen3.5-9B-Uncensored" # 9B Qwen3.5 VL, ~18.8 GB bf16, thinking model | |
| # "OpenYourMind/gemma-4-12B-it-abliterated-uncensored" # gemma4_unified, ~24GB; needs VLM_LOAD_8BIT=1 (slow) | |
| # Current: gemma-4-E4B (model_type=gemma4 / Gemma4ForConditionalGeneration), multimodal, | |
| # uncensored. Small/fast — loads full bf16 (~8 GB, fits the zero-a10g alongside the | |
| # diffusion pipelines; warm calls ~2 s). Stop tokens are handled model-agnostically by | |
| # _resolve_vlm_eos_ids() (this model ships a proper generation_config). NOTE: under-rates | |
| # explicit content (~2 cap) — the known tradeoff for its speed. VLM_LOAD_8BIT=1 forces | |
| # bitsandbytes 8-bit (only needed for the 12B); default is bf16. | |
| VLM_MODEL_ID = os.environ.get("VLM_MODEL_ID", "prithivMLmods/gemma-4-E4B-it-Uncensored-MAX") | |
| VLM_LOAD_8BIT = os.environ.get("VLM_LOAD_8BIT", "0").lower() not in ("0", "false", "no", "") | |
| print(f"Loading assistant: {VLM_MODEL_ID} (8bit={VLM_LOAD_8BIT}) ...") | |
| vlm_processor = AutoProcessor.from_pretrained( | |
| VLM_MODEL_ID, token=HF_TOKEN, trust_remote_code=True | |
| ) | |
| _vlm_kwargs = dict(token=HF_TOKEN, trust_remote_code=True, torch_dtype=torch.bfloat16) | |
| if VLM_LOAD_8BIT: | |
| # bitsandbytes quantizes the weights on GPU as they load; device_map places them | |
| # there directly, so we must NOT call .to("cuda") afterwards (it errors on a | |
| # quantized model). | |
| _vlm_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True) | |
| _vlm_kwargs["device_map"] = "auto" | |
| vlm_model = AutoModelForImageTextToText.from_pretrained(VLM_MODEL_ID, **_vlm_kwargs) | |
| if not VLM_LOAD_8BIT: | |
| vlm_model.to("cuda") | |
| vlm_model.eval() | |
| def _resolve_vlm_eos_ids(tokenizer, model): | |
| """Collect every stop token id for the active VLM, model-agnostically. | |
| generate() stops on these. We union three sources so a model swap can't | |
| silently lose its terminator: (1) the model's own generation_config (correct | |
| for models that ship one, e.g. Gemma's <turn|>/<eos>), (2) the tokenizer's | |
| eos, and (3) known chat-turn terminators across families (Gemma's <turn|>, | |
| Qwen's <|im_end|>, Llama's <|eot_id|>). Tokens absent from the vocab resolve | |
| to <unk> and are skipped. This is what stops a model like the old Qwen | |
| finetune (no generation_config) from rambling to max_new_tokens. | |
| """ | |
| ids = set() | |
| gc = getattr(model, "generation_config", None) | |
| gc_eos = getattr(gc, "eos_token_id", None) if gc is not None else None | |
| if isinstance(gc_eos, int): | |
| ids.add(gc_eos) | |
| elif isinstance(gc_eos, (list, tuple)): | |
| ids.update(int(x) for x in gc_eos) | |
| if tokenizer.eos_token_id is not None: | |
| ids.add(int(tokenizer.eos_token_id)) | |
| unk = tokenizer.unk_token_id | |
| for t in ("<turn|>", "<end_of_turn>", "<|im_end|>", "<|eot_id|>", "<|endoftext|>"): | |
| tid = tokenizer.convert_tokens_to_ids(t) | |
| if isinstance(tid, int) and tid >= 0 and tid != unk: | |
| ids.add(tid) | |
| return sorted(ids) | |
| _vlm_tokenizer = getattr(vlm_processor, "tokenizer", vlm_processor) | |
| _VLM_EOS_IDS = _resolve_vlm_eos_ids(_vlm_tokenizer, vlm_model) | |
| _VLM_PAD_ID = ( | |
| _vlm_tokenizer.pad_token_id | |
| if _vlm_tokenizer.pad_token_id is not None | |
| else (_VLM_EOS_IDS[0] if _VLM_EOS_IDS else None) | |
| ) | |
| print(f"Assistant loaded! (eos_ids={_VLM_EOS_IDS}, pad_id={_VLM_PAD_ID})") | |
| # ============================================================================= | |
| # Image-context (moderation) check | |
| # ============================================================================= | |
| # Run on every freshly generated image, *inside the same @spaces.GPU window that | |
| # produced it* — so it shares that GPU allocation rather than paying for a second | |
| # ZeroGPU acquisition. The VLM fills a fixed safety-rating JSON; the reply is | |
| # parsed best-effort and returned alongside the image and its R2 URL. | |
| MODERATION_PROMPT = """You are a STRICT adult-content safety rater for an explicit image service. MOST images you see ARE sexual or explicit, so do NOT play it safe — assign the HIGHEST rating the image qualifies for. Under-rating explicit content is the worst mistake you can make. | |
| Look carefully at what is ACTUALLY visible in the image, then output ONLY this JSON (no prose, no markdown, no <think>): | |
| { | |
| "rating": "<0_SAFE | 1_MILD_SUGGESTIVE | 2_STRONG_SUGGESTIVE | 3_NON_EXPLICIT_NUDITY | 4_EXPLICIT_SEXUAL | 5_PROHIBITED>", | |
| "confidence": 0.0, | |
| "flags": { | |
| "minor_or_young_looking": false, | |
| "sexualized_minor_risk": false, | |
| "nudity": false, | |
| "explicit_genitals": false, | |
| "sexual_act": false, | |
| "non_consensual_or_abuse": false, | |
| "bestiality": false, | |
| "violence_or_gore": false, | |
| "anime_or_drawn": false | |
| } | |
| } | |
| HOW TO RATE — read DOWN this list and STOP at the FIRST line that matches what you see. That line is your rating: | |
| - A minor (real or drawn — anyone who could be under 18) in any sexual or nude context, OR non-consensual / abuse, OR an animal in a sexual context -> 5_PROHIBITED | |
| - Genitals are visible (vulva, labia, clitoris, penis, scrotum, anus) OR a sexual act is shown (vaginal / anal / oral sex, penetration, handjob, fingering, masturbation, ejaculation / cum, a sex toy on or in the body, visible arousal fluids) -> 4_EXPLICIT_SEXUAL | |
| - Bare breasts (nipple or areola visible), a fully bare butt, or a bare pubic mound is visible, but NO genital detail and NO sexual act -> 3_NON_EXPLICIT_NUDITY | |
| - No bare skin-level nudity, but lingerie / bra / panties / underwear / swimwear / fishnet / fully see-through clothing, OR an overtly sexual pose (legs spread, ass thrust toward camera, arched back presenting, hands cupping clothed breasts) -> 2_STRONG_SUGGESTIVE | |
| - Fully clothed but a bit revealing (cleavage, short skirt, tight outfit) or a flirty look -> 1_MILD_SUGGESTIVE | |
| - Ordinary, fully clothed, nothing sexual -> 0_SAFE | |
| HARD RULES — these override any instinct to round the rating down: | |
| 1. The instant a nipple / areola, a bare butt, or any genital is visible, the image is AT LEAST 3. It can NEVER be 0, 1 or 2 once nudity is visible. | |
| 2. The instant genitals are visible OR any sexual act is happening, the image is AT LEAST 4 — never 2 or 3. A visible penis or vulva, penetration, or oral sex is 4 (or 5 with a minor / non-consent / animal). | |
| 3. "Sexy but fully covered" (lingerie or swimwear, no skin-level nudity) is the CEILING for level 2. If ANYTHING is uncovered, you MUST go higher. | |
| 4. Drawn, anime, hentai, 3D, cartoon and AI-generated art are rated EXACTLY like a real photo. Never lower the rating because it is illustrated — set "anime_or_drawn" true and rate the depicted content itself. | |
| 5. Set EVERY flag that applies, honestly: "nudity" (any bare breasts / butt / genitals), "explicit_genitals" (genitals visible), "sexual_act" (any sexual act), "minor_or_young_looking" (the subject could read as under 18). Your flags MUST agree with your rating. | |
| Output the JSON object only.""" | |
| _RATING_LEVELS = [ | |
| "0_SAFE", "1_MILD_SUGGESTIVE", "2_STRONG_SUGGESTIVE", | |
| "3_NON_EXPLICIT_NUDITY", "4_EXPLICIT_SEXUAL", "5_PROHIBITED", | |
| ] | |
| def _rating_index(rating): | |
| """Map a rating label (e.g. '4_EXPLICIT_SEXUAL') to its 0..5 index; 0 if unknown.""" | |
| if not rating: | |
| return 0 | |
| s = str(rating).strip() | |
| head = s.split("_", 1)[0] | |
| if head.isdigit(): | |
| i = int(head) | |
| if 0 <= i <= 5: | |
| return i | |
| s_up = s.upper() | |
| for i, lvl in enumerate(_RATING_LEVELS): | |
| if lvl in s_up: | |
| return i | |
| return 0 | |
| def _flag_rating_floor(flags): | |
| """Minimum rating implied by the boolean flags. The model tends to under-rate | |
| the holistic ``rating`` label (capping at ~2) even when it correctly flags | |
| nudity / genitals / a sex act, so the flags raise a floor the label can't sink | |
| below.""" | |
| if not isinstance(flags, dict): | |
| return 0 | |
| def f(k): | |
| return flags.get(k) is True | |
| if f("sexualized_minor_risk") or f("non_consensual_or_abuse") or f("bestiality"): | |
| return 5 | |
| if f("explicit_genitals") or f("sexual_act"): | |
| return 4 | |
| if f("nudity"): | |
| return 3 | |
| return 0 | |
| def _parse_moderation_json(text): | |
| """Best-effort parse of the moderation JSON the VLM returns. | |
| The model is asked for JSON only, but a reasoning finetune can still wrap it | |
| in <think> tags or stray prose, so we strip those and slice the outermost | |
| ``{...}`` before ``json.loads``. Returns a dict with the parsed ``rating`` / | |
| ``confidence`` / ``flags`` keys (None when absent) plus ``ok`` and the | |
| original ``raw`` text, so nothing is lost even when parsing fails. | |
| The final ``rating`` is escalated to ``max(model label, flag-implied floor)`` | |
| so explicit content can't be under-rated by a hedging label; the model's own | |
| label is preserved as ``model_rating``. | |
| """ | |
| raw = (text or "").strip() | |
| cleaned = raw.split("</think>")[-1].strip() if "</think>" in raw else raw | |
| start, end = cleaned.find("{"), cleaned.rfind("}") | |
| parsed = None | |
| if start != -1 and end > start: | |
| try: | |
| parsed = json.loads(cleaned[start:end + 1]) | |
| except (ValueError, TypeError): | |
| parsed = None | |
| if isinstance(parsed, dict): | |
| flags = parsed.get("flags") | |
| model_rating = parsed.get("rating") | |
| final_idx = max(_rating_index(model_rating), _flag_rating_floor(flags)) | |
| return { | |
| "ok": True, | |
| "rating": _RATING_LEVELS[final_idx], | |
| "model_rating": model_rating, | |
| "confidence": parsed.get("confidence"), | |
| "flags": flags, | |
| "raw": raw, | |
| } | |
| return {"ok": False, "rating": None, "confidence": None, "flags": None, "raw": raw} | |
| def _moderate_image_inner(image, max_new_tokens=320): | |
| """Synchronously run the VLM moderation check on a PIL image (GPU-side).""" | |
| messages = [{ | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": image}, | |
| {"type": "text", "text": MODERATION_PROMPT}, | |
| ], | |
| }] | |
| inputs = vlm_processor.apply_chat_template( | |
| messages, tokenize=True, add_generation_prompt=True, | |
| return_dict=True, return_tensors="pt", enable_thinking=False, | |
| ).to(vlm_model.device) | |
| with torch.inference_mode(): | |
| generated = vlm_model.generate( | |
| **inputs, max_new_tokens=int(max_new_tokens), do_sample=False, | |
| ) | |
| trimmed = generated[0][inputs["input_ids"].shape[-1]:] | |
| tokenizer = getattr(vlm_processor, "tokenizer", vlm_processor) | |
| text = tokenizer.decode(trimmed, skip_special_tokens=True) | |
| return _parse_moderation_json(text) | |
| # ============================================================================= | |
| # Compel-based prompt weighting helpers (ported from the NoobXL11 reference | |
| # space so long prompts and (weight:1.2) syntax work for the SDXL model). | |
| # ============================================================================= | |
| def parse_prompt_attention(text): | |
| re_attention = re.compile(r""" | |
| \\\(| | |
| \\\)| | |
| \\\[| | |
| \\]| | |
| \\\\| | |
| \\| | |
| \(| | |
| \[| | |
| :([+-]?[.\d]+)\)| | |
| \)| | |
| ]| | |
| [^\\()\[\]:]+| | |
| : | |
| """, re.X) | |
| res = [] | |
| round_brackets = [] | |
| square_brackets = [] | |
| round_bracket_multiplier = 1.1 | |
| square_bracket_multiplier = 1 / 1.1 | |
| def multiply_range(start_position, multiplier): | |
| for p in range(start_position, len(res)): | |
| res[p][1] *= multiplier | |
| for m in re_attention.finditer(text): | |
| text = m.group(0) | |
| weight = m.group(1) | |
| if text.startswith('\\'): | |
| res.append([text[1:], 1.0]) | |
| elif text == '(': | |
| round_brackets.append(len(res)) | |
| elif text == '[': | |
| square_brackets.append(len(res)) | |
| elif weight is not None and len(round_brackets) > 0: | |
| multiply_range(round_brackets.pop(), float(weight)) | |
| elif text == ')' and len(round_brackets) > 0: | |
| multiply_range(round_brackets.pop(), round_bracket_multiplier) | |
| elif text == ']' and len(square_brackets) > 0: | |
| multiply_range(square_brackets.pop(), square_bracket_multiplier) | |
| else: | |
| parts = re.split(re.compile(r"\s*\bBREAK\b\s*", re.S), text) | |
| for i, part in enumerate(parts): | |
| if i > 0: | |
| res.append(["BREAK", -1]) | |
| res.append([part, 1.0]) | |
| for pos in round_brackets: | |
| multiply_range(pos, round_bracket_multiplier) | |
| for pos in square_brackets: | |
| multiply_range(pos, square_bracket_multiplier) | |
| if len(res) == 0: | |
| res = [["", 1.0]] | |
| i = 0 | |
| while i + 1 < len(res): | |
| if res[i][1] == res[i + 1][1]: | |
| res[i][0] += res[i + 1][0] | |
| res.pop(i + 1) | |
| else: | |
| i += 1 | |
| return res | |
| def prompt_attention_to_invoke_prompt(attention): | |
| tokens = [] | |
| for text, weight in attention: | |
| weight = round(weight, 2) | |
| if weight == 1.0: | |
| tokens.append(text) | |
| elif weight < 1.0: | |
| if weight < 0.8: | |
| tokens.append(f"({text}){weight}") | |
| else: | |
| tokens.append(f"({text})-" + "-" * int((1.0 - weight) * 10)) | |
| else: | |
| if weight < 1.3: | |
| tokens.append(f"({text})" + "+" * int((weight - 1.0) * 10)) | |
| else: | |
| tokens.append(f"({text}){weight}") | |
| return "".join(tokens) | |
| def merge_embeds(prompt_chanks, compel): | |
| num_chanks = len(prompt_chanks) | |
| if num_chanks != 0: | |
| power_prompt = 1 / (num_chanks * (num_chanks + 1) // 2) | |
| prompt_embs = compel(prompt_chanks) | |
| t_list = list(torch.split(prompt_embs, 1, dim=0)) | |
| for i in range(num_chanks): | |
| t_list[-(i + 1)] = t_list[-(i + 1)] * ((i + 1) * power_prompt) | |
| prompt_emb = torch.stack(t_list, dim=0).sum(dim=0) | |
| else: | |
| prompt_emb = compel('') | |
| return prompt_emb | |
| def detokenize(chunk, actual_prompt): | |
| chunk[-1] = chunk[-1].replace('</w>', '') | |
| chanked_prompt = ''.join(chunk).strip() | |
| while '</w>' in chanked_prompt: | |
| if actual_prompt[chanked_prompt.find('</w>')] == ' ': | |
| chanked_prompt = chanked_prompt.replace('</w>', ' ', 1) | |
| else: | |
| chanked_prompt = chanked_prompt.replace('</w>', '', 1) | |
| actual_prompt = actual_prompt.replace(chanked_prompt, '') | |
| return chanked_prompt.strip(), actual_prompt.strip() | |
| def tokenize_line(line, tokenizer): | |
| actual_prompt = line.lower().strip() | |
| actual_tokens = tokenizer.tokenize(actual_prompt) | |
| max_tokens = tokenizer.model_max_length - 2 | |
| comma_token = tokenizer.tokenize(',')[0] | |
| chunks = [] | |
| chunk = [] | |
| for item in actual_tokens: | |
| chunk.append(item) | |
| if len(chunk) == max_tokens: | |
| if chunk[-1] != comma_token: | |
| for i in range(max_tokens - 1, -1, -1): | |
| if chunk[i] == comma_token: | |
| actual_chunk, actual_prompt = detokenize(chunk[:i + 1], actual_prompt) | |
| chunks.append(actual_chunk) | |
| chunk = chunk[i + 1:] | |
| break | |
| else: | |
| actual_chunk, actual_prompt = detokenize(chunk, actual_prompt) | |
| chunks.append(actual_chunk) | |
| chunk = [] | |
| else: | |
| actual_chunk, actual_prompt = detokenize(chunk, actual_prompt) | |
| chunks.append(actual_chunk) | |
| chunk = [] | |
| if chunk: | |
| actual_chunk, _ = detokenize(chunk, actual_prompt) | |
| chunks.append(actual_chunk) | |
| return chunks | |
| def get_embed_new(prompt, pipeline, compel, only_convert_string=False, compel_process_sd=False): | |
| if compel_process_sd: | |
| return merge_embeds(tokenize_line(prompt, pipeline.tokenizer), compel) | |
| else: | |
| prompt = prompt.replace("((", "(").replace("))", ")").replace("\\", "\\\\\\") | |
| attention = parse_prompt_attention(prompt) | |
| global_attention_chanks = [] | |
| for att in attention: | |
| for chank in att[0].split(','): | |
| temp_prompt_chanks = tokenize_line(chank, pipeline.tokenizer) | |
| for small_chank in temp_prompt_chanks: | |
| temp_dict = { | |
| "weight": round(att[1], 2), | |
| "lenght": len(pipeline.tokenizer.tokenize(f'{small_chank},')), | |
| "prompt": f'{small_chank},' | |
| } | |
| global_attention_chanks.append(temp_dict) | |
| max_tokens = pipeline.tokenizer.model_max_length - 2 | |
| global_prompt_chanks = [] | |
| current_list = [] | |
| current_length = 0 | |
| for item in global_attention_chanks: | |
| if current_length + item['lenght'] > max_tokens: | |
| global_prompt_chanks.append(current_list) | |
| current_list = [[item['prompt'], item['weight']]] | |
| current_length = item['lenght'] | |
| else: | |
| if not current_list: | |
| current_list.append([item['prompt'], item['weight']]) | |
| else: | |
| if item['weight'] != current_list[-1][1]: | |
| current_list.append([item['prompt'], item['weight']]) | |
| else: | |
| current_list[-1][0] += f" {item['prompt']}" | |
| current_length += item['lenght'] | |
| if current_list: | |
| global_prompt_chanks.append(current_list) | |
| if only_convert_string: | |
| return ' '.join([prompt_attention_to_invoke_prompt(i) for i in global_prompt_chanks]) | |
| return merge_embeds([prompt_attention_to_invoke_prompt(i) for i in global_prompt_chanks], compel) | |
| # ============================================================================= | |
| # Generation | |
| # ============================================================================= | |
| def _supports_step_callback(pipe): | |
| """True if this diffusers pipeline's __call__ accepts callback_on_step_end.""" | |
| try: | |
| return "callback_on_step_end" in inspect.signature(pipe.__call__).parameters | |
| except (TypeError, ValueError): | |
| return False | |
| def generate_image( | |
| model_name, | |
| prompt, | |
| negative_prompt, | |
| use_negative_prompt, | |
| height, | |
| width, | |
| num_inference_steps, | |
| guidance_scale, | |
| seed, | |
| randomize_seed, | |
| prompt_improve_instruction="", | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| """Generate an image, streaming per-step progress. | |
| This is a generator (so ZeroGPU streams its yields back over SSE). It yields | |
| ``("progress", step, total)`` tuples during sampling and a final | |
| ``("image", image, seed, moderation, effective_prompt)`` tuple. The sampler | |
| runs in a worker thread feeding a queue so the main thread can yield progress | |
| as each diffusion step lands. Once the image is ready, the VLM image-context | |
| (moderation) check runs in this same GPU window and its best-effort-parsed | |
| result rides along on the final tuple. | |
| When ``prompt_improve_instruction`` is non-empty, the VLM (Qwen) first rewrites | |
| the original ``prompt`` according to that instruction — reusing this same GPU | |
| window — and the improved text becomes the actual image prompt (and is echoed | |
| back as ``effective_prompt``). When it is empty the original prompt is used | |
| unchanged. | |
| """ | |
| _gpu_start = time.time() | |
| total_steps = int(num_inference_steps) | |
| # Optional prompt-improvement pass: rewrite the user's prompt with the VLM | |
| # before sampling. Runs inside this same @spaces.GPU window (single ZeroGPU | |
| # allocation), mirroring how the moderation check rides along below. | |
| effective_prompt = prompt | |
| instruction = (prompt_improve_instruction or "").strip() | |
| if instruction: | |
| # The improve instruction is the LLM's SYSTEM prompt; the user's original | |
| # prompt is the user message it rewrites. | |
| improved = "" | |
| for ev in _vlm_chat_core( | |
| prompt, None, "Off", 512, system_prompt=instruction | |
| ): | |
| if ev[0] == "text": | |
| improved = ev[1] | |
| improved = (improved or "").strip() | |
| if improved: | |
| effective_prompt = improved | |
| q = queue.Queue() | |
| result = {} | |
| def _step_cb(_pipe, step, _timestep, callback_kwargs): | |
| # diffusers calls this after each step; `step` is the 0-based index. | |
| q.put(step + 1) | |
| return callback_kwargs | |
| def _run(): | |
| try: | |
| result["image"], result["seed"] = _generate_image_inner( | |
| model_name, effective_prompt, negative_prompt, use_negative_prompt, | |
| height, width, total_steps, guidance_scale, seed, randomize_seed, | |
| callback=_step_cb, | |
| ) | |
| except Exception as exc: # noqa: BLE001 - surfaced to the main thread | |
| result["error"] = exc | |
| finally: | |
| q.put(None) # sentinel: generation finished (ok or error) | |
| thread = threading.Thread(target=_run, daemon=True) | |
| thread.start() | |
| try: | |
| while True: | |
| step = q.get() | |
| if step is None: | |
| break | |
| yield ("progress", step, total_steps) | |
| finally: | |
| thread.join() | |
| print( | |
| f"[ImageStudio] GPU time consumed: {time.time() - _gpu_start:.2f}s " | |
| f"(model={model_name}, steps={num_inference_steps}, {int(width)}x{int(height)})", | |
| flush=True, | |
| ) | |
| if "error" in result: | |
| raise result["error"] | |
| image = result["image"] | |
| # Same GPU window: run the image-context (moderation) check on the freshly | |
| # generated image before the GPU allocation is released. | |
| try: | |
| moderation = _moderate_image_inner(image) | |
| except Exception as exc: # noqa: BLE001 - never let moderation break generation | |
| moderation = {"ok": False, "rating": None, "confidence": None, | |
| "flags": None, "raw": "", "error": f"{type(exc).__name__}: {exc}"} | |
| yield ("image", image, result["seed"], moderation, effective_prompt) | |
| def _generate_image_inner( | |
| model_name, | |
| prompt, | |
| negative_prompt, | |
| use_negative_prompt, | |
| height, | |
| width, | |
| num_inference_steps, | |
| guidance_scale, | |
| seed, | |
| randomize_seed, | |
| callback=None, | |
| ): | |
| if randomize_seed: | |
| seed = random.randint(0, MAX_SEED) | |
| seed = int(seed) | |
| entry = _model_entry(model_name) | |
| family = entry.get("family") | |
| pipe = PIPELINES.get(model_name) | |
| if pipe is None: | |
| raise gr.Error( | |
| f"Model '{model_name}' is not loaded on this Space " | |
| f"(loaded: {list(PIPELINES)}). Set the IMAGE_MODELS env var to include it." | |
| ) | |
| if family == "illustrious": | |
| generator = torch.Generator().manual_seed(seed) | |
| compel = Compel( | |
| tokenizer=[pipe.tokenizer, pipe.tokenizer_2], | |
| text_encoder=[pipe.text_encoder, pipe.text_encoder_2], | |
| returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, | |
| requires_pooled=[False, True], | |
| truncate_long_prompts=False, | |
| ) | |
| # Per-model quality prefix (skipping tags already in the prompt) + the | |
| # model's default negative when the caller supplied none. | |
| prompt = _apply_prefix(prompt, entry.get("prefix", "")) | |
| negative_prompt = _resolve_negative(entry, negative_prompt, use_negative_prompt, model_name) | |
| conv_prompt = get_embed_new(prompt, pipe, compel, only_convert_string=True) | |
| conv_negative = get_embed_new(negative_prompt, pipe, compel, only_convert_string=True) | |
| # Encode prompt and negative separately, then pad to equal length. The | |
| # batched ``compel([prompt, negative])`` call pads internally via | |
| # ``conditioning_provider.empty_z``, which the SDXL dual-encoder | |
| # ``EmbeddingsProviderMulti`` does not expose (crashes on long prompts | |
| # whose chunk count differs from the negative). Supplying the empty-string | |
| # conditioning as ``precomputed_padding`` sidesteps that missing attribute. | |
| cond_prompt, pooled_prompt = compel(conv_prompt) | |
| cond_negative, pooled_negative = compel(conv_negative) | |
| empty_padding, _ = compel("") | |
| cond_prompt, cond_negative = compel.pad_conditioning_tensors_to_same_length( | |
| [cond_prompt, cond_negative], precomputed_padding=empty_padding | |
| ) | |
| kwargs = dict( | |
| prompt_embeds=cond_prompt, | |
| pooled_prompt_embeds=pooled_prompt, | |
| negative_prompt_embeds=cond_negative, | |
| negative_pooled_prompt_embeds=pooled_negative, | |
| width=int(width), | |
| height=int(height), | |
| guidance_scale=float(guidance_scale), | |
| num_inference_steps=int(num_inference_steps), | |
| generator=generator, | |
| use_resolution_binning=True, | |
| ) | |
| if callback is not None and _supports_step_callback(pipe): | |
| kwargs["callback_on_step_end"] = callback | |
| image = pipe(**kwargs).images[0] | |
| return image, seed | |
| if family == "anima": | |
| # Anima (Cosmos-Predict2 DiT) via the experimental modular pipeline. It | |
| # manages guidance internally (no `guidance_scale` input) but does honor a | |
| # negative prompt. The modular call returns the requested `output` value. | |
| generator = torch.Generator("cuda").manual_seed(seed) | |
| prompt = _apply_prefix(prompt, entry.get("prefix", "")) | |
| negative_prompt = _resolve_negative(entry, negative_prompt, use_negative_prompt, model_name) | |
| images = pipe( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt or None, | |
| height=int(height), | |
| width=int(width), | |
| num_inference_steps=int(num_inference_steps), | |
| num_images_per_prompt=1, | |
| generator=generator, | |
| output_type="pil", | |
| output="images", | |
| ) | |
| image = images[0] if isinstance(images, (list, tuple)) else images | |
| return image, seed | |
| # Default: Z-Image-Turbo (guidance-free distilled model) | |
| generator = torch.Generator("cuda").manual_seed(seed) | |
| kwargs = dict( | |
| prompt=prompt, | |
| height=int(height), | |
| width=int(width), | |
| num_inference_steps=int(num_inference_steps), | |
| guidance_scale=0.0, | |
| generator=generator, | |
| ) | |
| if callback is not None and _supports_step_callback(pipe): | |
| kwargs["callback_on_step_end"] = callback | |
| image = pipe(**kwargs).images[0] | |
| return image, seed | |
| # ============================================================================= | |
| # Prompt Assistant (Qwen3.5-4B) — single-turn chat, optional image | |
| # ============================================================================= | |
| def _vlm_chat_core(message, image, reasoning, max_new_tokens, system_prompt=""): | |
| """Token-streaming VLM chat body (no GPU decorator). | |
| Yields ``("progress", produced, budget, partial)`` per token and a final | |
| ``("text", answer)`` tuple. Kept decorator-free so it can be reused inside a | |
| *single* ``@spaces.GPU`` window by other endpoints (e.g. the image→video | |
| check), avoiding a second ZeroGPU acquisition. See :func:`vlm_chat`. | |
| ``system_prompt`` (optional) steers the assistant. Models whose chat template | |
| supports a ``system`` role (e.g. Qwen) get a real system turn; templates that | |
| reject one (e.g. Gemma) fall back to folding the system text into the user | |
| turn, so the param works regardless of which model is loaded. | |
| """ | |
| message = (message or "").strip() | |
| system_prompt = (system_prompt or "").strip() | |
| if not message and image is None: | |
| yield ("text", "Please enter a question (and optionally attach an image).") | |
| return | |
| enable_thinking = (reasoning == "On") | |
| budget = int(max_new_tokens) | |
| _gpu_start = time.time() | |
| user_text = message or "Describe this image." | |
| content = [] | |
| if image is not None: | |
| content.append({"type": "image", "image": image}) | |
| content.append({"type": "text", "text": user_text}) | |
| messages = [] | |
| if system_prompt: | |
| messages.append({"role": "system", "content": [{"type": "text", "text": system_prompt}]}) | |
| messages.append({"role": "user", "content": content}) | |
| def _template(msgs): | |
| return vlm_processor.apply_chat_template( | |
| msgs, | |
| tokenize=True, | |
| add_generation_prompt=True, | |
| return_dict=True, | |
| return_tensors="pt", | |
| enable_thinking=enable_thinking, | |
| ) | |
| try: | |
| inputs = _template(messages).to(vlm_model.device) | |
| except Exception: | |
| # Template has no system role (Gemma et al.) — fold it into the user turn. | |
| if not system_prompt: | |
| raise | |
| folded = [] | |
| if image is not None: | |
| folded.append({"type": "image", "image": image}) | |
| folded.append({"type": "text", "text": f"{system_prompt}\n\n{user_text}"}) | |
| inputs = _template([{"role": "user", "content": folded}]).to(vlm_model.device) | |
| tokenizer = getattr(vlm_processor, "tokenizer", vlm_processor) | |
| streamer = TextIteratorStreamer( | |
| tokenizer, skip_prompt=True, skip_special_tokens=True | |
| ) | |
| result = {} | |
| def _run(): | |
| try: | |
| with torch.inference_mode(): | |
| vlm_model.generate( | |
| **inputs, | |
| max_new_tokens=budget, | |
| do_sample=False, | |
| # Greedy decoding alone loops into repeated/rambling output that | |
| # runs to max_new_tokens instead of emitting EOS. These penalties | |
| # break n-gram repetition and let the model stop on its own, while | |
| # keeping decoding deterministic (important for prompt rewrites). | |
| repetition_penalty=1.3, | |
| no_repeat_ngram_size=3, | |
| # Explicit stop tokens — the model has no generation_config, so | |
| # without these generate() never stops and rambles to the budget. | |
| eos_token_id=_VLM_EOS_IDS, | |
| pad_token_id=_VLM_PAD_ID, | |
| streamer=streamer, | |
| ) | |
| except Exception as exc: # noqa: BLE001 - surfaced to the main thread | |
| result["error"] = exc | |
| streamer.end() | |
| thread = threading.Thread(target=_run, daemon=True) | |
| thread.start() | |
| try: | |
| text = "" | |
| produced = 0 | |
| for chunk in streamer: | |
| text += chunk | |
| produced += 1 | |
| yield ("progress", produced, budget, text) | |
| finally: | |
| thread.join() | |
| print( | |
| f"[ImageStudio] Assistant GPU time: {time.time() - _gpu_start:.2f}s " | |
| f"(has_image={image is not None}, reasoning={reasoning}, max_new_tokens={budget})", | |
| flush=True, | |
| ) | |
| if "error" in result: | |
| raise result["error"] | |
| text = text.strip() | |
| # With reasoning off, drop any stray <think>…</think> block so the answer | |
| # stays clean; with it on, keep the trace so the user can see it. | |
| if not enable_thinking and "</think>" in text: | |
| text = text.split("</think>")[-1].strip() | |
| yield ("text", text) | |
| def vlm_chat(message, image, reasoning, max_new_tokens, system_prompt="", progress=gr.Progress(track_tqdm=True)): | |
| """Answer a single user message, optionally grounded on an uploaded image. | |
| Thin ``@spaces.GPU`` wrapper around :func:`_vlm_chat_core` (which holds the | |
| token-streaming logic). ``reasoning`` ("On"/"Off") drives the model's | |
| ``enable_thinking`` switch: Off skips the <think> trace for a direct answer | |
| (best for prompt rewriting); On lets the model reason first (slower). | |
| ``system_prompt`` (optional) steers the assistant. | |
| """ | |
| yield from _vlm_chat_core(message, image, reasoning, max_new_tokens, system_prompt) | |
| def _image_video_check_gpu(image, message, reasoning, max_new_tokens, system_prompt=""): | |
| """One GPU window: moderate ``image``, then stream a video prompt from it. | |
| Yields ``("moderation", dict)`` once the image-context check is done, then | |
| relays :func:`_vlm_chat_core`'s ``("progress", …)`` / ``("text", …)`` tuples. | |
| Fusing both into a single ZeroGPU allocation (rather than two) matches the | |
| pattern used by :func:`generate_image`. | |
| """ | |
| try: | |
| moderation = _moderate_image_inner(image) | |
| except Exception as exc: # noqa: BLE001 - never let moderation break the call | |
| moderation = {"ok": False, "rating": None, "confidence": None, | |
| "flags": None, "raw": "", "error": f"{type(exc).__name__}: {exc}"} | |
| yield ("moderation", moderation) | |
| yield from _vlm_chat_core(message, image, reasoning, max_new_tokens, system_prompt) | |
| def generate_and_upload( | |
| model_name, | |
| prompt, | |
| negative_prompt, | |
| use_negative_prompt, | |
| height, | |
| width, | |
| num_inference_steps, | |
| guidance_scale, | |
| seed, | |
| randomize_seed, | |
| prompt_improve_instruction="", | |
| request: gr.Request = None, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| """Generate, then upload the result to R2 outside the GPU window. | |
| Generator yielding ``(image, seed, r2_status, progress)``. Streams per-step | |
| sampling progress (image still None) and finishes with the real image, seed | |
| and R2 status once the upload completes. The image is always the original | |
| HF-generated asset; ``r2_status`` reports the uploaded filekey, bucket and a | |
| presigned ``r2_url`` on success (or the error on failure) plus the VLM | |
| ``moderation`` result. The asset is encoded as **WebP** before upload (and | |
| the returned file is WebP too) so downstream consumers only ever see WebP. | |
| The caller's unique id (``uid`` cookie) is recorded in the object's metadata. | |
| """ | |
| image, used, moderation = None, None, None | |
| effective_prompt = prompt | |
| for ev in generate_image( | |
| model_name, prompt, negative_prompt, use_negative_prompt, | |
| height, width, num_inference_steps, guidance_scale, seed, randomize_seed, | |
| prompt_improve_instruction, | |
| progress=progress, | |
| ): | |
| if ev[0] == "progress": | |
| _, step, total = ev | |
| # Reserve the last 5% of this node for the R2 upload that follows. | |
| frac = (step / max(total, 1)) * 0.95 | |
| yield None, None, None, _progress( | |
| "image", frac, step, total, f"Sampling {step}/{total}" | |
| ) | |
| else: | |
| _, image, used, moderation, effective_prompt = ev | |
| # Stamp the brand watermark (logo + domain) onto the asset when the caller | |
| # supplied a valid ``wm`` cookie spec — applied after moderation so it never | |
| # affects the safety rating, and to the same image we display and upload. | |
| watermark = r2_uploader.watermark_from_request(request) | |
| if watermark_module.is_valid(watermark): | |
| image = watermark_module.apply_to_image(image, watermark) | |
| # Direct Gradio web-UI generations are not uploaded to R2 — only API calls | |
| # (the generator, which forwards a uid cookie) populate the asset store. | |
| if not r2_uploader.request_is_api(request): | |
| status = {"r2_skipped": "web-ui generation (not uploaded)", "moderation": moderation} | |
| yield image, used, status, _progress("done", 1.0, label="Done") | |
| return | |
| yield None, used, None, _progress("image", 0.97, label="Uploading") | |
| uid = r2_uploader.uid_from_request(request) | |
| buf = io.BytesIO() | |
| image.save(buf, format="WEBP", quality=95, method=6) | |
| params = { | |
| "model": model_name, | |
| "prompt": effective_prompt, | |
| "negative_prompt": (negative_prompt if use_negative_prompt else ""), | |
| "height": int(height), | |
| "width": int(width), | |
| "num_inference_steps": int(num_inference_steps), | |
| "guidance_scale": float(guidance_scale), | |
| "seed": int(used), | |
| "uid": uid, | |
| } | |
| result = r2_uploader.upload_asset( | |
| namespace=R2_NAMESPACE, | |
| prompt=effective_prompt, | |
| params=params, | |
| data=buf.getvalue(), | |
| ext=".webp", | |
| content_type="image/webp", | |
| uid=uid, | |
| ) | |
| if result.get("ok"): | |
| status = { | |
| "r2_filekey": result["filekey"], | |
| "r2_bucket": result["bucket"], | |
| "r2_url": r2_uploader.presign_get_url(result["filekey"], result["bucket"]), | |
| } | |
| else: | |
| status = {"r2_error": result.get("error", "unknown error")} | |
| status["moderation"] = moderation | |
| yield image, used, status, _progress("done", 1.0, label="Done") | |
| def assistant_chat( | |
| message, image, reasoning, max_new_tokens, system_prompt="", | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| """Gradio-facing wrapper around ``vlm_chat``. | |
| Yields ``(answer, progress)`` so the UI streams the text live and a downstream | |
| consumer reading the progress index sees this node advance over SSE. The final | |
| ``complete`` frame carries the clean answer at index 0. ``system_prompt`` | |
| (optional) steers the assistant. | |
| """ | |
| for ev in vlm_chat(message, image, reasoning, max_new_tokens, system_prompt, progress=progress): | |
| if ev[0] == "progress": | |
| _, produced, budget, partial = ev | |
| frac = min(0.99, produced / max(budget, 1)) | |
| yield partial, _progress("prompt", frac, produced, budget, "Writing prompt") | |
| else: | |
| _, text = ev | |
| yield text, _progress("done", 1.0, label="Done") | |
| # ============================================================================= | |
| # Combined endpoint: text -> first-frame image + video prompt (no UI) | |
| # ============================================================================= | |
| # One call chains LLM + image model + LLM to turn a raw idea into the two assets | |
| # an image-to-video pipeline needs: | |
| # 1. LLM (Qwen) writes a first-frame *image* prompt from | |
| # `image_instruction` + `original_input`. | |
| # 2. The image model renders that first frame. | |
| # 3. The frame is uploaded to R2 (same path as the Generate tab). | |
| # 4. LLM (Qwen) writes a *video* prompt from `video_instruction` + | |
| # `original_input`, grounded on the rendered first frame image. | |
| # It returns the video prompt and the first-frame URL (+ R2 filekey/bucket and | |
| # the intermediate image prompt). Exposed via `gr.api` so it has no Gradio UI. | |
| # | |
| # Like the other streaming endpoints it is a generator yielding one structured | |
| # dict per frame (so progress crosses the /call SSE stream, see _progress); the | |
| # same dict carries the results, filled in as each stage completes, and the final | |
| # yield has `done=True`. Stage weights below sum to 1.0. | |
| _P2V_W_FRAME_PROMPT = 0.30 # LLM writes the first-frame image prompt | |
| _P2V_W_IMAGE = 0.30 # image model renders the first frame | |
| _P2V_W_UPLOAD = 0.05 # R2 upload | |
| _P2V_W_VIDEO_PROMPT = 0.35 # LLM writes the video prompt | |
| def _compose_instruction(instruction, original_input): | |
| """Join an instruction with the raw idea into a single LLM message.""" | |
| instruction = (instruction or "").strip() | |
| original_input = (original_input or "").strip() | |
| if instruction and original_input: | |
| return f"{instruction}\n\nInput:\n{original_input}" | |
| return instruction or original_input | |
| def prompt_to_video_assets( | |
| original_input: str, | |
| image_instruction: str, | |
| video_instruction: str, | |
| model_name: str = MODEL_ZIMAGE, | |
| height: int = 1024, | |
| width: int = 1024, | |
| num_inference_steps: int = 9, | |
| guidance_scale: float = 0.0, | |
| seed: int = 42, | |
| randomize_seed: bool = True, | |
| reasoning: str = "Off", | |
| max_new_tokens: int = 512, | |
| system_prompt: str = "", | |
| request: gr.Request = None, | |
| ) -> dict: | |
| """Text -> first-frame image (uploaded to R2) + video generation prompt. | |
| Streams progress and returns a dict with ``video_prompt``, ``first_frame_url`` | |
| (a presigned GET to the **WebP** first frame, usable directly), | |
| ``r2_filekey``/``r2_bucket`` (for callers that resolve their own public URL), | |
| ``first_frame_prompt`` (the intermediate image prompt), ``moderation`` (the | |
| VLM image-context check on the rendered frame) and ``seed``. Has no UI | |
| (registered via ``gr.api``). | |
| """ | |
| state = { | |
| "stage": "frame_prompt", "p": 0.0, "step": 0, "total": 0, "label": "", | |
| "first_frame_prompt": None, | |
| "first_frame_url": None, | |
| "r2_filekey": None, | |
| "r2_bucket": None, | |
| "video_prompt": None, | |
| "moderation": None, | |
| "seed": None, | |
| "done": False, | |
| "error": None, | |
| } | |
| def frame(stage, base, span, frac=1.0, step=0, total=0, label=""): | |
| state.update( | |
| stage=stage, | |
| p=max(0.0, min(1.0, base + span * max(0.0, min(1.0, frac)))), | |
| step=int(step), total=int(total), label=label, | |
| ) | |
| return dict(state) | |
| if not (original_input or "").strip(): | |
| state["error"] = "original_input is required" | |
| state["done"] = True | |
| yield dict(state) | |
| return | |
| # --- Stage 1: LLM writes the first-frame image prompt --------------------- | |
| base = 0.0 | |
| frame_prompt = "" | |
| for ev in vlm_chat( | |
| _compose_instruction(image_instruction, original_input), | |
| None, reasoning, max_new_tokens, system_prompt, | |
| ): | |
| if ev[0] == "progress": | |
| _, produced, budget, partial = ev | |
| frame_prompt = partial | |
| yield frame("frame_prompt", base, _P2V_W_FRAME_PROMPT, | |
| frac=produced / max(budget, 1), step=produced, total=budget, | |
| label="Writing first-frame prompt") | |
| else: | |
| frame_prompt = ev[1] | |
| frame_prompt = (frame_prompt or "").strip() | |
| state["first_frame_prompt"] = frame_prompt | |
| yield frame("frame_prompt", base, _P2V_W_FRAME_PROMPT, label="First-frame prompt ready") | |
| # --- Stage 2: image model renders the first frame ------------------------- | |
| base += _P2V_W_FRAME_PROMPT | |
| image, used_seed, moderation = None, seed, None | |
| use_negative_prompt = (model_name == MODEL_NOOBXL) | |
| for ev in generate_image( | |
| model_name, frame_prompt, NOOBXL_NEGATIVE, use_negative_prompt, | |
| height, width, num_inference_steps, guidance_scale, seed, randomize_seed, | |
| ): | |
| if ev[0] == "progress": | |
| _, step, total = ev | |
| yield frame("image", base, _P2V_W_IMAGE, | |
| frac=step / max(total, 1), step=step, total=total, | |
| label=f"Rendering first frame {step}/{total}") | |
| else: | |
| _, image, used_seed, moderation, _ = ev | |
| state["seed"] = int(used_seed) | |
| state["moderation"] = moderation | |
| # --- Stage 3: upload the first frame to R2 (WebP) ------------------------- | |
| base += _P2V_W_IMAGE | |
| yield frame("upload", base, _P2V_W_UPLOAD, frac=0.1, label="Uploading first frame") | |
| uid = r2_uploader.uid_from_request(request) | |
| buf = io.BytesIO() | |
| image.save(buf, format="WEBP", quality=95, method=6) | |
| params = { | |
| "model": model_name, | |
| "prompt": frame_prompt, | |
| "negative_prompt": (NOOBXL_NEGATIVE if use_negative_prompt else ""), | |
| "height": int(height), | |
| "width": int(width), | |
| "num_inference_steps": int(num_inference_steps), | |
| "guidance_scale": float(guidance_scale), | |
| "seed": int(used_seed), | |
| "uid": uid, | |
| "source": "prompt_to_video_assets", | |
| "original_input": original_input, | |
| } | |
| up = r2_uploader.upload_asset( | |
| namespace=R2_NAMESPACE, prompt=frame_prompt, params=params, | |
| data=buf.getvalue(), ext=".webp", content_type="image/webp", uid=uid, | |
| ) | |
| if up.get("ok"): | |
| state["r2_filekey"] = up["filekey"] | |
| state["r2_bucket"] = up["bucket"] | |
| state["first_frame_url"] = r2_uploader.presign_get_url(up["filekey"], up["bucket"]) | |
| else: | |
| state["error"] = up.get("error", "R2 upload failed") | |
| # --- Stage 4: LLM writes the video prompt (grounded on the frame) --------- | |
| base += _P2V_W_UPLOAD | |
| video_prompt = "" | |
| for ev in vlm_chat( | |
| _compose_instruction(video_instruction, original_input), | |
| image, reasoning, max_new_tokens, system_prompt, | |
| ): | |
| if ev[0] == "progress": | |
| _, produced, budget, partial = ev | |
| video_prompt = partial | |
| yield frame("video_prompt", base, _P2V_W_VIDEO_PROMPT, | |
| frac=produced / max(budget, 1), step=produced, total=budget, | |
| label="Writing video prompt") | |
| else: | |
| video_prompt = ev[1] | |
| state["video_prompt"] = (video_prompt or "").strip() | |
| state["done"] = True | |
| yield frame("done", 1.0, 0.0, label="Done") | |
| # ============================================================================= | |
| # Image -> video prompt + image moderation check (no first frame; UI-less) | |
| # ============================================================================= | |
| # The image-to-video counterpart of `prompt_to_video_assets`: the caller already | |
| # HAS the first frame (a user-uploaded image), so this endpoint just (1) screens | |
| # that image with the VLM image-context check and (2) writes the motion prompt | |
| # the video model will use — both inside one GPU window. The generator calls this | |
| # before the LTX/Wan video node, records the check against the produced asset, | |
| # and feeds the prompt forward. | |
| # | |
| # Outputs (positional, matching the structured-progress contract): | |
| # 0: video_prompt (str) — the motion prompt for the video model | |
| # 1: status (JSON dict) — { "moderation": <image-context check> } | |
| # 2: progress (JSON dict) — live 0..1 progress for SSE consumers | |
| _I2V_DEFAULT_INSTRUCTION = ( | |
| "Look at this image. Write a concise image-to-video motion prompt focusing on " | |
| "the characters' actions, expressions and the scene, describing how it should " | |
| "animate: camera movement plus subject motion. Respond with ONE single line of " | |
| "one or two sentences containing ONLY the motion prompt. No headings, no " | |
| "markdown, no bullet points, no preamble." | |
| ) | |
| def image_to_video_assets( | |
| image, | |
| video_instruction, | |
| reasoning, | |
| max_new_tokens, | |
| system_prompt="", | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| """User-uploaded image -> motion prompt + image-context (moderation) check. | |
| Generator yielding ``(video_prompt, status, progress)`` per frame so the | |
| motion prompt streams and a downstream consumer reading the progress index | |
| sees live progress over SSE. ``status`` carries the moderation check once | |
| available; the final ``complete`` frame holds the clean prompt at index 0. | |
| Bound with ``api_name="image_to_video_assets"`` (hidden trigger, no UI). | |
| """ | |
| instruction = (video_instruction or "").strip() or _I2V_DEFAULT_INSTRUCTION | |
| budget = int(max_new_tokens) if max_new_tokens else 256 | |
| if image is None: | |
| yield "", {"moderation": None, "error": "image is required"}, \ | |
| _progress("done", 1.0, label="No image") | |
| return | |
| status = {"moderation": None} | |
| video_prompt = "" | |
| yield video_prompt, status, _progress("check", 0.02, label="Screening image") | |
| for ev in _image_video_check_gpu(image, instruction, reasoning, budget, system_prompt): | |
| if ev[0] == "moderation": | |
| status = {"moderation": ev[1]} | |
| yield video_prompt, status, _progress("check", 0.2, label="Image checked") | |
| elif ev[0] == "progress": | |
| _, produced, budget_, partial = ev | |
| video_prompt = partial | |
| frac = 0.2 + 0.79 * min(1.0, produced / max(budget_, 1)) | |
| yield video_prompt, status, _progress( | |
| "video_prompt", frac, produced, budget_, "Writing video prompt") | |
| else: # ("text", final) | |
| video_prompt = ev[1] | |
| yield video_prompt.strip(), status, _progress("done", 1.0, label="Done") | |
| def apply_model_defaults(model_name): | |
| """Update sliders and the negative-prompt controls when the model changes. | |
| Sampler defaults + capabilities come from the model's registry entry, so this | |
| stays correct for any model the registry knows about.""" | |
| e = _model_entry(model_name) | |
| supports_neg = bool(e.get("negative")) | |
| # Only the SDXL (illustrious) family exposes an adjustable guidance scale; the | |
| # guidance-free / internally-guided families pin it. | |
| guided = e.get("family") == "illustrious" | |
| neg_info = ("Applied via classifier-free guidance" | |
| if supports_neg else | |
| "This model is guidance-free — negative prompt is ignored") | |
| return ( | |
| gr.update(value=e.get("steps", 9)), | |
| gr.update(value=e.get("guidance", 0.0), interactive=guided, | |
| info="Classifier-free guidance" if guided else "This model fixes its own guidance"), | |
| gr.update(value=e.get("height", 1024)), | |
| gr.update(value=e.get("width", 1024)), | |
| gr.update(interactive=supports_neg), | |
| gr.update(interactive=supports_neg, info=neg_info), | |
| ) | |
| # Example prompts (work well across both models) | |
| examples = [ | |
| ["Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp, bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda, blurred colorful distant lights."], | |
| ["A majestic dragon soaring through clouds at sunset, scales shimmering with iridescent colors, detailed fantasy art style"], | |
| ["Cozy coffee shop interior, warm lighting, rain on windows, plants on shelves, vintage aesthetic, photorealistic"], | |
| ["1girl, nahida (genshin impact), white dress, green hair, looking at viewer, masterpiece, best quality, very aesthetic"], | |
| ["Portrait of a wise old wizard with a long white beard, holding a glowing crystal staff, magical forest background"], | |
| ] | |
| # Custom theme with modern aesthetics (Gradio 6) | |
| custom_theme = gr.themes.Soft( | |
| primary_hue="yellow", | |
| secondary_hue="amber", | |
| neutral_hue="slate", | |
| font=gr.themes.GoogleFont("Inter"), | |
| text_size="lg", | |
| spacing_size="md", | |
| radius_size="lg" | |
| ).set( | |
| button_primary_background_fill="*primary_500", | |
| button_primary_background_fill_hover="*primary_600", | |
| block_title_text_weight="600", | |
| ) | |
| APP_CSS = """ | |
| .header-text h1 { | |
| font-size: 2.5rem !important; | |
| font-weight: 700 !important; | |
| margin-bottom: 0.5rem !important; | |
| background: linear-gradient(135deg, #fbbf24 0%, #f59e0b 100%); | |
| -webkit-background-clip: text; | |
| -webkit-text-fill-color: transparent; | |
| background-clip: text; | |
| } | |
| .header-text p { | |
| font-size: 1.1rem !important; | |
| color: #64748b !important; | |
| margin-top: 0 !important; | |
| } | |
| .footer-text { padding: 1rem 0; } | |
| .footer-text a { | |
| color: #f59e0b !important; | |
| text-decoration: none !important; | |
| font-weight: 500; | |
| } | |
| .footer-text a:hover { text-decoration: underline !important; } | |
| @media (max-width: 768px) { | |
| .header-text h1 { font-size: 1.8rem !important; } | |
| .header-text p { font-size: 1rem !important; } | |
| } | |
| button, .gr-button { transition: all 0.2s ease !important; } | |
| button:hover, .gr-button:hover { | |
| transform: translateY(-1px); | |
| box-shadow: 0 4px 12px rgba(0, 0, 0, 0.15) !important; | |
| } | |
| .gradio-container { | |
| max-width: 1400px !important; | |
| margin: 0 auto !important; | |
| } | |
| """ | |
| # Build the Gradio interface. In Gradio 6.x, theme/css/footer_links/mcp_server | |
| # are arguments to demo.launch() (see bottom of file), not to Blocks(). | |
| with gr.Blocks(fill_height=True) as demo: | |
| gr.Markdown( | |
| """ | |
| # 🎨 Image Studio | |
| **Ultra-fast AI image generation** • Choose a model and create stunning images | |
| """, | |
| elem_classes="header-text" | |
| ) | |
| with gr.Tabs(): | |
| with gr.Tab("🎨 Generate"): | |
| with gr.Row(equal_height=False): | |
| # Left column - Input controls | |
| with gr.Column(scale=1, min_width=320): | |
| # Only the models loaded this boot (IMAGE_MODELS env) are | |
| # selectable. The first one seeds the initial slider defaults. | |
| _enabled_labels = list(PIPELINES) | |
| _first_entry = _model_entry(_enabled_labels[0]) | |
| model_name = gr.Dropdown( | |
| choices=_enabled_labels, | |
| value=_enabled_labels[0], | |
| label="🧠 Model", | |
| info="Pick a model — sampler defaults update to match.", | |
| ) | |
| prompt = gr.Textbox( | |
| label="✨ Your Prompt", | |
| placeholder="Describe the image you want to create...", | |
| lines=5, | |
| max_lines=10, | |
| autofocus=True, | |
| ) | |
| with gr.Accordion("⚙️ Advanced Settings", open=False): | |
| # Negative prompt is always exposed; it is only interactive | |
| # (and only applied) for models that support it — seeded from | |
| # the first loaded model. | |
| _default_supports_neg = bool(_first_entry.get("negative")) | |
| use_negative_prompt = gr.Checkbox( | |
| label="Use Negative Prompt", | |
| value=True, | |
| interactive=_default_supports_neg, | |
| info="Only models with classifier-free guidance use this", | |
| ) | |
| negative_prompt = gr.Textbox( | |
| label="🚫 Negative Prompt", | |
| value=_first_entry.get("default_negative", ""), | |
| lines=3, | |
| max_lines=6, | |
| interactive=_default_supports_neg, | |
| info="This model is guidance-free — negative prompt is ignored", | |
| ) | |
| with gr.Row(): | |
| height = gr.Slider( | |
| minimum=512, maximum=2048, value=_first_entry.get("height", 1024), step=64, | |
| label="Height", info="Image height in pixels", | |
| ) | |
| width = gr.Slider( | |
| minimum=512, maximum=2048, value=_first_entry.get("width", 1024), step=64, | |
| label="Width", info="Image width in pixels", | |
| ) | |
| num_inference_steps = gr.Slider( | |
| minimum=1, maximum=50, value=_first_entry.get("steps", 9), step=1, | |
| label="Inference Steps", | |
| info="Z-Image: ~9 • SDXL/Illustrious: ~28 • Anima: ~24", | |
| ) | |
| guidance_scale = gr.Slider( | |
| minimum=0.0, maximum=10.0, value=_first_entry.get("guidance", 0.0), step=0.1, | |
| label="Guidance Scale", | |
| info="Only the SDXL/Illustrious family uses an adjustable guidance scale", | |
| interactive=(_first_entry.get("family") == "illustrious"), | |
| ) | |
| with gr.Row(): | |
| randomize_seed = gr.Checkbox(label="🎲 Random Seed", value=True) | |
| seed = gr.Number(label="Seed", value=42, precision=0, minimum=0, maximum=MAX_SEED) | |
| prompt_improve_instruction = gr.Textbox( | |
| label="✨ Prompt Improve Instruction", | |
| value="", | |
| lines=2, | |
| max_lines=4, | |
| placeholder="Leave empty to use your prompt as-is. If set, the VLM " | |
| "rewrites your prompt with this instruction first, then " | |
| "renders the improved prompt.", | |
| info="Optional — when non-empty, your prompt is improved before generation", | |
| ) | |
| generate_btn = gr.Button( | |
| "🚀 Generate Image", variant="primary", size="lg", scale=1 | |
| ) | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[prompt], | |
| label="💡 Try these prompts", | |
| examples_per_page=5, | |
| ) | |
| # Right column - Output | |
| with gr.Column(scale=1, min_width=320): | |
| output_image = gr.Image( | |
| label="Generated Image", | |
| type="pil", | |
| format="webp", | |
| show_label=False, | |
| height=600, | |
| buttons=["download", "share"], | |
| ) | |
| used_seed = gr.Number( | |
| label="🎲 Seed Used", interactive=False, container=True, | |
| ) | |
| r2_status = gr.JSON(label="☁️ R2 Upload") | |
| # Hidden structured-progress channel (index 3 of generate_image | |
| # outputs). Surfaces every yield as an SSE `generating` frame. | |
| gen_progress = gr.JSON(label="progress", visible=False) | |
| with gr.Tab("💬 Prompt Assistant"): | |
| gr.Markdown( | |
| "Ask the assistant model a single question — with or without an image. " | |
| "Turn a rough idea into a rich prompt, or upload a reference image and " | |
| "ask the model to describe it as a prompt. Use the **System Prompt** to " | |
| "steer its tone/role." | |
| ) | |
| with gr.Row(equal_height=False): | |
| with gr.Column(scale=1, min_width=320): | |
| vlm_image = gr.Image( | |
| label="🖼️ Image (optional)", | |
| type="pil", | |
| height=320, | |
| ) | |
| vlm_system = gr.Textbox( | |
| label="🧭 System Prompt (optional)", | |
| placeholder="e.g. 'You are an expert Danbooru tag writer. Output only comma-separated tags.'", | |
| lines=2, | |
| max_lines=8, | |
| ) | |
| vlm_prompt = gr.Textbox( | |
| label="💬 Your Question", | |
| placeholder="e.g. 'Improve this prompt: a cat in a hat' — or attach an image and ask 'describe this as a prompt'", | |
| lines=4, | |
| max_lines=12, | |
| ) | |
| vlm_reasoning = gr.Radio( | |
| choices=["Off", "On"], | |
| value="Off", | |
| label="🧠 Reasoning", | |
| info="Off: direct answer, best for prompts • On: think step-by-step first (slower, raise max tokens)", | |
| ) | |
| with gr.Accordion("⚙️ Settings", open=False): | |
| vlm_max_tokens = gr.Slider( | |
| minimum=64, maximum=2048, value=512, step=64, | |
| label="Max New Tokens", | |
| info="Upper bound on the answer length", | |
| ) | |
| vlm_btn = gr.Button( | |
| "✨ Ask", variant="primary", size="lg", scale=1 | |
| ) | |
| with gr.Column(scale=1, min_width=320): | |
| vlm_output = gr.Textbox( | |
| label="🤖 Answer", | |
| lines=20, | |
| ) | |
| # Hidden structured-progress channel (index 1 of prompt_assistant | |
| # outputs). | |
| vlm_progress = gr.JSON(label="progress", visible=False) | |
| gr.Markdown( | |
| """ | |
| --- | |
| <div style="text-align: center; opacity: 0.7; font-size: 0.9em; margin-top: 1rem;"> | |
| <strong>Models:</strong> | |
| Moody Pro Mix (ZIT V12 DPO, Z-Image-Turbo finetune) • | |
| One Obsession v2.1 Anime (Illustrious / SDXL) • | |
| <strong>Demo by:</strong> <a href="https://x.com/realmrfakename" target="_blank">@mrfakename</a> | |
| </div> | |
| """, | |
| elem_classes="footer-text" | |
| ) | |
| # Update sliders / helper visibility when the model changes | |
| model_name.change( | |
| fn=apply_model_defaults, | |
| inputs=[model_name], | |
| outputs=[num_inference_steps, guidance_scale, height, width, | |
| use_negative_prompt, negative_prompt], | |
| ) | |
| gen_inputs = [ | |
| model_name, prompt, negative_prompt, use_negative_prompt, | |
| height, width, num_inference_steps, guidance_scale, seed, randomize_seed, | |
| prompt_improve_instruction, | |
| ] | |
| # api_name pinned so the generator's "/generate_image" endpoint keeps | |
| # resolving even though the click now runs the upload wrapper. | |
| generate_btn.click( | |
| fn=generate_and_upload, inputs=gen_inputs, | |
| outputs=[output_image, used_seed, r2_status, gen_progress], | |
| api_name="generate_image", | |
| ) | |
| prompt.submit( | |
| fn=generate_and_upload, inputs=gen_inputs, | |
| outputs=[output_image, used_seed, r2_status, gen_progress], | |
| ) | |
| # Prompt Assistant — single-turn, optional image, optional system prompt | |
| vlm_inputs = [vlm_prompt, vlm_image, vlm_reasoning, vlm_max_tokens, vlm_system] | |
| vlm_btn.click( | |
| fn=assistant_chat, inputs=vlm_inputs, outputs=[vlm_output, vlm_progress], | |
| api_name="prompt_assistant", | |
| ) | |
| vlm_prompt.submit( | |
| fn=assistant_chat, inputs=vlm_inputs, outputs=[vlm_output, vlm_progress], | |
| ) | |
| # UI-less endpoint: user image -> motion video prompt + image-context check. | |
| # Bound via a hidden trigger so the image input deserializes the uploaded | |
| # FileData into a PIL image (same proven path as the Prompt Assistant), then | |
| # exposed to the generator as api_name="image_to_video_assets". | |
| with gr.Row(visible=False): | |
| i2v_image = gr.Image(label="i2v image", type="pil") | |
| i2v_instruction = gr.Textbox(label="i2v instruction", value="") | |
| i2v_reasoning = gr.Radio(choices=["Off", "On"], value="Off", label="i2v reasoning") | |
| i2v_max_tokens = gr.Slider(minimum=64, maximum=1024, value=256, step=64, label="i2v max tokens") | |
| i2v_system = gr.Textbox(label="i2v system prompt", value="") | |
| i2v_btn = gr.Button("i2v", visible=False) | |
| i2v_prompt_out = gr.Textbox(label="i2v prompt") | |
| i2v_status_out = gr.JSON(label="i2v status") | |
| i2v_progress_out = gr.JSON(label="i2v progress") | |
| i2v_btn.click( | |
| fn=image_to_video_assets, | |
| inputs=[i2v_image, i2v_instruction, i2v_reasoning, i2v_max_tokens, i2v_system], | |
| outputs=[i2v_prompt_out, i2v_status_out, i2v_progress_out], | |
| api_name="image_to_video_assets", | |
| ) | |
| # UI-less combined endpoint: text -> first-frame image (R2) + video prompt. | |
| # `gr.api` derives its schema from the function's type hints and registers no | |
| # components, so it adds an API route without touching the visible UI. | |
| gr.api(prompt_to_video_assets, api_name="prompt_to_video_assets") | |
| if __name__ == "__main__": | |
| demo.launch( | |
| theme=custom_theme, | |
| css=APP_CSS, | |
| footer_links=["api", "gradio"], | |
| mcp_server=True, | |
| ) | |