"""HF-native model execution for the Space runtime.""" from __future__ import annotations import os import threading import time from dataclasses import dataclass, asdict from pathlib import Path from typing import Any from uuid import uuid4 from .catalog import ADAPTER_CATALOG from .lora_adapter import load_and_apply, unload_all from .schema import AdapterRecipe FLUX_REPO_ID = "black-forest-labs/FLUX.2-klein-9B" TINY_TITAN_FLUX_REPO_ID = "black-forest-labs/FLUX.2-klein-4B" PRIVATE_RESEARCH_FLUX_REPO_ID = FLUX_REPO_ID _PIPELINE_CACHE: dict[str, Any] = {} _PIPELINE_CACHE_LOCK = threading.Lock() @dataclass(frozen=True) class HFGenerationResult: status: str provider_state: str repo_id: str output_path: str | None = None message: str = "" latency_seconds: float | None = None width: int = 1024 height: int = 1024 steps: int = 4 hf_token_present: bool = False lora_status: str = "disabled" lora_repo_id: str | None = None lora_message: str = "No LoRA adapter selected for this run." fallback_used: bool = False primary_error: str | None = None def to_dict(self) -> dict[str, Any]: """ Convert the result to a dictionary. Returns: The result as a dictionary. """ return asdict(self) def hf_runtime_enabled() -> bool: if os.environ.get("NEXUS_DISABLE_REAL_HF") == "1": return False if os.environ.get("NEXUS_ENABLE_REAL_HF") == "1": return True return bool(os.environ.get("SPACE_ID") or os.environ.get("HF_SPACE_ID")) def _output_dir() -> Path: path_str = os.environ.get("NEXUS_OUTPUT_DIR") if not path_str: if Path("/data").exists(): try: root = Path("/data/nexus_visual_weaver") root.mkdir(parents=True, exist_ok=True) return root except PermissionError: pass path_str = "outputs/runtime" root = Path(path_str) root.mkdir(parents=True, exist_ok=True) return root def _short_error(exc: BaseException) -> str: text = str(exc).replace("\n", " ").strip() if len(text) > 420: text = text[:417] + "..." return f"{exc.__class__.__name__}: {text}" def _hf_token() -> str | None: """ Retrieve the HuggingFace authentication token from environment variables. Checks HF_TOKEN first, then HUGGING_FACE_HUB_TOKEN. Returns: str | None: The token if available, None otherwise. """ return os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") def active_flux_repo_id() -> str: """ Determines the FLUX repository ID based on environment configuration. Returns: str: The FLUX repository ID to use for image generation. """ configured = os.environ.get("NEXUS_FLUX_REPO_ID") if configured: return configured if os.environ.get("NEXUS_TINY_TITAN_MODE") == "1": return TINY_TITAN_FLUX_REPO_ID return FLUX_REPO_ID def _repo_candidates(repo_id: str) -> list[str]: """ Return a list of repository candidates for generation attempts. The primary repo_id is always included. If repo_id is not the Tiny Titan model and fallback is not disabled via NEXUS_DISABLE_TINY_TITAN_FALLBACK, the Tiny Titan repository is appended as a fallback candidate. Parameters: repo_id (str): The primary FLUX model repository identifier. Returns: list[str]: Repository IDs to attempt, starting with the primary repo and optionally including a fallback candidate. """ candidates = [repo_id] if repo_id != TINY_TITAN_FLUX_REPO_ID and os.environ.get("NEXUS_DISABLE_TINY_TITAN_FALLBACK") != "1": candidates.append(TINY_TITAN_FLUX_REPO_ID) return candidates def _get_flux_pipe(repo_id: str, torch_module: Any, pipeline_cls: Any, token: str | None) -> Any: """ Get or load a cached FLUX pipeline from a Hugging Face repository. Maintains a thread-safe cache of pipelines keyed by repository ID to avoid repeated loading. Parameters: torch_module: The torch module, providing the dtype for the pipeline. pipeline_cls: The pipeline class (e.g., Flux2KleinPipeline) to instantiate. Returns: A FLUX pipeline instance. """ with _PIPELINE_CACHE_LOCK: cached = _PIPELINE_CACHE.get(repo_id) if cached is not None: return cached pipe = pipeline_cls.from_pretrained(repo_id, torch_dtype=torch_module.bfloat16, token=token) pipe.enable_model_cpu_offload() _PIPELINE_CACHE[repo_id] = pipe return pipe def _adapter_recipe(repo_id: str | None) -> AdapterRecipe | None: """ Retrieves an adapter recipe from the catalog by repository identifier. Returns: AdapterRecipe | None: The matching adapter recipe, or None if no recipe is found. """ if not repo_id: return None return next((recipe for recipe in ADAPTER_CATALOG if recipe.repo_id == repo_id), None) def default_lora_repo_id(target_repo_id: str) -> str | None: """ Selects the default LoRA adapter compatible with a FLUX repository. Returns the repository ID of the first adapter that is runtime-enabled, is not adult-only, does not require an input image, and is compatible with the target repository. Parameters: target_repo_id (str): The FLUX repository ID to find a compatible adapter for. Returns: The repository ID of the first matching adapter, or `None` if no compatible adapter is found. """ for recipe in ADAPTER_CATALOG: compatible_ids = {recipe.adapter_for, *recipe.compatible_repo_ids} if recipe.runtime_enabled and not recipe.adult_only and not recipe.requires_image and target_repo_id in compatible_ids: return recipe.repo_id return None def generate_flux_image( prompt: str, *, seed: int = 0, width: int = 1024, height: int = 1024, steps: int = 4, lora_repo_id: str | None = None, adult_mode: bool = False, ) -> HFGenerationResult: """ Generate a FLUX.2 image from a text prompt with optional LoRA adapter application. Returns: HFGenerationResult: Immutable container with generation status, image output path, execution latency, and LoRA application results. """ repo_id = active_flux_repo_id() selected_lora = lora_repo_id if lora_repo_id is not None else default_lora_repo_id(repo_id) recipe = _adapter_recipe(selected_lora) if not hf_runtime_enabled(): return HFGenerationResult( status="disabled", provider_state="dry-run", repo_id=repo_id, message="Real HF generation disabled outside Space. Raven Quality Stack uses FLUX.2 Klein 9B by default; set NEXUS_TINY_TITAN_MODE=1 for the 4B sidecar.", width=width, height=height, steps=steps, hf_token_present=bool(_hf_token()), lora_status="disabled", lora_repo_id=recipe.repo_id if recipe else None, lora_message="LoRA loading requires the HF Space GPU runtime.", ) started = time.perf_counter() try: import torch from diffusers import Flux2KleinPipeline except Exception as exc: # pragma: no cover - depends on Space runtime packages. return HFGenerationResult( status="missing_runtime", provider_state="blocked", repo_id=repo_id, message=f"FLUX runtime import failed. Install diffusers main + torch. {_short_error(exc)}", width=width, height=height, steps=steps, hf_token_present=bool(_hf_token()), lora_status="disabled", lora_repo_id=recipe.repo_id if recipe else None, lora_message="FLUX runtime import failed before LoRA loading.", ) if not torch.cuda.is_available(): return HFGenerationResult( status="no_cuda", provider_state="blocked", repo_id=repo_id, message="CUDA is not available to the Space callback; FLUX.2 generation requires GPU execution.", width=width, height=height, steps=steps, hf_token_present=bool(_hf_token()), lora_status="disabled", lora_repo_id=recipe.repo_id if recipe else None, lora_message="CUDA unavailable before LoRA loading.", ) token = _hf_token() errors: list[str] = [] for candidate in _repo_candidates(repo_id): try: pipe = _get_flux_pipe(candidate, torch, Flux2KleinPipeline, token) if hasattr(pipe, "set_progress_bar_config"): pipe.set_progress_bar_config(disable=True) lora_result = load_and_apply(pipe, recipe, candidate, adult_mode=adult_mode) try: generator = torch.Generator(device="cuda").manual_seed(seed) image = pipe( prompt=prompt, height=height, width=width, guidance_scale=1.0, num_inference_steps=steps, generator=generator, ).images[0] finally: unload_all(pipe) output_path = _output_dir() / f"nexus_flux_{time.time_ns()}_{seed}_{uuid4().hex[:8]}.png" image.save(output_path) return HFGenerationResult( status="success", provider_state="generated", repo_id=candidate, output_path=str(output_path), message=f"{candidate} generated a real Raven Quality artifact on HF Space.", latency_seconds=round(time.perf_counter() - started, 2), width=width, height=height, steps=steps, hf_token_present=bool(token), lora_status=str(lora_result.get("status", "disabled")), lora_repo_id=lora_result.get("repo_id"), lora_message=str(lora_result.get("message", "")), fallback_used=candidate != repo_id, primary_error=errors[0] if candidate != repo_id and errors else None, ) except Exception as exc: # pragma: no cover - exercised on HF Space with gated/runtime conditions. errors.append(f"{candidate}: {_short_error(exc)}") with _PIPELINE_CACHE_LOCK: _PIPELINE_CACHE.pop(candidate, None) continue return HFGenerationResult( status="error", provider_state="blocked", repo_id=repo_id, message=f"FLUX.2 generation failed. Check model license acceptance, HF_TOKEN/Space access, and runtime deps. Attempts: {' | '.join(errors)}", latency_seconds=round(time.perf_counter() - started, 2), width=width, height=height, steps=steps, hf_token_present=bool(_hf_token()), lora_status="disabled" if recipe is None else "failed", lora_repo_id=recipe.repo_id if recipe else None, lora_message="Generation failed before a usable LoRA evidence state could be produced.", )