numzoo / image_generator.py
goumsss's picture
Reduce multi-animal merging: 768px + side-by-side layout cue
9d5c943
Raw
History Blame Contribute Delete
17.7 kB
import random
import os
import sys
import threading
import asyncio
# ---------------------------------------------------------------------------
# ZeroGPU shim — spaces MUST be imported BEFORE torch.
# On HF ZeroGPU, spaces patches CUDA initialisation; that patch must land
# before torch is imported or GPU calls can silently mis-behave.
# Locally spaces isn't installed, so we fall back to a no-op @GPU decorator
# that makes the same code run unchanged on MPS / CPU.
# ---------------------------------------------------------------------------
try:
import spaces # type: ignore
GPU = spaces.GPU
ON_ZEROGPU = True
except Exception:
def GPU(*dargs, **dkwargs): # noqa: N802 — mirror the spaces.GPU API
def wrap(fn):
return fn
if len(dargs) == 1 and callable(dargs[0]) and not dkwargs:
return dargs[0]
return wrap
ON_ZEROGPU = False
import torch
# Patch for torch < 2.4 which lacks torch.xpu (required by diffusers >= 0.30)
if not hasattr(torch, "xpu"):
class _MockXPU:
is_available = staticmethod(lambda: False)
device_count = staticmethod(lambda: 0)
empty_cache = staticmethod(lambda: None)
manual_seed = staticmethod(lambda seed: None)
reset_peak_memory_stats = staticmethod(lambda: None)
max_memory_allocated = staticmethod(lambda: 0)
synchronize = staticmethod(lambda: None)
torch.xpu = _MockXPU()
# ---------------------------------------------------------------------------
# Detect HuggingFace Spaces
# ---------------------------------------------------------------------------
IS_HF_SPACE = os.environ.get("SPACE_ID") is not None
if IS_HF_SPACE:
# Suppress Python 3.13 asyncio GC bug (Invalid file descriptor: -1)
_orig_unraisable = sys.unraisablehook
def _unraisable_hook(args):
if args.exc_type is ValueError and "Invalid file descriptor" in str(args.exc_value):
return
_orig_unraisable(args)
sys.unraisablehook = _unraisable_hook
# ---------------------------------------------------------------------------
# Persistent storage cache — survives sleep/restart on HF Spaces
# ---------------------------------------------------------------------------
# The /data mount repeatedly gets poisoned: partial downloads leave files where
# huggingface_hub/xet expect directories ([Errno 20] Not a directory), and the
# cache root itself can end up as a file with I/O errors. This block is
# bulletproof: any failure setting up /data falls back to the ephemeral
# container cache (slower cold start, but always works) and NEVER crashes the
# import — a dead import takes the whole app down.
_MODEL_ID = "black-forest-labs/FLUX.2-klein-4B"
_MODEL_DIR = "models--black-forest-labs--FLUX.2-klein-4B"
def _force_remove(path):
"""Remove path whether it's a file, dir, or broken — never raises."""
import shutil
try:
if os.path.isdir(path) and not os.path.islink(path):
shutil.rmtree(path)
elif os.path.exists(path) or os.path.islink(path):
os.remove(path)
except Exception as e:
print(f"⚠️ Could not remove {path}: {e}")
def _setup_persistent_cache():
"""Point HF cache at /data if usable. Returns True if persistence is active."""
cache_dir = "/data/hf_cache_v4" # bump path to escape any poisoned older cache
hub_dir = os.path.join(cache_dir, "hub")
# Probe: can we create a clean directory tree at cache_dir? If the path is a
# poisoned file or the mount errors, bail out to ephemeral caching.
try:
if os.path.exists(cache_dir) or os.path.islink(cache_dir):
# Validate existing cache: a real model snapshot with model_index.json
snaps = os.path.join(hub_dir, _MODEL_DIR, "snapshots")
valid = os.path.isdir(snaps) and any(
os.path.isfile(os.path.join(snaps, s, "model_index.json"))
for s in os.listdir(snaps)
) if os.path.isdir(snaps) else False
if not valid:
print(f"Cache at {cache_dir} is incomplete/poisoned — wiping")
_force_remove(cache_dir)
os.makedirs(hub_dir, exist_ok=True)
except Exception as e:
print(f"⚠️ /data cache unusable ({e}) — using ephemeral container cache")
return False
os.environ["HF_HOME"] = cache_dir
os.environ["HF_HUB_CACHE"] = hub_dir
print(f"Persistent cache active → {cache_dir}")
return True
def _is_model_cached():
snaps = os.path.join(os.environ.get("HF_HUB_CACHE", ""), _MODEL_DIR, "snapshots")
if not os.path.isdir(snaps):
return False
return any(
os.path.isfile(os.path.join(snaps, s, "model_index.json"))
for s in os.listdir(snaps)
)
def _predownload_model():
"""Blocking pre-download so the first @spaces.GPU call is fast. Never raises."""
if _is_model_cached():
print(f"✅ FLUX.2-klein-4B already cached — skipping download")
return
print(f"Downloading FLUX.2-klein-4B…")
try:
from huggingface_hub import snapshot_download
hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
result = {}
def _download():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
snapshot_download(
_MODEL_ID,
token=hf_token,
ignore_patterns=["*.msgpack", "*.h5", "flax_model*"],
)
result["ok"] = True
except Exception as e:
result["error"] = e
finally:
try: loop.close()
except: pass
t = threading.Thread(target=_download, daemon=True)
t.start(); t.join()
if "error" in result:
raise result["error"]
print("✅ FLUX.2-klein-4B cached successfully")
except Exception as e:
print(f"⚠️ Pre-cache warning (will retry at generation time): {e}")
if IS_HF_SPACE:
if os.path.isdir("/data"):
_setup_persistent_cache() # falls back to ephemeral on any failure
_predownload_model()
else:
print("Local mode.")
# ---------------------------------------------------------------------------
# Emoji → descriptive text maps (fed into FLUX prompt)
# ---------------------------------------------------------------------------
# Keep keys + order in sync with ANIMAL_EMOJIS in app.py (grouped by lookalike).
ANIMAL_MAP: dict[str, str] = {
"🐶": "puppy", "🐱": "kitten", "🐰": "bunny",
"🐭": "baby mouse", "🦊": "baby fox", "🐺": "baby wolf",
"🐻": "baby bear", "🐼": "baby panda", "🐨": "baby koala",
"🦁": "baby lion", "🐯": "baby tiger", "🐷": "baby pig",
"🐮": "baby cow", "🐴": "pony", "🐑": "baby lamb",
"🐸": "baby frog", "🦦": "baby otter", "🐧": "baby penguin",
"🐙": "baby octopus", "🦋": "butterfly", "🐝": "bee",
"🐞": "ladybug", "🦄": "unicorn", "🐉": "baby dragon",
}
# Keep keys + order in sync with PLACE_EMOJIS in app.py.
PLACE_MAP: dict[str, str] = {
"🌊": "on a sunny beach with ocean waves",
"🏔️": "on a snowy mountain top",
"🌸": "in a cherry blossom garden",
"🌈": "under a rainbow",
"🌙": "on a glowing crescent moon",
"⭐": "surrounded by sparkling stars",
"🌴": "on a tropical island",
"🏡": "in a cosy cottage garden",
"🌺": "in a field of tropical flowers",
"🍄": "in an enchanted mushroom forest",
"🏜️": "in a red rock canyon",
"🏰": "in a fairytale castle",
"🎢": "at a fun theme park",
"⛺": "in a cosy camping tent",
"🪩": "under a sparkling disco ball",
"🎪": "at a colourful circus",
"⛵": "in a little boat on calm water",
"🚀": "aboard a space rocket among the stars",
}
# Shared style suffix — must stay in sync with STYLE in scripts/generate_dataset.py.
# This exact string appears in every training caption so the LoRA learns it as a trigger.
NUMZOO_STYLE = (
"kawaii children's book illustration, pastel anime art style, "
"soft painterly lighting, detailed rich background with warm fairy lights, "
"cozy magical atmosphere, cute chibi character with big sparkling eyes, "
"soft pastel color palette, highly detailed scene, no text"
)
# ---------------------------------------------------------------------------
# Prompt builder
# ---------------------------------------------------------------------------
def _join_animals(parts: list[str]) -> str:
"""Join 1–5 animals; each one after the first is prefixed 'a cute' so the model
renders them as distinct subjects, e.g. 'puppy, a cute kitten and a cute bunny'."""
if len(parts) == 1:
return parts[0]
head, rest = parts[0], [f"a cute {p}" for p in parts[1:]]
if len(rest) == 1:
return f"{head} and {rest[0]}"
return f"{head}, " + ", ".join(rest[:-1]) + f" and {rest[-1]}"
def _join_places(parts: list[str]) -> str:
"""Join 1–3 phrases: 'a', 'a and b', 'a, b and c'."""
if len(parts) == 1:
return parts[0]
if len(parts) == 2:
return f"{parts[0]} and {parts[1]}"
return f"{parts[0]}, {parts[1]} and {parts[2]}"
def build_subject(animals: list[str], places: list[str]) -> str:
"""The 'A cute {animals} {places}' core — SHARED with scripts/generate_dataset.py
so training captions and live prompts use the exact same structure & vocabulary."""
animal_list = animals[:5] if animals else [random.choice(list(ANIMAL_MAP))]
place_list = places[:3] if places else [random.choice(list(PLACE_MAP))]
animal_text = _join_animals([ANIMAL_MAP.get(a, "bunny") for a in animal_list])
place_text = _join_places([PLACE_MAP.get(p, "in a magical garden") for p in place_list])
return f"A cute {animal_text} {place_text}"
# Trigger word the NumZoo LoRA learned to associate with the whole cozy aesthetic.
# Training captions were "NUMZOO. A cute {animals} {places}, <detail>", so the live
# prompt mirrors that prefix exactly — the trigger now carries the style, making the
# verbose NUMZOO_STYLE suffix redundant (kept as a constant for the dataset generator).
NUMZOO_TRIGGER = "NUMZOO"
def build_prompt(animals: list[str], places: list[str]) -> str:
# "no text" suppresses the model rendering the NUMZOO trigger word as a literal
# sign/title (the distilled 4-step klein is prone to this without it).
# With multiple animals, a layout hint reduces subject-merging (the model
# otherwise tends to fuse adjacent animals into one).
layout = ", side by side, each a separate full-body character" if len(animals) > 1 else ""
return f"{NUMZOO_TRIGGER}. {build_subject(animals, places)}{layout}, no text"
# ---------------------------------------------------------------------------
# Pipeline loader — two-phase, following the reference ZeroGPU pattern:
#
# Phase 1 · _load_pipeline_cpu() — from_pretrained to CPU RAM.
# Called at module scope (outside any @GPU function) so the weights are
# already resident when the first ZeroGPU call arrives. This keeps the
# model-download cost out of the 60 s GPU-runtime budget and prevents
# cold-start timeouts on the very first generation.
#
# Phase 2 · get_pipeline() — .to(device).
# Must be called INSIDE a @GPU-decorated function (where a GPU slice is
# guaranteed). Moving already-loaded CPU tensors to CUDA is fast (~1 s)
# and comfortably within the budget.
#
# Locally (MPS / CPU) both phases happen inside generate_reward_image because
# the no-op @GPU decorator doesn't impose any budget constraint.
# ---------------------------------------------------------------------------
_pipe = None
_pipe_lock = threading.Lock() # serialize loading — pregenerate + on-demand can race
_gen_lock = threading.Lock() # serialize inference — one device can't run two forwards at once
# Single-slot cache so the on-demand fallback reuses the in-flight pre-generation
# for the SAME reward instead of generating it twice. Keyed by reward_id (level);
# a different reward_id always regenerates, so rewards stay varied across levels.
_recent_reward = {"id": None, "image": None}
_use_mps = (not IS_HF_SPACE) and (not torch.cuda.is_available()) and torch.backends.mps.is_available()
_dtype = torch.float16 if _use_mps else torch.bfloat16 # float16 on MPS (bfloat16 unsupported)
def _load_pipeline_cpu() -> None:
"""Phase 1: load model weights into CPU RAM. Thread-safe — pregenerate and
on-demand handlers can call this concurrently; the lock prevents a double load
(which previously applied the LoRA twice and crashed with a meta-tensor error)."""
global _pipe
if _pipe is not None:
return
with _pipe_lock:
if _pipe is not None: # re-check inside the lock
return
from diffusers import Flux2KleinPipeline
hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
print(f"Loading FLUX.2-klein-4B on CPU… (dtype={_dtype}, token={'set' if hf_token else 'NOT SET'})")
pipe = Flux2KleinPipeline.from_pretrained(
_MODEL_ID,
torch_dtype=_dtype,
token=hf_token,
)
# Apply the NumZoo style LoRA (trained on FLUX.2-klein-base-4B, loads on distilled).
# Bundled in the repo via Git LFS. Loaded here at module/CPU scope so it stays out
# of the ZeroGPU 60 s budget. Never crash the import — fall back to the base model.
lora_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "lora", "numzoo_klein.safetensors")
if os.path.isfile(lora_path):
try:
pipe.load_lora_weights(lora_path)
print(f"✅ NumZoo LoRA loaded from {lora_path}")
except Exception as e:
print(f"⚠️ Could not load NumZoo LoRA ({e}) — using base model")
else:
print(f"⚠️ NumZoo LoRA not found at {lora_path} — using base model")
_pipe = pipe # publish only once fully built (incl. LoRA)
print("Pipeline loaded on CPU — ready for device placement.")
def get_pipeline():
"""Phase 2: move pipeline to the target device. Call inside @GPU on HF Spaces."""
global _pipe
if _pipe is None:
_load_pipeline_cpu() # fallback for local / first-call safety
if torch.cuda.is_available():
return _pipe.to("cuda")
if _use_mps:
return _pipe.to("mps")
return _pipe.to("cpu")
# ---------------------------------------------------------------------------
# Core generation (always wrapped in try/except)
# ---------------------------------------------------------------------------
def _generate(animals: list[str], places: list[str], reward_id=None):
try:
import time
prompt = build_prompt(animals, places)
# Serialize device placement + inference: pregenerate and on-demand handlers
# can fire concurrently, but one device (MPS/GPU slice) can't run two forward
# passes at once. ZeroGPU serializes @GPU calls for us; locally we must.
with _gen_lock:
# Reuse the pre-generated image for this reward instead of generating twice.
if reward_id is not None and _recent_reward["id"] == reward_id \
and _recent_reward["image"] is not None:
print(f"♻️ Reusing pre-generated image for reward {reward_id}")
return _recent_reward["image"], prompt
pipe = get_pipeline()
print(f"Generating | prompt: {prompt}")
t0 = time.time()
result = pipe(
prompt=prompt,
num_inference_steps=4,
guidance_scale=1.0, # klein-4B distilled: always 1.0 (not schnell's 0.0)
height=768, # 768 (vs 512) gives multiple subjects room → less merging
width=768,
)
print(f"✅ Generated in {time.time() - t0:.1f}s")
image = result.images[0]
if reward_id is not None:
_recent_reward["id"], _recent_reward["image"] = reward_id, image
return image, prompt
except Exception as e:
import traceback
print(f"[image_generator] ❌ generation failed: {e}")
print(traceback.format_exc())
return None, str(e)
# ---------------------------------------------------------------------------
# Module-scope CPU pre-load (HF Spaces only).
# Runs after the persistent cache is set up and the snapshot is downloaded,
# so from_pretrained finds the weights locally and completes quickly.
# Locally this is skipped — the pipeline loads lazily on first generate call.
# ---------------------------------------------------------------------------
if IS_HF_SPACE:
_load_pipeline_cpu()
# ---------------------------------------------------------------------------
# Public API — single function, @GPU decorator is a no-op locally.
# ---------------------------------------------------------------------------
@GPU(duration=60)
def generate_reward_image(animals: list[str], places: list[str], reward_id=None):
"""Generate a reward image. On HF Spaces runs inside a ZeroGPU slice;
locally the @GPU decorator is a no-op and MPS/CPU is used instead.
Pass reward_id so the pre-generate + on-demand paths dedupe the same reward."""
return _generate(animals, places, reward_id=reward_id)