File size: 1,992 Bytes
b9ba589
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# kaggle_template.py
# =====================================================
# KAGGLE IMAGE WORKER (FLUX.1-SCHNELL)
# =====================================================
import os, torch, gc, subprocess, sys

# Install bitsandbytes if missing (Critical for T4 GPU)
try:
    import bitsandbytes
except ImportError:
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "bitsandbytes"])

from diffusers import FluxPipeline, FluxTransformer2DModel, BitsAndBytesConfig
from pathlib import Path

# --- CONFIG ---
# The automation script will inject the prompts here automatically
PROMPTS = [
    # {{PROMPTS_PLACEHOLDER}}
]

OUTPUT_DIR = Path("/kaggle/working/images")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# --- MODEL LOADING ---
print("📦 Loading Quantized Flux...")

# 4-bit config to fit Flux on Kaggle T4s
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

# Load Transformer
transformer = FluxTransformer2DModel.from_pretrained(
    "black-forest-labs/FLUX.1-schnell",
    subfolder="transformer",
    quantization_config=bnb_config,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True
)

# Load Pipeline
pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-schnell",
    transformer=transformer,
    torch_dtype=torch.bfloat16,
)

pipe.enable_model_cpu_offload()

# --- GENERATION ---
print(f"🎨 Generating {len(PROMPTS)} images...")

for i, prompt in enumerate(PROMPTS, 1):
    print(f"   Frame {i}/{len(PROMPTS)}...")
    gc.collect()
    torch.cuda.empty_cache()
    
    # Generate
    image = pipe(
        prompt=prompt,
        width=1072, 
        height=1920,
        num_inference_steps=4, # Schnell is fast
        guidance_scale=1.0,
        max_sequence_length=512,
    ).images[0]
    
    # Save
    save_path = OUTPUT_DIR / f"{i:02d}.png"
    image.save(save_path)
    print(f"   ✅ Saved: {save_path.name}")

print("🏁 Job Complete")