import random import os import sys import threading import asyncio # --------------------------------------------------------------------------- # ZeroGPU shim — spaces MUST be imported BEFORE torch. # On HF ZeroGPU, spaces patches CUDA initialisation; that patch must land # before torch is imported or GPU calls can silently mis-behave. # Locally spaces isn't installed, so we fall back to a no-op @GPU decorator # that makes the same code run unchanged on MPS / CPU. # --------------------------------------------------------------------------- try: import spaces # type: ignore GPU = spaces.GPU ON_ZEROGPU = True except Exception: def GPU(*dargs, **dkwargs): # noqa: N802 — mirror the spaces.GPU API def wrap(fn): return fn if len(dargs) == 1 and callable(dargs[0]) and not dkwargs: return dargs[0] return wrap ON_ZEROGPU = False import torch # Patch for torch < 2.4 which lacks torch.xpu (required by diffusers >= 0.30) if not hasattr(torch, "xpu"): class _MockXPU: is_available = staticmethod(lambda: False) device_count = staticmethod(lambda: 0) empty_cache = staticmethod(lambda: None) manual_seed = staticmethod(lambda seed: None) reset_peak_memory_stats = staticmethod(lambda: None) max_memory_allocated = staticmethod(lambda: 0) synchronize = staticmethod(lambda: None) torch.xpu = _MockXPU() # --------------------------------------------------------------------------- # Detect HuggingFace Spaces # --------------------------------------------------------------------------- IS_HF_SPACE = os.environ.get("SPACE_ID") is not None if IS_HF_SPACE: # Suppress Python 3.13 asyncio GC bug (Invalid file descriptor: -1) _orig_unraisable = sys.unraisablehook def _unraisable_hook(args): if args.exc_type is ValueError and "Invalid file descriptor" in str(args.exc_value): return _orig_unraisable(args) sys.unraisablehook = _unraisable_hook # --------------------------------------------------------------------------- # Persistent storage cache — survives sleep/restart on HF Spaces # --------------------------------------------------------------------------- # The /data mount repeatedly gets poisoned: partial downloads leave files where # huggingface_hub/xet expect directories ([Errno 20] Not a directory), and the # cache root itself can end up as a file with I/O errors. This block is # bulletproof: any failure setting up /data falls back to the ephemeral # container cache (slower cold start, but always works) and NEVER crashes the # import — a dead import takes the whole app down. _MODEL_ID = "black-forest-labs/FLUX.2-klein-4B" _MODEL_DIR = "models--black-forest-labs--FLUX.2-klein-4B" def _force_remove(path): """Remove path whether it's a file, dir, or broken — never raises.""" import shutil try: if os.path.isdir(path) and not os.path.islink(path): shutil.rmtree(path) elif os.path.exists(path) or os.path.islink(path): os.remove(path) except Exception as e: print(f"⚠️ Could not remove {path}: {e}") def _setup_persistent_cache(): """Point HF cache at /data if usable. Returns True if persistence is active.""" cache_dir = "/data/hf_cache_v4" # bump path to escape any poisoned older cache hub_dir = os.path.join(cache_dir, "hub") # Probe: can we create a clean directory tree at cache_dir? If the path is a # poisoned file or the mount errors, bail out to ephemeral caching. try: if os.path.exists(cache_dir) or os.path.islink(cache_dir): # Validate existing cache: a real model snapshot with model_index.json snaps = os.path.join(hub_dir, _MODEL_DIR, "snapshots") valid = os.path.isdir(snaps) and any( os.path.isfile(os.path.join(snaps, s, "model_index.json")) for s in os.listdir(snaps) ) if os.path.isdir(snaps) else False if not valid: print(f"Cache at {cache_dir} is incomplete/poisoned — wiping") _force_remove(cache_dir) os.makedirs(hub_dir, exist_ok=True) except Exception as e: print(f"⚠️ /data cache unusable ({e}) — using ephemeral container cache") return False os.environ["HF_HOME"] = cache_dir os.environ["HF_HUB_CACHE"] = hub_dir print(f"Persistent cache active → {cache_dir}") return True def _is_model_cached(): snaps = os.path.join(os.environ.get("HF_HUB_CACHE", ""), _MODEL_DIR, "snapshots") if not os.path.isdir(snaps): return False return any( os.path.isfile(os.path.join(snaps, s, "model_index.json")) for s in os.listdir(snaps) ) def _predownload_model(): """Blocking pre-download so the first @spaces.GPU call is fast. Never raises.""" if _is_model_cached(): print(f"✅ FLUX.2-klein-4B already cached — skipping download") return print(f"Downloading FLUX.2-klein-4B…") try: from huggingface_hub import snapshot_download hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") result = {} def _download(): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: snapshot_download( _MODEL_ID, token=hf_token, ignore_patterns=["*.msgpack", "*.h5", "flax_model*"], ) result["ok"] = True except Exception as e: result["error"] = e finally: try: loop.close() except: pass t = threading.Thread(target=_download, daemon=True) t.start(); t.join() if "error" in result: raise result["error"] print("✅ FLUX.2-klein-4B cached successfully") except Exception as e: print(f"⚠️ Pre-cache warning (will retry at generation time): {e}") if IS_HF_SPACE: if os.path.isdir("/data"): _setup_persistent_cache() # falls back to ephemeral on any failure _predownload_model() else: print("Local mode.") # --------------------------------------------------------------------------- # Emoji → descriptive text maps (fed into FLUX prompt) # --------------------------------------------------------------------------- # Keep keys + order in sync with ANIMAL_EMOJIS in app.py (grouped by lookalike). ANIMAL_MAP: dict[str, str] = { "🐶": "puppy", "🐱": "kitten", "🐰": "bunny", "🐭": "baby mouse", "🦊": "baby fox", "🐺": "baby wolf", "🐻": "baby bear", "🐼": "baby panda", "🐨": "baby koala", "🦁": "baby lion", "🐯": "baby tiger", "🐷": "baby pig", "🐮": "baby cow", "🐴": "pony", "🐑": "baby lamb", "🐸": "baby frog", "🦦": "baby otter", "🐧": "baby penguin", "🐙": "baby octopus", "🦋": "butterfly", "🐝": "bee", "🐞": "ladybug", "🦄": "unicorn", "🐉": "baby dragon", } # Keep keys + order in sync with PLACE_EMOJIS in app.py. PLACE_MAP: dict[str, str] = { "🌊": "on a sunny beach with ocean waves", "🏔️": "on a snowy mountain top", "🌸": "in a cherry blossom garden", "🌈": "under a rainbow", "🌙": "on a glowing crescent moon", "⭐": "surrounded by sparkling stars", "🌴": "on a tropical island", "🏡": "in a cosy cottage garden", "🌺": "in a field of tropical flowers", "🍄": "in an enchanted mushroom forest", "🏜️": "in a red rock canyon", "🏰": "in a fairytale castle", "🎢": "at a fun theme park", "⛺": "in a cosy camping tent", "🪩": "under a sparkling disco ball", "🎪": "at a colourful circus", "⛵": "in a little boat on calm water", "🚀": "aboard a space rocket among the stars", } # Shared style suffix — must stay in sync with STYLE in scripts/generate_dataset.py. # This exact string appears in every training caption so the LoRA learns it as a trigger. NUMZOO_STYLE = ( "kawaii children's book illustration, pastel anime art style, " "soft painterly lighting, detailed rich background with warm fairy lights, " "cozy magical atmosphere, cute chibi character with big sparkling eyes, " "soft pastel color palette, highly detailed scene, no text" ) # --------------------------------------------------------------------------- # Prompt builder # --------------------------------------------------------------------------- def _join_animals(parts: list[str]) -> str: """Join 1–5 animals; each one after the first is prefixed 'a cute' so the model renders them as distinct subjects, e.g. 'puppy, a cute kitten and a cute bunny'.""" if len(parts) == 1: return parts[0] head, rest = parts[0], [f"a cute {p}" for p in parts[1:]] if len(rest) == 1: return f"{head} and {rest[0]}" return f"{head}, " + ", ".join(rest[:-1]) + f" and {rest[-1]}" def _join_places(parts: list[str]) -> str: """Join 1–3 phrases: 'a', 'a and b', 'a, b and c'.""" if len(parts) == 1: return parts[0] if len(parts) == 2: return f"{parts[0]} and {parts[1]}" return f"{parts[0]}, {parts[1]} and {parts[2]}" def build_subject(animals: list[str], places: list[str]) -> str: """The 'A cute {animals} {places}' core — SHARED with scripts/generate_dataset.py so training captions and live prompts use the exact same structure & vocabulary.""" animal_list = animals[:5] if animals else [random.choice(list(ANIMAL_MAP))] place_list = places[:3] if places else [random.choice(list(PLACE_MAP))] animal_text = _join_animals([ANIMAL_MAP.get(a, "bunny") for a in animal_list]) place_text = _join_places([PLACE_MAP.get(p, "in a magical garden") for p in place_list]) return f"A cute {animal_text} {place_text}" # Trigger word the NumZoo LoRA learned to associate with the whole cozy aesthetic. # Training captions were "NUMZOO. A cute {animals} {places}, ", so the live # prompt mirrors that prefix exactly — the trigger now carries the style, making the # verbose NUMZOO_STYLE suffix redundant (kept as a constant for the dataset generator). NUMZOO_TRIGGER = "NUMZOO" def build_prompt(animals: list[str], places: list[str]) -> str: # "no text" suppresses the model rendering the NUMZOO trigger word as a literal # sign/title (the distilled 4-step klein is prone to this without it). # With multiple animals, a layout hint reduces subject-merging (the model # otherwise tends to fuse adjacent animals into one). layout = ", side by side, each a separate full-body character" if len(animals) > 1 else "" return f"{NUMZOO_TRIGGER}. {build_subject(animals, places)}{layout}, no text" # --------------------------------------------------------------------------- # Pipeline loader — two-phase, following the reference ZeroGPU pattern: # # Phase 1 · _load_pipeline_cpu() — from_pretrained to CPU RAM. # Called at module scope (outside any @GPU function) so the weights are # already resident when the first ZeroGPU call arrives. This keeps the # model-download cost out of the 60 s GPU-runtime budget and prevents # cold-start timeouts on the very first generation. # # Phase 2 · get_pipeline() — .to(device). # Must be called INSIDE a @GPU-decorated function (where a GPU slice is # guaranteed). Moving already-loaded CPU tensors to CUDA is fast (~1 s) # and comfortably within the budget. # # Locally (MPS / CPU) both phases happen inside generate_reward_image because # the no-op @GPU decorator doesn't impose any budget constraint. # --------------------------------------------------------------------------- _pipe = None _pipe_lock = threading.Lock() # serialize loading — pregenerate + on-demand can race _gen_lock = threading.Lock() # serialize inference — one device can't run two forwards at once # Single-slot cache so the on-demand fallback reuses the in-flight pre-generation # for the SAME reward instead of generating it twice. Keyed by reward_id (level); # a different reward_id always regenerates, so rewards stay varied across levels. _recent_reward = {"id": None, "image": None} _use_mps = (not IS_HF_SPACE) and (not torch.cuda.is_available()) and torch.backends.mps.is_available() _dtype = torch.float16 if _use_mps else torch.bfloat16 # float16 on MPS (bfloat16 unsupported) def _load_pipeline_cpu() -> None: """Phase 1: load model weights into CPU RAM. Thread-safe — pregenerate and on-demand handlers can call this concurrently; the lock prevents a double load (which previously applied the LoRA twice and crashed with a meta-tensor error).""" global _pipe if _pipe is not None: return with _pipe_lock: if _pipe is not None: # re-check inside the lock return from diffusers import Flux2KleinPipeline hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") print(f"Loading FLUX.2-klein-4B on CPU… (dtype={_dtype}, token={'set' if hf_token else 'NOT SET'})") pipe = Flux2KleinPipeline.from_pretrained( _MODEL_ID, torch_dtype=_dtype, token=hf_token, ) # Apply the NumZoo style LoRA (trained on FLUX.2-klein-base-4B, loads on distilled). # Bundled in the repo via Git LFS. Loaded here at module/CPU scope so it stays out # of the ZeroGPU 60 s budget. Never crash the import — fall back to the base model. lora_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "lora", "numzoo_klein.safetensors") if os.path.isfile(lora_path): try: pipe.load_lora_weights(lora_path) print(f"✅ NumZoo LoRA loaded from {lora_path}") except Exception as e: print(f"⚠️ Could not load NumZoo LoRA ({e}) — using base model") else: print(f"⚠️ NumZoo LoRA not found at {lora_path} — using base model") _pipe = pipe # publish only once fully built (incl. LoRA) print("Pipeline loaded on CPU — ready for device placement.") def get_pipeline(): """Phase 2: move pipeline to the target device. Call inside @GPU on HF Spaces.""" global _pipe if _pipe is None: _load_pipeline_cpu() # fallback for local / first-call safety if torch.cuda.is_available(): return _pipe.to("cuda") if _use_mps: return _pipe.to("mps") return _pipe.to("cpu") # --------------------------------------------------------------------------- # Core generation (always wrapped in try/except) # --------------------------------------------------------------------------- def _generate(animals: list[str], places: list[str], reward_id=None): try: import time prompt = build_prompt(animals, places) # Serialize device placement + inference: pregenerate and on-demand handlers # can fire concurrently, but one device (MPS/GPU slice) can't run two forward # passes at once. ZeroGPU serializes @GPU calls for us; locally we must. with _gen_lock: # Reuse the pre-generated image for this reward instead of generating twice. if reward_id is not None and _recent_reward["id"] == reward_id \ and _recent_reward["image"] is not None: print(f"♻️ Reusing pre-generated image for reward {reward_id}") return _recent_reward["image"], prompt pipe = get_pipeline() print(f"Generating | prompt: {prompt}") t0 = time.time() result = pipe( prompt=prompt, num_inference_steps=4, guidance_scale=1.0, # klein-4B distilled: always 1.0 (not schnell's 0.0) height=768, # 768 (vs 512) gives multiple subjects room → less merging width=768, ) print(f"✅ Generated in {time.time() - t0:.1f}s") image = result.images[0] if reward_id is not None: _recent_reward["id"], _recent_reward["image"] = reward_id, image return image, prompt except Exception as e: import traceback print(f"[image_generator] ❌ generation failed: {e}") print(traceback.format_exc()) return None, str(e) # --------------------------------------------------------------------------- # Module-scope CPU pre-load (HF Spaces only). # Runs after the persistent cache is set up and the snapshot is downloaded, # so from_pretrained finds the weights locally and completes quickly. # Locally this is skipped — the pipeline loads lazily on first generate call. # --------------------------------------------------------------------------- if IS_HF_SPACE: _load_pipeline_cpu() # --------------------------------------------------------------------------- # Public API — single function, @GPU decorator is a no-op locally. # --------------------------------------------------------------------------- @GPU(duration=60) def generate_reward_image(animals: list[str], places: list[str], reward_id=None): """Generate a reward image. On HF Spaces runs inside a ZeroGPU slice; locally the @GPU decorator is a no-op and MPS/CPU is used instead. Pass reward_id so the pre-generate + on-demand paths dedupe the same reward.""" return _generate(animals, places, reward_id=reward_id)