import json import os import random import shlex import sys from dataclasses import asdict, dataclass, field from pathlib import Path from typing import Any, Dict, List, Literal, Optional import torch from einops import rearrange from PIL import ExifTags, Image from flux2.openrouter_api_client import DEFAULT_SAMPLING_PARAMS, OpenRouterAPIClient from flux2.sampling import ( batched_prc_img, batched_prc_txt, denoise, encode_image_refs, get_schedule, scatter_ids, ) from flux2.util import FLUX2_MODEL_INFO, load_ae, load_flow_model, load_mistral_small_embedder # from flux2.watermark import embed_watermark @dataclass class Config: prompt: str = "a photo of a forest with mist swirling around the tree trunks. The word 'FLUX.2' is painted over it in big, red brush strokes with visible texture" seed: Optional[int] = None width: int = 1360 height: int = 768 num_steps: int = 50 guidance: float = 4.0 input_images: List[Path] = field(default_factory=list) match_image_size: Optional[int] = None # Index of input_images to match size from upsample_prompt_mode: Literal["none", "local", "openrouter"] = "none" openrouter_model: str = "mistralai/pixtral-large-2411" # OpenRouter model name def copy(self) -> "Config": return Config( prompt=self.prompt, seed=self.seed, width=self.width, height=self.height, num_steps=self.num_steps, guidance=self.guidance, input_images=list(self.input_images), match_image_size=self.match_image_size, upsample_prompt_mode=self.upsample_prompt_mode, openrouter_model=self.openrouter_model, ) DEFAULTS = Config() INT_FIELDS = {"width", "height", "seed", "num_steps", "match_image_size"} FLOAT_FIELDS = {"guidance"} LIST_FIELDS = {"input_images"} UPSAMPLING_MODE_FIELDS = ("none", "local", "openrouter") STR_FIELDS = {"openrouter_model"} def coerce_value(key: str, raw: str): """Convert a raw string to the correct field type.""" if key in INT_FIELDS: if raw.lower() == "none" or raw == "": return None return int(raw) if key in FLOAT_FIELDS: return float(raw) if key in STR_FIELDS: return raw.strip().strip('"').strip("'") if key in LIST_FIELDS: # Handle empty list cases if raw == "" or raw == "[]": return [] # Accept comma-separated or space-separated; strip quotes. items = [] # If user passed a single token that contains commas, split on commas. tokens = [raw] if ("," in raw and " " not in raw) else shlex.split(raw) for tok in tokens: for part in tok.split(","): part = part.strip() if part: if os.path.exists(part): items.append(Path(part)) else: print(f"File {part} not found. Skipping for now. Please check your path") return items if key == "upsample_prompt_mode": v = str(raw).strip().strip('"').strip("'").lower() if v in UPSAMPLING_MODE_FIELDS: return v raise ValueError( f"invalid upsample_prompt_mode: {v}. Must be one of: {', '.join(UPSAMPLING_MODE_FIELDS)}" ) # plain strings return raw def apply_updates(cfg: Config, updates: Dict[str, Any]) -> None: for k, v in updates.items(): if not hasattr(cfg, k): print(f" ! unknown key: {k}", file=sys.stderr) continue # Validate upsample_prompt_mode if k == "upsample_prompt_mode": valid_modes = {"none", "local", "openrouter"} if v not in valid_modes: print( f" ! Invalid upsample_prompt_mode: {v}. Must be one of: {', '.join(valid_modes)}", file=sys.stderr, ) continue setattr(cfg, k, v) def parse_key_values(line: str) -> Dict[str, Any]: """ Parse shell-like 'key=value' pairs. Values can be quoted. Example: prompt="a dog" width=768 input_images="in1.png,in2.jpg" """ updates: Dict[str, Any] = {} for token in shlex.split(line): if "=" not in token: # Allow bare commands like: run, show, reset, quit updates[token] = True continue key, val = token.split("=", 1) key = key.strip() val = val.strip() try: updates[key] = coerce_value(key, val) except Exception as e: print(f" ! could not parse {key}={val!r}: {e}", file=sys.stderr) return updates def print_config(cfg: Config): d = asdict(cfg) d["input_images"] = [str(p) for p in cfg.input_images] print("Current config:") for k in [ "prompt", "seed", "width", "height", "num_steps", "guidance", "input_images", "match_image_size", "upsample_prompt_mode", "openrouter_model", ]: print(f" {k}: {d[k]}") print() def print_help(): print(""" Available commands: [Enter] - Run generation with current config run - Run generation with current config show - Show current configuration reset - Reset configuration to defaults help, h, ? - Show this help message quit, q, exit - Exit the program Setting parameters: key=value - Update a config parameter (shows updated config, doesn't run) Examples: prompt="a cat in a hat" width=768 height=768 seed=42 num_steps=30 guidance=3.5 input_images="img1.jpg,img2.jpg" match_image_size=0 (use dimensions from first input image) upsample_prompt_mode="none" (prompt upsampling mode: "none", "local", or "openrouter") openrouter_model="mistralai/pixtral-large-2411" (OpenRouter model name) You can combine parameter updates: prompt="sunset" width=1920 height=1080 Parameters: prompt - Text prompt for generation (string) seed - Random seed (integer or 'none' for random) width - Output width in pixels (integer) height - Output height in pixels (integer) num_steps - Number of denoising steps (integer) guidance - Guidance scale (float) input_images - Comma-separated list of input image paths (list) match_image_size - Index of input image to match dimensions from (integer, 0-based) upsample_prompt_mode - Prompt upsampling mode: "none" (default), "local", or "openrouter" (string) openrouter_model - OpenRouter model name (string, default: "mistralai/pixtral-large-2411") Examples: "mistralai/pixtral-large-2411", "qwen/qwen3-vl-235b-a22b-instruct", etc. Note: For "openrouter" mode, set OPENROUTER_API_KEY environment variable """) # ---------- Main Loop ---------- def main( model_name: str = "flux.2-dev", single_eval: bool = False, prompt: str | None = None, debug_mode: bool = False, cpu_offloading: bool = False, **overwrite, ): assert ( model_name.lower() in FLUX2_MODEL_INFO ), f"{model_name} is not available, choose from {FLUX2_MODEL_INFO.keys()}" torch_device = torch.device("cuda") mistral = load_mistral_small_embedder() model = load_flow_model( model_name, debug_mode=debug_mode, device="cpu" if cpu_offloading else torch_device ) ae = load_ae(model_name) ae.eval() mistral.eval() # API client will be initialized lazily when needed openrouter_api_client: Optional[OpenRouterAPIClient] = None cfg = DEFAULTS.copy() changes = [f"{key}={value}" for key, value in overwrite.items()] updates = parse_key_values(" ".join(changes)) apply_updates(cfg, updates) if prompt is not None: cfg.prompt = prompt print_config(cfg) while True: if not single_eval: try: line = input("> ").strip() except (EOFError, KeyboardInterrupt): print("\nbye!") break if not line: # Empty -> run with current config cmd = "run" updates = {} else: try: updates = parse_key_values(line) except Exception as e: # noqa: BLE001 print(f" ! Failed to parse command: {type(e).__name__}: {e}", file=sys.stderr) print( " ! Please check your syntax (e.g., matching quotes) and try again.\n", file=sys.stderr, ) continue if "prompt" in updates and mistral.test_txt(updates["prompt"]): print( "Your prompt has been flagged for potential copyright or public personas concerns. Please choose another." ) updates.pop("prompt") if "input_images" in updates: flagged = False for image in updates["input_images"]: if mistral.test_image(image): print(f"The image {image} has been flagged as unsuitable. Please choose another.") flagged = True if flagged: updates.pop("input_images") # If the line was only 'run' / 'show' / ... it will appear as {cmd: True} # If it had key=val pairs, there may be no bare command -> just update config bare_cmds = [k for k, v in updates.items() if v is True and k.isalpha()] cmd = bare_cmds[0] if bare_cmds else None # Remove bare commands from updates so they don't get applied as fields for c in bare_cmds: updates.pop(c, None) if cmd in ("quit", "q", "exit"): print("bye!") break elif cmd == "reset": cfg = DEFAULTS.copy() print_config(cfg) continue elif cmd == "show": print_config(cfg) continue elif cmd in ("help", "h", "?"): print_help() continue # Apply key=value changes if updates: apply_updates(cfg, updates) print_config(cfg) continue # Only run if explicitly requested (empty line or 'run' command) if cmd != "run": if cmd is not None: print(f" ! Unknown command: '{cmd}'", file=sys.stderr) print(" ! Type 'help' to see available commands.\n", file=sys.stderr) continue try: # Load input images first to potentially match dimensions img_ctx = [Image.open(input_image) for input_image in cfg.input_images] # Apply match_image_size if specified width = cfg.width height = cfg.height if cfg.match_image_size is not None: if cfg.match_image_size < 0 or cfg.match_image_size >= len(img_ctx): print( f" ! match_image_size={cfg.match_image_size} is out of range (0-{len(img_ctx)-1})", file=sys.stderr, ) print(f" ! Using default dimensions: {width}x{height}", file=sys.stderr) else: ref_img = img_ctx[cfg.match_image_size] width, height = ref_img.size print(f" Matched dimensions from image {cfg.match_image_size}: {width}x{height}") seed = cfg.seed if cfg.seed is not None else random.randrange(2**31) dir = Path("output") dir.mkdir(exist_ok=True) output_name = dir / f"sample_{len(list(dir.glob('*')))}.png" with torch.no_grad(): ref_tokens, ref_ids = encode_image_refs(ae, img_ctx) if cfg.upsample_prompt_mode == "openrouter": try: # Ensure API key is available, otherwise prompt the user api_key = os.environ.get("OPENROUTER_API_KEY", "").strip() if not api_key: try: entered = input( "OPENROUTER_API_KEY not set. Enter it now (leave blank to skip OpenRouter upsampling): " ).strip() except (EOFError, KeyboardInterrupt): entered = "" if entered: os.environ["OPENROUTER_API_KEY"] = entered else: print( " ! No API key provided; disabling OpenRouter upsampling", file=sys.stderr, ) cfg.upsample_prompt_mode = "none" prompt = cfg.prompt # Skip OpenRouter flow # Only proceed if still in openrouter mode (not disabled above) if cfg.upsample_prompt_mode == "openrouter": # Let user specify sampling params, or use model defaults if available sampling_params_input = "" try: sampling_params_input = input( "Enter OpenRouter sampling params as JSON or key=value (blank to use defaults): " ).strip() except (EOFError, KeyboardInterrupt): sampling_params_input = "" sampling_params: Dict[str, Any] = {} if sampling_params_input: # Try JSON first parsed_ok = False try: parsed = json.loads(sampling_params_input) if isinstance(parsed, dict): sampling_params = parsed parsed_ok = True except Exception: parsed_ok = False if not parsed_ok: # Fallback: parse key=value pairs separated by spaces or commas tokens = [ tok for tok in sampling_params_input.replace(",", " ").split(" ") if tok ] for tok in tokens: if "=" not in tok: continue k, v = tok.split("=", 1) v_str = v.strip() v_low = v_str.lower() if v_low in {"true", "false"}: val: Any = v_low == "true" else: try: if "." in v_str: num = float(v_str) val = int(num) if num.is_integer() else num else: val = int(v_str) except Exception: val = v_str sampling_params[k.strip()] = val print(f" Using custom OpenRouter sampling params: {sampling_params}") else: model_key = cfg.openrouter_model default_params = DEFAULT_SAMPLING_PARAMS.get(model_key) if default_params: sampling_params = default_params print( f" Using default OpenRouter sampling params for {model_key}: {sampling_params}" ) else: print( f" Setting no OpenRouter sampling params: not set for this model ({model_key})" ) # Initialize or reinitialize client if model changed if ( openrouter_api_client is None or openrouter_api_client.model != cfg.openrouter_model or getattr(openrouter_api_client, "sampling_params", None) != sampling_params ): openrouter_api_client = OpenRouterAPIClient( model=cfg.openrouter_model, sampling_params=sampling_params, ) else: # Ensure client uses latest sampling params openrouter_api_client.sampling_params = sampling_params upsampled_prompts = openrouter_api_client.upsample_prompt( [cfg.prompt], img=[img_ctx] if img_ctx else None ) prompt = upsampled_prompts[0] if upsampled_prompts else cfg.prompt except Exception as e: print(f" ! Failed to upsample prompt via OpenRouter API: {e}", file=sys.stderr) print( " ! Disabling OpenRouter upsampling and falling back to original prompt", file=sys.stderr, ) cfg.upsample_prompt_mode = "none" prompt = cfg.prompt elif cfg.upsample_prompt_mode == "local": # Use local model for upsampling upsampled_prompts = mistral.upsample_prompt( [cfg.prompt], img=[img_ctx] if img_ctx else None ) prompt = upsampled_prompts[0] if upsampled_prompts else cfg.prompt else: # upsample_prompt_mode == "none" or invalid value prompt = cfg.prompt print("Generating with prompt: ", prompt) ctx = mistral([prompt]).to(torch.bfloat16) ctx, ctx_ids = batched_prc_txt(ctx) if cpu_offloading: mistral = mistral.cpu() torch.cuda.empty_cache() model = model.to(torch_device) # Create noise shape = (1, 128, height // 16, width // 16) generator = torch.Generator(device="cuda").manual_seed(seed) randn = torch.randn(shape, generator=generator, dtype=torch.bfloat16, device="cuda") x, x_ids = batched_prc_img(randn) timesteps = get_schedule(cfg.num_steps, x.shape[1]) x = denoise( model, x, x_ids, ctx, ctx_ids, timesteps=timesteps, guidance=cfg.guidance, img_cond_seq=ref_tokens, img_cond_seq_ids=ref_ids, ) x = torch.cat(scatter_ids(x, x_ids)).squeeze(2) x = ae.decode(x).float() # x = embed_watermark(x) if cpu_offloading: model = model.cpu() torch.cuda.empty_cache() mistral = mistral.to(torch_device) x = x.clamp(-1, 1) x = rearrange(x[0], "c h w -> h w c") img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy()) if mistral.test_image(img): print("Your output has been flagged. Please choose another prompt / input image combination") else: exif_data = Image.Exif() exif_data[ExifTags.Base.Software] = "AI generated;flux2" exif_data[ExifTags.Base.Make] = "Black Forest Labs" img.save(output_name, exif=exif_data, quality=95, subsampling=0) print(f"Saved {output_name}") except Exception as e: # noqa: BLE001 print(f"\n ERROR: {type(e).__name__}: {e}", file=sys.stderr) print(" The model is still loaded. Please fix the error and try again.\n", file=sys.stderr) if single_eval: break if __name__ == "__main__": from fire import Fire Fire(main)