Spaces:
Runtime error
Runtime error
| import random | |
| import os | |
| import sys | |
| import threading | |
| import asyncio | |
| # --------------------------------------------------------------------------- | |
| # ZeroGPU shim — spaces MUST be imported BEFORE torch. | |
| # On HF ZeroGPU, spaces patches CUDA initialisation; that patch must land | |
| # before torch is imported or GPU calls can silently mis-behave. | |
| # Locally spaces isn't installed, so we fall back to a no-op @GPU decorator | |
| # that makes the same code run unchanged on MPS / CPU. | |
| # --------------------------------------------------------------------------- | |
| try: | |
| import spaces # type: ignore | |
| GPU = spaces.GPU | |
| ON_ZEROGPU = True | |
| except Exception: | |
| def GPU(*dargs, **dkwargs): # noqa: N802 — mirror the spaces.GPU API | |
| def wrap(fn): | |
| return fn | |
| if len(dargs) == 1 and callable(dargs[0]) and not dkwargs: | |
| return dargs[0] | |
| return wrap | |
| ON_ZEROGPU = False | |
| import torch | |
| # Patch for torch < 2.4 which lacks torch.xpu (required by diffusers >= 0.30) | |
| if not hasattr(torch, "xpu"): | |
| class _MockXPU: | |
| is_available = staticmethod(lambda: False) | |
| device_count = staticmethod(lambda: 0) | |
| empty_cache = staticmethod(lambda: None) | |
| manual_seed = staticmethod(lambda seed: None) | |
| reset_peak_memory_stats = staticmethod(lambda: None) | |
| max_memory_allocated = staticmethod(lambda: 0) | |
| synchronize = staticmethod(lambda: None) | |
| torch.xpu = _MockXPU() | |
| # --------------------------------------------------------------------------- | |
| # Detect HuggingFace Spaces | |
| # --------------------------------------------------------------------------- | |
| IS_HF_SPACE = os.environ.get("SPACE_ID") is not None | |
| if IS_HF_SPACE: | |
| # Suppress Python 3.13 asyncio GC bug (Invalid file descriptor: -1) | |
| _orig_unraisable = sys.unraisablehook | |
| def _unraisable_hook(args): | |
| if args.exc_type is ValueError and "Invalid file descriptor" in str(args.exc_value): | |
| return | |
| _orig_unraisable(args) | |
| sys.unraisablehook = _unraisable_hook | |
| # --------------------------------------------------------------------------- | |
| # Persistent storage cache — survives sleep/restart on HF Spaces | |
| # --------------------------------------------------------------------------- | |
| # The /data mount repeatedly gets poisoned: partial downloads leave files where | |
| # huggingface_hub/xet expect directories ([Errno 20] Not a directory), and the | |
| # cache root itself can end up as a file with I/O errors. This block is | |
| # bulletproof: any failure setting up /data falls back to the ephemeral | |
| # container cache (slower cold start, but always works) and NEVER crashes the | |
| # import — a dead import takes the whole app down. | |
| _MODEL_ID = "black-forest-labs/FLUX.2-klein-4B" | |
| _MODEL_DIR = "models--black-forest-labs--FLUX.2-klein-4B" | |
| def _force_remove(path): | |
| """Remove path whether it's a file, dir, or broken — never raises.""" | |
| import shutil | |
| try: | |
| if os.path.isdir(path) and not os.path.islink(path): | |
| shutil.rmtree(path) | |
| elif os.path.exists(path) or os.path.islink(path): | |
| os.remove(path) | |
| except Exception as e: | |
| print(f"⚠️ Could not remove {path}: {e}") | |
| def _setup_persistent_cache(): | |
| """Point HF cache at /data if usable. Returns True if persistence is active.""" | |
| cache_dir = "/data/hf_cache_v4" # bump path to escape any poisoned older cache | |
| hub_dir = os.path.join(cache_dir, "hub") | |
| # Probe: can we create a clean directory tree at cache_dir? If the path is a | |
| # poisoned file or the mount errors, bail out to ephemeral caching. | |
| try: | |
| if os.path.exists(cache_dir) or os.path.islink(cache_dir): | |
| # Validate existing cache: a real model snapshot with model_index.json | |
| snaps = os.path.join(hub_dir, _MODEL_DIR, "snapshots") | |
| valid = os.path.isdir(snaps) and any( | |
| os.path.isfile(os.path.join(snaps, s, "model_index.json")) | |
| for s in os.listdir(snaps) | |
| ) if os.path.isdir(snaps) else False | |
| if not valid: | |
| print(f"Cache at {cache_dir} is incomplete/poisoned — wiping") | |
| _force_remove(cache_dir) | |
| os.makedirs(hub_dir, exist_ok=True) | |
| except Exception as e: | |
| print(f"⚠️ /data cache unusable ({e}) — using ephemeral container cache") | |
| return False | |
| os.environ["HF_HOME"] = cache_dir | |
| os.environ["HF_HUB_CACHE"] = hub_dir | |
| print(f"Persistent cache active → {cache_dir}") | |
| return True | |
| def _is_model_cached(): | |
| snaps = os.path.join(os.environ.get("HF_HUB_CACHE", ""), _MODEL_DIR, "snapshots") | |
| if not os.path.isdir(snaps): | |
| return False | |
| return any( | |
| os.path.isfile(os.path.join(snaps, s, "model_index.json")) | |
| for s in os.listdir(snaps) | |
| ) | |
| def _predownload_model(): | |
| """Blocking pre-download so the first @spaces.GPU call is fast. Never raises.""" | |
| if _is_model_cached(): | |
| print(f"✅ FLUX.2-klein-4B already cached — skipping download") | |
| return | |
| print(f"Downloading FLUX.2-klein-4B…") | |
| try: | |
| from huggingface_hub import snapshot_download | |
| hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") | |
| result = {} | |
| def _download(): | |
| loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(loop) | |
| try: | |
| snapshot_download( | |
| _MODEL_ID, | |
| token=hf_token, | |
| ignore_patterns=["*.msgpack", "*.h5", "flax_model*"], | |
| ) | |
| result["ok"] = True | |
| except Exception as e: | |
| result["error"] = e | |
| finally: | |
| try: loop.close() | |
| except: pass | |
| t = threading.Thread(target=_download, daemon=True) | |
| t.start(); t.join() | |
| if "error" in result: | |
| raise result["error"] | |
| print("✅ FLUX.2-klein-4B cached successfully") | |
| except Exception as e: | |
| print(f"⚠️ Pre-cache warning (will retry at generation time): {e}") | |
| if IS_HF_SPACE: | |
| if os.path.isdir("/data"): | |
| _setup_persistent_cache() # falls back to ephemeral on any failure | |
| _predownload_model() | |
| else: | |
| print("Local mode.") | |
| # --------------------------------------------------------------------------- | |
| # Emoji → descriptive text maps (fed into FLUX prompt) | |
| # --------------------------------------------------------------------------- | |
| # Keep keys + order in sync with ANIMAL_EMOJIS in app.py (grouped by lookalike). | |
| ANIMAL_MAP: dict[str, str] = { | |
| "🐶": "puppy", "🐱": "kitten", "🐰": "bunny", | |
| "🐭": "baby mouse", "🦊": "baby fox", "🐺": "baby wolf", | |
| "🐻": "baby bear", "🐼": "baby panda", "🐨": "baby koala", | |
| "🦁": "baby lion", "🐯": "baby tiger", "🐷": "baby pig", | |
| "🐮": "baby cow", "🐴": "pony", "🐑": "baby lamb", | |
| "🐸": "baby frog", "🦦": "baby otter", "🐧": "baby penguin", | |
| "🐙": "baby octopus", "🦋": "butterfly", "🐝": "bee", | |
| "🐞": "ladybug", "🦄": "unicorn", "🐉": "baby dragon", | |
| } | |
| # Keep keys + order in sync with PLACE_EMOJIS in app.py. | |
| PLACE_MAP: dict[str, str] = { | |
| "🌊": "on a sunny beach with ocean waves", | |
| "🏔️": "on a snowy mountain top", | |
| "🌸": "in a cherry blossom garden", | |
| "🌈": "under a rainbow", | |
| "🌙": "on a glowing crescent moon", | |
| "⭐": "surrounded by sparkling stars", | |
| "🌴": "on a tropical island", | |
| "🏡": "in a cosy cottage garden", | |
| "🌺": "in a field of tropical flowers", | |
| "🍄": "in an enchanted mushroom forest", | |
| "🏜️": "in a red rock canyon", | |
| "🏰": "in a fairytale castle", | |
| "🎢": "at a fun theme park", | |
| "⛺": "in a cosy camping tent", | |
| "🪩": "under a sparkling disco ball", | |
| "🎪": "at a colourful circus", | |
| "⛵": "in a little boat on calm water", | |
| "🚀": "aboard a space rocket among the stars", | |
| } | |
| # Shared style suffix — must stay in sync with STYLE in scripts/generate_dataset.py. | |
| # This exact string appears in every training caption so the LoRA learns it as a trigger. | |
| NUMZOO_STYLE = ( | |
| "kawaii children's book illustration, pastel anime art style, " | |
| "soft painterly lighting, detailed rich background with warm fairy lights, " | |
| "cozy magical atmosphere, cute chibi character with big sparkling eyes, " | |
| "soft pastel color palette, highly detailed scene, no text" | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Prompt builder | |
| # --------------------------------------------------------------------------- | |
| def _join_animals(parts: list[str]) -> str: | |
| """Join 1–5 animals; each one after the first is prefixed 'a cute' so the model | |
| renders them as distinct subjects, e.g. 'puppy, a cute kitten and a cute bunny'.""" | |
| if len(parts) == 1: | |
| return parts[0] | |
| head, rest = parts[0], [f"a cute {p}" for p in parts[1:]] | |
| if len(rest) == 1: | |
| return f"{head} and {rest[0]}" | |
| return f"{head}, " + ", ".join(rest[:-1]) + f" and {rest[-1]}" | |
| def _join_places(parts: list[str]) -> str: | |
| """Join 1–3 phrases: 'a', 'a and b', 'a, b and c'.""" | |
| if len(parts) == 1: | |
| return parts[0] | |
| if len(parts) == 2: | |
| return f"{parts[0]} and {parts[1]}" | |
| return f"{parts[0]}, {parts[1]} and {parts[2]}" | |
| def build_subject(animals: list[str], places: list[str]) -> str: | |
| """The 'A cute {animals} {places}' core — SHARED with scripts/generate_dataset.py | |
| so training captions and live prompts use the exact same structure & vocabulary.""" | |
| animal_list = animals[:5] if animals else [random.choice(list(ANIMAL_MAP))] | |
| place_list = places[:3] if places else [random.choice(list(PLACE_MAP))] | |
| animal_text = _join_animals([ANIMAL_MAP.get(a, "bunny") for a in animal_list]) | |
| place_text = _join_places([PLACE_MAP.get(p, "in a magical garden") for p in place_list]) | |
| return f"A cute {animal_text} {place_text}" | |
| # Trigger word the NumZoo LoRA learned to associate with the whole cozy aesthetic. | |
| # Training captions were "NUMZOO. A cute {animals} {places}, <detail>", so the live | |
| # prompt mirrors that prefix exactly — the trigger now carries the style, making the | |
| # verbose NUMZOO_STYLE suffix redundant (kept as a constant for the dataset generator). | |
| NUMZOO_TRIGGER = "NUMZOO" | |
| def build_prompt(animals: list[str], places: list[str]) -> str: | |
| # "no text" suppresses the model rendering the NUMZOO trigger word as a literal | |
| # sign/title (the distilled 4-step klein is prone to this without it). | |
| # With multiple animals, a layout hint reduces subject-merging (the model | |
| # otherwise tends to fuse adjacent animals into one). | |
| layout = ", side by side, each a separate full-body character" if len(animals) > 1 else "" | |
| return f"{NUMZOO_TRIGGER}. {build_subject(animals, places)}{layout}, no text" | |
| # --------------------------------------------------------------------------- | |
| # Pipeline loader — two-phase, following the reference ZeroGPU pattern: | |
| # | |
| # Phase 1 · _load_pipeline_cpu() — from_pretrained to CPU RAM. | |
| # Called at module scope (outside any @GPU function) so the weights are | |
| # already resident when the first ZeroGPU call arrives. This keeps the | |
| # model-download cost out of the 60 s GPU-runtime budget and prevents | |
| # cold-start timeouts on the very first generation. | |
| # | |
| # Phase 2 · get_pipeline() — .to(device). | |
| # Must be called INSIDE a @GPU-decorated function (where a GPU slice is | |
| # guaranteed). Moving already-loaded CPU tensors to CUDA is fast (~1 s) | |
| # and comfortably within the budget. | |
| # | |
| # Locally (MPS / CPU) both phases happen inside generate_reward_image because | |
| # the no-op @GPU decorator doesn't impose any budget constraint. | |
| # --------------------------------------------------------------------------- | |
| _pipe = None | |
| _pipe_lock = threading.Lock() # serialize loading — pregenerate + on-demand can race | |
| _gen_lock = threading.Lock() # serialize inference — one device can't run two forwards at once | |
| # Single-slot cache so the on-demand fallback reuses the in-flight pre-generation | |
| # for the SAME reward instead of generating it twice. Keyed by reward_id (level); | |
| # a different reward_id always regenerates, so rewards stay varied across levels. | |
| _recent_reward = {"id": None, "image": None} | |
| _use_mps = (not IS_HF_SPACE) and (not torch.cuda.is_available()) and torch.backends.mps.is_available() | |
| _dtype = torch.float16 if _use_mps else torch.bfloat16 # float16 on MPS (bfloat16 unsupported) | |
| def _load_pipeline_cpu() -> None: | |
| """Phase 1: load model weights into CPU RAM. Thread-safe — pregenerate and | |
| on-demand handlers can call this concurrently; the lock prevents a double load | |
| (which previously applied the LoRA twice and crashed with a meta-tensor error).""" | |
| global _pipe | |
| if _pipe is not None: | |
| return | |
| with _pipe_lock: | |
| if _pipe is not None: # re-check inside the lock | |
| return | |
| from diffusers import Flux2KleinPipeline | |
| hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") | |
| print(f"Loading FLUX.2-klein-4B on CPU… (dtype={_dtype}, token={'set' if hf_token else 'NOT SET'})") | |
| pipe = Flux2KleinPipeline.from_pretrained( | |
| _MODEL_ID, | |
| torch_dtype=_dtype, | |
| token=hf_token, | |
| ) | |
| # Apply the NumZoo style LoRA (trained on FLUX.2-klein-base-4B, loads on distilled). | |
| # Bundled in the repo via Git LFS. Loaded here at module/CPU scope so it stays out | |
| # of the ZeroGPU 60 s budget. Never crash the import — fall back to the base model. | |
| lora_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "lora", "numzoo_klein.safetensors") | |
| if os.path.isfile(lora_path): | |
| try: | |
| pipe.load_lora_weights(lora_path) | |
| print(f"✅ NumZoo LoRA loaded from {lora_path}") | |
| except Exception as e: | |
| print(f"⚠️ Could not load NumZoo LoRA ({e}) — using base model") | |
| else: | |
| print(f"⚠️ NumZoo LoRA not found at {lora_path} — using base model") | |
| _pipe = pipe # publish only once fully built (incl. LoRA) | |
| print("Pipeline loaded on CPU — ready for device placement.") | |
| def get_pipeline(): | |
| """Phase 2: move pipeline to the target device. Call inside @GPU on HF Spaces.""" | |
| global _pipe | |
| if _pipe is None: | |
| _load_pipeline_cpu() # fallback for local / first-call safety | |
| if torch.cuda.is_available(): | |
| return _pipe.to("cuda") | |
| if _use_mps: | |
| return _pipe.to("mps") | |
| return _pipe.to("cpu") | |
| # --------------------------------------------------------------------------- | |
| # Core generation (always wrapped in try/except) | |
| # --------------------------------------------------------------------------- | |
| def _generate(animals: list[str], places: list[str], reward_id=None): | |
| try: | |
| import time | |
| prompt = build_prompt(animals, places) | |
| # Serialize device placement + inference: pregenerate and on-demand handlers | |
| # can fire concurrently, but one device (MPS/GPU slice) can't run two forward | |
| # passes at once. ZeroGPU serializes @GPU calls for us; locally we must. | |
| with _gen_lock: | |
| # Reuse the pre-generated image for this reward instead of generating twice. | |
| if reward_id is not None and _recent_reward["id"] == reward_id \ | |
| and _recent_reward["image"] is not None: | |
| print(f"♻️ Reusing pre-generated image for reward {reward_id}") | |
| return _recent_reward["image"], prompt | |
| pipe = get_pipeline() | |
| print(f"Generating | prompt: {prompt}") | |
| t0 = time.time() | |
| result = pipe( | |
| prompt=prompt, | |
| num_inference_steps=4, | |
| guidance_scale=1.0, # klein-4B distilled: always 1.0 (not schnell's 0.0) | |
| height=768, # 768 (vs 512) gives multiple subjects room → less merging | |
| width=768, | |
| ) | |
| print(f"✅ Generated in {time.time() - t0:.1f}s") | |
| image = result.images[0] | |
| if reward_id is not None: | |
| _recent_reward["id"], _recent_reward["image"] = reward_id, image | |
| return image, prompt | |
| except Exception as e: | |
| import traceback | |
| print(f"[image_generator] ❌ generation failed: {e}") | |
| print(traceback.format_exc()) | |
| return None, str(e) | |
| # --------------------------------------------------------------------------- | |
| # Module-scope CPU pre-load (HF Spaces only). | |
| # Runs after the persistent cache is set up and the snapshot is downloaded, | |
| # so from_pretrained finds the weights locally and completes quickly. | |
| # Locally this is skipped — the pipeline loads lazily on first generate call. | |
| # --------------------------------------------------------------------------- | |
| if IS_HF_SPACE: | |
| _load_pipeline_cpu() | |
| # --------------------------------------------------------------------------- | |
| # Public API — single function, @GPU decorator is a no-op locally. | |
| # --------------------------------------------------------------------------- | |
| def generate_reward_image(animals: list[str], places: list[str], reward_id=None): | |
| """Generate a reward image. On HF Spaces runs inside a ZeroGPU slice; | |
| locally the @GPU decorator is a no-op and MPS/CPU is used instead. | |
| Pass reward_id so the pre-generate + on-demand paths dedupe the same reward.""" | |
| return _generate(animals, places, reward_id=reward_id) | |