Spaces:
Running
on
Zero
Running
on
Zero
File size: 9,729 Bytes
da23dfe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 |
"""
FLUX.2 Klein Client
===================
Client for FLUX.2 klein 4B local image generation.
Supports text-to-image and multi-reference editing.
"""
import logging
import time
from typing import Optional, List
from PIL import Image
import torch
from .models import GenerationRequest, GenerationResult
logger = logging.getLogger(__name__)
class FluxKleinClient:
"""
Client for FLUX.2 klein models.
Supports:
- Text-to-image generation
- Single and multi-reference image editing
- Multiple model sizes (4B, 9B) and variants (distilled, base)
"""
# Model variants - choose based on quality/speed tradeoff
MODELS = {
# 4B models (~13GB VRAM)
"4b": "black-forest-labs/FLUX.2-klein-4B", # Distilled, 4 steps
"4b-base": "black-forest-labs/FLUX.2-klein-base-4B", # Base, configurable steps
# 9B models (~29GB VRAM, better quality)
"9b": "black-forest-labs/FLUX.2-klein-9B", # Distilled, 4 steps
"9b-base": "black-forest-labs/FLUX.2-klein-base-9B", # Base, 50 steps - BEST QUALITY
"9b-fp8": "black-forest-labs/FLUX.2-klein-9b-fp8", # FP8 quantized (~20GB)
}
# Legacy compatibility
MODEL_ID = MODELS["4b"]
MODEL_ID_BASE = MODELS["4b-base"]
# Aspect ratio to dimensions mapping
ASPECT_RATIOS = {
"1:1": (1024, 1024),
"16:9": (1344, 768),
"9:16": (768, 1344),
"21:9": (1536, 640), # Cinematic ultra-wide
"3:2": (1248, 832),
"2:3": (832, 1248),
"3:4": (896, 1152),
"4:3": (1152, 896),
"4:5": (896, 1120),
"5:4": (1120, 896),
}
# Default settings for each model variant
MODEL_DEFAULTS = {
"4b": {"steps": 4, "guidance": 1.0},
"4b-base": {"steps": 28, "guidance": 3.5},
"9b": {"steps": 4, "guidance": 1.0},
"9b-base": {"steps": 50, "guidance": 4.0}, # Best quality
"9b-fp8": {"steps": 4, "guidance": 4.0},
}
def __init__(
self,
model_variant: str = "9b-base", # Default to highest quality
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
enable_cpu_offload: bool = True,
# Legacy params
use_base_model: bool = False,
):
"""
Initialize FLUX.2 klein client.
Args:
model_variant: Model variant to use:
- "4b": Fast, 4 steps, ~13GB VRAM
- "4b-base": Configurable steps, ~13GB VRAM
- "9b": Better quality, 4 steps, ~29GB VRAM
- "9b-base": BEST quality, 50 steps, ~29GB VRAM
- "9b-fp8": FP8 quantized, ~20GB VRAM
device: Device to use (cuda or cpu)
dtype: Data type for model weights
enable_cpu_offload: Enable CPU offload to save VRAM
"""
# Handle legacy use_base_model parameter
if use_base_model and model_variant == "9b-base":
model_variant = "4b-base"
self.model_variant = model_variant
self.device = device
self.dtype = dtype
self.enable_cpu_offload = enable_cpu_offload
self.pipe = None
self._loaded = False
# Get default settings for this variant
defaults = self.MODEL_DEFAULTS.get(model_variant, {"steps": 4, "guidance": 1.0})
self.default_steps = defaults["steps"]
self.default_guidance = defaults["guidance"]
logger.info(f"FluxKleinClient initialized (variant: {model_variant}, steps: {self.default_steps}, guidance: {self.default_guidance})")
def load_model(self) -> bool:
"""Load the model into memory."""
if self._loaded:
return True
try:
# Get model ID for selected variant
model_id = self.MODELS.get(self.model_variant, self.MODELS["4b"])
logger.info(f"Loading FLUX.2 klein ({self.model_variant}) from {model_id}...")
start_time = time.time()
# FLUX.2 klein requires Flux2KleinPipeline (specific to klein models)
# Requires diffusers from git: pip install git+https://github.com/huggingface/diffusers.git
from diffusers import Flux2KleinPipeline
self.pipe = Flux2KleinPipeline.from_pretrained(
model_id,
torch_dtype=self.dtype,
)
# Use enable_model_cpu_offload() for VRAM management (documented approach)
if self.enable_cpu_offload:
self.pipe.enable_model_cpu_offload()
logger.info("CPU offload enabled")
else:
self.pipe.to(self.device)
logger.info(f"Model moved to {self.device}")
load_time = time.time() - start_time
logger.info(f"FLUX.2 klein ({self.model_variant}) loaded in {load_time:.1f}s")
# Validate by running a test generation
logger.info("Validating model with test generation...")
try:
test_result = self.pipe(
prompt="A simple test image",
height=256,
width=256,
guidance_scale=1.0,
num_inference_steps=1,
generator=torch.Generator(device="cpu").manual_seed(42),
)
if test_result.images[0] is not None:
logger.info("Model validation successful")
else:
logger.error("Model validation failed: no output image")
return False
except Exception as e:
logger.error(f"Model validation failed: {e}", exc_info=True)
return False
self._loaded = True
return True
except Exception as e:
logger.error(f"Failed to load FLUX.2 klein: {e}", exc_info=True)
return False
def unload_model(self):
"""Unload model from memory."""
if self.pipe is not None:
del self.pipe
self.pipe = None
self._loaded = False
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.info("FLUX.2 klein unloaded")
def generate(
self,
request: GenerationRequest,
num_inference_steps: int = None,
guidance_scale: float = None
) -> GenerationResult:
"""
Generate image using FLUX.2 klein.
Args:
request: GenerationRequest object
num_inference_steps: Number of denoising steps (4 for klein distilled)
guidance_scale: Classifier-free guidance scale
Returns:
GenerationResult object
"""
if not self._loaded:
if not self.load_model():
return GenerationResult.error_result("Failed to load FLUX.2 klein model")
# Use model defaults if not specified
if num_inference_steps is None:
num_inference_steps = self.default_steps
if guidance_scale is None:
guidance_scale = self.default_guidance
try:
start_time = time.time()
# Get dimensions from aspect ratio
width, height = self._get_dimensions(request.aspect_ratio)
logger.info(f"Generating with {self.model_variant}: steps={num_inference_steps}, guidance={guidance_scale}")
# Build generation kwargs
gen_kwargs = {
"prompt": request.prompt,
"height": height,
"width": width,
"guidance_scale": guidance_scale,
"num_inference_steps": num_inference_steps,
"generator": torch.Generator(device="cpu").manual_seed(42),
}
# Add input images if present (for editing)
if request.has_input_images:
# FLUX.2 klein supports multi-reference editing
# Pass images as 'image' parameter
valid_images = [img for img in request.input_images if img is not None]
if len(valid_images) == 1:
gen_kwargs["image"] = valid_images[0]
elif len(valid_images) > 1:
gen_kwargs["image"] = valid_images
logger.info(f"Generating with FLUX.2 klein: {request.prompt[:80]}...")
# Generate
with torch.inference_mode():
output = self.pipe(**gen_kwargs)
image = output.images[0]
generation_time = time.time() - start_time
logger.info(f"Generated in {generation_time:.2f}s: {image.size}")
return GenerationResult.success_result(
image=image,
message=f"Generated with FLUX.2 klein in {generation_time:.2f}s",
generation_time=generation_time
)
except Exception as e:
logger.error(f"FLUX.2 klein generation failed: {e}", exc_info=True)
return GenerationResult.error_result(f"FLUX.2 klein error: {str(e)}")
def _get_dimensions(self, aspect_ratio: str) -> tuple:
"""Get pixel dimensions for aspect ratio."""
ratio = aspect_ratio.split()[0] if " " in aspect_ratio else aspect_ratio
return self.ASPECT_RATIOS.get(ratio, (1024, 1024))
def is_healthy(self) -> bool:
"""Check if model is loaded and ready."""
return self._loaded and self.pipe is not None
@classmethod
def get_dimensions(cls, aspect_ratio: str) -> tuple:
"""Get pixel dimensions for aspect ratio."""
ratio = aspect_ratio.split()[0] if " " in aspect_ratio else aspect_ratio
return cls.ASPECT_RATIOS.get(ratio, (1024, 1024))
|