PaperProf / core /image_gen.py
Ryadg's picture
feat: amphitheater hero with mouse-tracking robot professor, fixed particles
57ebbd6
Raw
History Blame Contribute Delete
2.07 kB
"""
Concept image generation via FLUX.2-klein.
"""
import threading
from functools import lru_cache
import torch
from PIL import Image
MODEL_ID = "black-forest-labs/FLUX.2-klein-4B"
def _prefetch_weights():
# Download the ~16 GB of weights to the HF cache at startup so the
# @spaces.GPU window is spent loading + generating, not downloading.
try:
from huggingface_hub import snapshot_download
# Skip the single-file ComfyUI checkpoint (7.75 GB duplicate of the
# transformer) and example images — diffusers never reads them.
snapshot_download(
MODEL_ID,
ignore_patterns=["flux-2-klein-4b.safetensors", "*.jpg", "*.png", "*.md"],
)
print("[image_gen] FLUX.2-klein weights prefetched")
except Exception as e:
print(f"[image_gen] prefetch failed (will download on first call): {e}")
threading.Thread(target=_prefetch_weights, daemon=True).start()
@lru_cache(maxsize=1)
def _load_pipeline():
# Imported lazily so the app can start (and CUDA stays untouched)
# outside the @spaces.GPU context that calls generate_concept_image.
from diffusers import Flux2KleinPipeline
pipe = Flux2KleinPipeline.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
)
pipe.to("cuda")
return pipe
def generate_concept_image(concept: str) -> Image.Image:
# `concept` is intentionally ignored: FLUX.2 renders any text it sees
# into the image, so we use a fixed purely-visual celebration prompt.
pipe = _load_pipeline()
prompt = (
"A cute minimalist illustration celebrating learning and success, "
"no text, no words, no letters, no labels, purely visual, "
"soft glowing shapes, stars, abstract celebration, "
"dark navy background, purple and cyan colors, "
"warm and encouraging mood, flat design"
)
result = pipe(
prompt=prompt,
height=512,
width=512,
num_inference_steps=4,
guidance_scale=1.0,
)
return result.images[0]