File size: 2,069 Bytes
53f316c
3e0f4f1
53f316c
 
5158491
53f316c
 
 
 
 
5158491
 
 
 
 
 
 
 
5ff15a8
 
 
 
 
 
5158491
 
 
 
 
 
 
53f316c
 
 
 
 
3e0f4f1
53f316c
3e0f4f1
5158491
3e0f4f1
53f316c
 
 
 
 
 
57ebbd6
 
53f316c
 
57ebbd6
 
 
 
 
53f316c
 
413ab6b
53f316c
 
 
413ab6b
53f316c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
"""
Concept image generation via FLUX.2-klein.
"""

import threading
from functools import lru_cache

import torch
from PIL import Image

MODEL_ID = "black-forest-labs/FLUX.2-klein-4B"


def _prefetch_weights():
    # Download the ~16 GB of weights to the HF cache at startup so the
    # @spaces.GPU window is spent loading + generating, not downloading.
    try:
        from huggingface_hub import snapshot_download
        # Skip the single-file ComfyUI checkpoint (7.75 GB duplicate of the
        # transformer) and example images — diffusers never reads them.
        snapshot_download(
            MODEL_ID,
            ignore_patterns=["flux-2-klein-4b.safetensors", "*.jpg", "*.png", "*.md"],
        )
        print("[image_gen] FLUX.2-klein weights prefetched")
    except Exception as e:
        print(f"[image_gen] prefetch failed (will download on first call): {e}")


threading.Thread(target=_prefetch_weights, daemon=True).start()


@lru_cache(maxsize=1)
def _load_pipeline():
    # Imported lazily so the app can start (and CUDA stays untouched)
    # outside the @spaces.GPU context that calls generate_concept_image.
    from diffusers import Flux2KleinPipeline

    pipe = Flux2KleinPipeline.from_pretrained(
        MODEL_ID,
        torch_dtype=torch.bfloat16,
    )
    pipe.to("cuda")
    return pipe


def generate_concept_image(concept: str) -> Image.Image:
    # `concept` is intentionally ignored: FLUX.2 renders any text it sees
    # into the image, so we use a fixed purely-visual celebration prompt.
    pipe = _load_pipeline()
    prompt = (
        "A cute minimalist illustration celebrating learning and success, "
        "no text, no words, no letters, no labels, purely visual, "
        "soft glowing shapes, stars, abstract celebration, "
        "dark navy background, purple and cyan colors, "
        "warm and encouraging mood, flat design"
    )
    result = pipe(
        prompt=prompt,
        height=512,
        width=512,
        num_inference_steps=4,
        guidance_scale=1.0,
    )
    return result.images[0]