Spaces:
Paused
Paused
File size: 1,992 Bytes
b9ba589 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 | # kaggle_template.py
# =====================================================
# KAGGLE IMAGE WORKER (FLUX.1-SCHNELL)
# =====================================================
import os, torch, gc, subprocess, sys
# Install bitsandbytes if missing (Critical for T4 GPU)
try:
import bitsandbytes
except ImportError:
subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "bitsandbytes"])
from diffusers import FluxPipeline, FluxTransformer2DModel, BitsAndBytesConfig
from pathlib import Path
# --- CONFIG ---
# The automation script will inject the prompts here automatically
PROMPTS = [
# {{PROMPTS_PLACEHOLDER}}
]
OUTPUT_DIR = Path("/kaggle/working/images")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
# --- MODEL LOADING ---
print("📦 Loading Quantized Flux...")
# 4-bit config to fit Flux on Kaggle T4s
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
# Load Transformer
transformer = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
subfolder="transformer",
quantization_config=bnb_config,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True
)
# Load Pipeline
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
transformer=transformer,
torch_dtype=torch.bfloat16,
)
pipe.enable_model_cpu_offload()
# --- GENERATION ---
print(f"🎨 Generating {len(PROMPTS)} images...")
for i, prompt in enumerate(PROMPTS, 1):
print(f" Frame {i}/{len(PROMPTS)}...")
gc.collect()
torch.cuda.empty_cache()
# Generate
image = pipe(
prompt=prompt,
width=1072,
height=1920,
num_inference_steps=4, # Schnell is fast
guidance_scale=1.0,
max_sequence_length=512,
).images[0]
# Save
save_path = OUTPUT_DIR / f"{i:02d}.png"
image.save(save_path)
print(f" ✅ Saved: {save_path.name}")
print("🏁 Job Complete") |