| """HF-native model execution for the Space runtime.""" |
|
|
| from __future__ import annotations |
|
|
| import os |
| import threading |
| import time |
| from dataclasses import dataclass, asdict |
| from pathlib import Path |
| from typing import Any |
| from uuid import uuid4 |
|
|
| from .catalog import ADAPTER_CATALOG |
| from .lora_adapter import load_and_apply, unload_all |
| from .schema import AdapterRecipe |
|
|
|
|
| FLUX_REPO_ID = "black-forest-labs/FLUX.2-klein-9B" |
| TINY_TITAN_FLUX_REPO_ID = "black-forest-labs/FLUX.2-klein-4B" |
| PRIVATE_RESEARCH_FLUX_REPO_ID = FLUX_REPO_ID |
| _PIPELINE_CACHE: dict[str, Any] = {} |
| _PIPELINE_CACHE_LOCK = threading.Lock() |
|
|
|
|
| @dataclass(frozen=True) |
| class HFGenerationResult: |
| status: str |
| provider_state: str |
| repo_id: str |
| output_path: str | None = None |
| message: str = "" |
| latency_seconds: float | None = None |
| width: int = 1024 |
| height: int = 1024 |
| steps: int = 4 |
| hf_token_present: bool = False |
| lora_status: str = "disabled" |
| lora_repo_id: str | None = None |
| lora_message: str = "No LoRA adapter selected for this run." |
| fallback_used: bool = False |
| primary_error: str | None = None |
|
|
| def to_dict(self) -> dict[str, Any]: |
| """ |
| Convert the result to a dictionary. |
| |
| Returns: |
| The result as a dictionary. |
| """ |
| return asdict(self) |
|
|
|
|
| def hf_runtime_enabled() -> bool: |
| if os.environ.get("NEXUS_DISABLE_REAL_HF") == "1": |
| return False |
| if os.environ.get("NEXUS_ENABLE_REAL_HF") == "1": |
| return True |
| return bool(os.environ.get("SPACE_ID") or os.environ.get("HF_SPACE_ID")) |
|
|
|
|
| def _output_dir() -> Path: |
| path_str = os.environ.get("NEXUS_OUTPUT_DIR") |
| if not path_str: |
| if Path("/data").exists(): |
| try: |
| root = Path("/data/nexus_visual_weaver") |
| root.mkdir(parents=True, exist_ok=True) |
| return root |
| except PermissionError: |
| pass |
| path_str = "outputs/runtime" |
| root = Path(path_str) |
| root.mkdir(parents=True, exist_ok=True) |
| return root |
|
|
|
|
| def _short_error(exc: BaseException) -> str: |
| text = str(exc).replace("\n", " ").strip() |
| if len(text) > 420: |
| text = text[:417] + "..." |
| return f"{exc.__class__.__name__}: {text}" |
|
|
|
|
| def _hf_token() -> str | None: |
| """ |
| Retrieve the HuggingFace authentication token from environment variables. |
| |
| Checks HF_TOKEN first, then HUGGING_FACE_HUB_TOKEN. |
| |
| Returns: |
| str | None: The token if available, None otherwise. |
| """ |
| return os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") |
|
|
|
|
| def active_flux_repo_id() -> str: |
| """ |
| Determines the FLUX repository ID based on environment configuration. |
| |
| Returns: |
| str: The FLUX repository ID to use for image generation. |
| """ |
| configured = os.environ.get("NEXUS_FLUX_REPO_ID") |
| if configured: |
| return configured |
| if os.environ.get("NEXUS_TINY_TITAN_MODE") == "1": |
| return TINY_TITAN_FLUX_REPO_ID |
| return FLUX_REPO_ID |
|
|
|
|
| def _repo_candidates(repo_id: str) -> list[str]: |
| """ |
| Return a list of repository candidates for generation attempts. |
| |
| The primary repo_id is always included. If repo_id is not the Tiny Titan |
| model and fallback is not disabled via NEXUS_DISABLE_TINY_TITAN_FALLBACK, |
| the Tiny Titan repository is appended as a fallback candidate. |
| |
| Parameters: |
| repo_id (str): The primary FLUX model repository identifier. |
| |
| Returns: |
| list[str]: Repository IDs to attempt, starting with the primary repo |
| and optionally including a fallback candidate. |
| """ |
| candidates = [repo_id] |
| if repo_id != TINY_TITAN_FLUX_REPO_ID and os.environ.get("NEXUS_DISABLE_TINY_TITAN_FALLBACK") != "1": |
| candidates.append(TINY_TITAN_FLUX_REPO_ID) |
| return candidates |
|
|
|
|
| def _get_flux_pipe(repo_id: str, torch_module: Any, pipeline_cls: Any, token: str | None) -> Any: |
| """ |
| Get or load a cached FLUX pipeline from a Hugging Face repository. |
| |
| Maintains a thread-safe cache of pipelines keyed by repository ID to avoid repeated loading. |
| |
| Parameters: |
| torch_module: The torch module, providing the dtype for the pipeline. |
| pipeline_cls: The pipeline class (e.g., Flux2KleinPipeline) to instantiate. |
| |
| Returns: |
| A FLUX pipeline instance. |
| """ |
| with _PIPELINE_CACHE_LOCK: |
| cached = _PIPELINE_CACHE.get(repo_id) |
| if cached is not None: |
| return cached |
| pipe = pipeline_cls.from_pretrained(repo_id, torch_dtype=torch_module.bfloat16, token=token) |
| pipe.enable_model_cpu_offload() |
| _PIPELINE_CACHE[repo_id] = pipe |
| return pipe |
|
|
|
|
| def _adapter_recipe(repo_id: str | None) -> AdapterRecipe | None: |
| """ |
| Retrieves an adapter recipe from the catalog by repository identifier. |
| |
| Returns: |
| AdapterRecipe | None: The matching adapter recipe, or None if no recipe is found. |
| """ |
| if not repo_id: |
| return None |
| return next((recipe for recipe in ADAPTER_CATALOG if recipe.repo_id == repo_id), None) |
|
|
|
|
| def default_lora_repo_id(target_repo_id: str) -> str | None: |
| """ |
| Selects the default LoRA adapter compatible with a FLUX repository. |
| |
| Returns the repository ID of the first adapter that is runtime-enabled, is not |
| adult-only, does not require an input image, and is compatible with the target |
| repository. |
| |
| Parameters: |
| target_repo_id (str): The FLUX repository ID to find a compatible adapter for. |
| |
| Returns: |
| The repository ID of the first matching adapter, or `None` if no compatible |
| adapter is found. |
| """ |
| for recipe in ADAPTER_CATALOG: |
| compatible_ids = {recipe.adapter_for, *recipe.compatible_repo_ids} |
| if recipe.runtime_enabled and not recipe.adult_only and not recipe.requires_image and target_repo_id in compatible_ids: |
| return recipe.repo_id |
| return None |
|
|
|
|
| def generate_flux_image( |
| prompt: str, |
| *, |
| seed: int = 0, |
| width: int = 1024, |
| height: int = 1024, |
| steps: int = 4, |
| lora_repo_id: str | None = None, |
| adult_mode: bool = False, |
| ) -> HFGenerationResult: |
| """ |
| Generate a FLUX.2 image from a text prompt with optional LoRA adapter application. |
| |
| Returns: |
| HFGenerationResult: Immutable container with generation status, image output path, execution latency, and LoRA application results. |
| """ |
| repo_id = active_flux_repo_id() |
| selected_lora = lora_repo_id if lora_repo_id is not None else default_lora_repo_id(repo_id) |
| recipe = _adapter_recipe(selected_lora) |
| if not hf_runtime_enabled(): |
| return HFGenerationResult( |
| status="disabled", |
| provider_state="dry-run", |
| repo_id=repo_id, |
| message="Real HF generation disabled outside Space. Raven Quality Stack uses FLUX.2 Klein 9B by default; set NEXUS_TINY_TITAN_MODE=1 for the 4B sidecar.", |
| width=width, |
| height=height, |
| steps=steps, |
| hf_token_present=bool(_hf_token()), |
| lora_status="disabled", |
| lora_repo_id=recipe.repo_id if recipe else None, |
| lora_message="LoRA loading requires the HF Space GPU runtime.", |
| ) |
|
|
| started = time.perf_counter() |
| try: |
| import torch |
| from diffusers import Flux2KleinPipeline |
| except Exception as exc: |
| return HFGenerationResult( |
| status="missing_runtime", |
| provider_state="blocked", |
| repo_id=repo_id, |
| message=f"FLUX runtime import failed. Install diffusers main + torch. {_short_error(exc)}", |
| width=width, |
| height=height, |
| steps=steps, |
| hf_token_present=bool(_hf_token()), |
| lora_status="disabled", |
| lora_repo_id=recipe.repo_id if recipe else None, |
| lora_message="FLUX runtime import failed before LoRA loading.", |
| ) |
|
|
| if not torch.cuda.is_available(): |
| return HFGenerationResult( |
| status="no_cuda", |
| provider_state="blocked", |
| repo_id=repo_id, |
| message="CUDA is not available to the Space callback; FLUX.2 generation requires GPU execution.", |
| width=width, |
| height=height, |
| steps=steps, |
| hf_token_present=bool(_hf_token()), |
| lora_status="disabled", |
| lora_repo_id=recipe.repo_id if recipe else None, |
| lora_message="CUDA unavailable before LoRA loading.", |
| ) |
|
|
| token = _hf_token() |
| errors: list[str] = [] |
| for candidate in _repo_candidates(repo_id): |
| try: |
| pipe = _get_flux_pipe(candidate, torch, Flux2KleinPipeline, token) |
| if hasattr(pipe, "set_progress_bar_config"): |
| pipe.set_progress_bar_config(disable=True) |
| lora_result = load_and_apply(pipe, recipe, candidate, adult_mode=adult_mode) |
|
|
| try: |
| generator = torch.Generator(device="cuda").manual_seed(seed) |
| image = pipe( |
| prompt=prompt, |
| height=height, |
| width=width, |
| guidance_scale=1.0, |
| num_inference_steps=steps, |
| generator=generator, |
| ).images[0] |
| finally: |
| unload_all(pipe) |
| output_path = _output_dir() / f"nexus_flux_{time.time_ns()}_{seed}_{uuid4().hex[:8]}.png" |
| image.save(output_path) |
| return HFGenerationResult( |
| status="success", |
| provider_state="generated", |
| repo_id=candidate, |
| output_path=str(output_path), |
| message=f"{candidate} generated a real Raven Quality artifact on HF Space.", |
| latency_seconds=round(time.perf_counter() - started, 2), |
| width=width, |
| height=height, |
| steps=steps, |
| hf_token_present=bool(token), |
| lora_status=str(lora_result.get("status", "disabled")), |
| lora_repo_id=lora_result.get("repo_id"), |
| lora_message=str(lora_result.get("message", "")), |
| fallback_used=candidate != repo_id, |
| primary_error=errors[0] if candidate != repo_id and errors else None, |
| ) |
| except Exception as exc: |
| errors.append(f"{candidate}: {_short_error(exc)}") |
| with _PIPELINE_CACHE_LOCK: |
| _PIPELINE_CACHE.pop(candidate, None) |
| continue |
| return HFGenerationResult( |
| status="error", |
| provider_state="blocked", |
| repo_id=repo_id, |
| message=f"FLUX.2 generation failed. Check model license acceptance, HF_TOKEN/Space access, and runtime deps. Attempts: {' | '.join(errors)}", |
| latency_seconds=round(time.perf_counter() - started, 2), |
| width=width, |
| height=height, |
| steps=steps, |
| hf_token_present=bool(_hf_token()), |
| lora_status="disabled" if recipe is None else "failed", |
| lora_repo_id=recipe.repo_id if recipe else None, |
| lora_message="Generation failed before a usable LoRA evidence state could be produced.", |
| ) |
|
|