Spaces:
Sleeping
Sleeping
| import json | |
| import os | |
| import random | |
| import shlex | |
| import sys | |
| from dataclasses import asdict, dataclass, field | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Literal, Optional | |
| import torch | |
| from einops import rearrange | |
| from PIL import ExifTags, Image | |
| from flux2.openrouter_api_client import DEFAULT_SAMPLING_PARAMS, OpenRouterAPIClient | |
| from flux2.sampling import ( | |
| batched_prc_img, | |
| batched_prc_txt, | |
| denoise, | |
| encode_image_refs, | |
| get_schedule, | |
| scatter_ids, | |
| ) | |
| from flux2.util import FLUX2_MODEL_INFO, load_ae, load_flow_model, load_mistral_small_embedder | |
| # from flux2.watermark import embed_watermark | |
| class Config: | |
| prompt: str = "a photo of a forest with mist swirling around the tree trunks. The word 'FLUX.2' is painted over it in big, red brush strokes with visible texture" | |
| seed: Optional[int] = None | |
| width: int = 1360 | |
| height: int = 768 | |
| num_steps: int = 50 | |
| guidance: float = 4.0 | |
| input_images: List[Path] = field(default_factory=list) | |
| match_image_size: Optional[int] = None # Index of input_images to match size from | |
| upsample_prompt_mode: Literal["none", "local", "openrouter"] = "none" | |
| openrouter_model: str = "mistralai/pixtral-large-2411" # OpenRouter model name | |
| def copy(self) -> "Config": | |
| return Config( | |
| prompt=self.prompt, | |
| seed=self.seed, | |
| width=self.width, | |
| height=self.height, | |
| num_steps=self.num_steps, | |
| guidance=self.guidance, | |
| input_images=list(self.input_images), | |
| match_image_size=self.match_image_size, | |
| upsample_prompt_mode=self.upsample_prompt_mode, | |
| openrouter_model=self.openrouter_model, | |
| ) | |
| DEFAULTS = Config() | |
| INT_FIELDS = {"width", "height", "seed", "num_steps", "match_image_size"} | |
| FLOAT_FIELDS = {"guidance"} | |
| LIST_FIELDS = {"input_images"} | |
| UPSAMPLING_MODE_FIELDS = ("none", "local", "openrouter") | |
| STR_FIELDS = {"openrouter_model"} | |
| def coerce_value(key: str, raw: str): | |
| """Convert a raw string to the correct field type.""" | |
| if key in INT_FIELDS: | |
| if raw.lower() == "none" or raw == "": | |
| return None | |
| return int(raw) | |
| if key in FLOAT_FIELDS: | |
| return float(raw) | |
| if key in STR_FIELDS: | |
| return raw.strip().strip('"').strip("'") | |
| if key in LIST_FIELDS: | |
| # Handle empty list cases | |
| if raw == "" or raw == "[]": | |
| return [] | |
| # Accept comma-separated or space-separated; strip quotes. | |
| items = [] | |
| # If user passed a single token that contains commas, split on commas. | |
| tokens = [raw] if ("," in raw and " " not in raw) else shlex.split(raw) | |
| for tok in tokens: | |
| for part in tok.split(","): | |
| part = part.strip() | |
| if part: | |
| if os.path.exists(part): | |
| items.append(Path(part)) | |
| else: | |
| print(f"File {part} not found. Skipping for now. Please check your path") | |
| return items | |
| if key == "upsample_prompt_mode": | |
| v = str(raw).strip().strip('"').strip("'").lower() | |
| if v in UPSAMPLING_MODE_FIELDS: | |
| return v | |
| raise ValueError( | |
| f"invalid upsample_prompt_mode: {v}. Must be one of: {', '.join(UPSAMPLING_MODE_FIELDS)}" | |
| ) | |
| # plain strings | |
| return raw | |
| def apply_updates(cfg: Config, updates: Dict[str, Any]) -> None: | |
| for k, v in updates.items(): | |
| if not hasattr(cfg, k): | |
| print(f" ! unknown key: {k}", file=sys.stderr) | |
| continue | |
| # Validate upsample_prompt_mode | |
| if k == "upsample_prompt_mode": | |
| valid_modes = {"none", "local", "openrouter"} | |
| if v not in valid_modes: | |
| print( | |
| f" ! Invalid upsample_prompt_mode: {v}. Must be one of: {', '.join(valid_modes)}", | |
| file=sys.stderr, | |
| ) | |
| continue | |
| setattr(cfg, k, v) | |
| def parse_key_values(line: str) -> Dict[str, Any]: | |
| """ | |
| Parse shell-like 'key=value' pairs. Values can be quoted. | |
| Example: prompt="a dog" width=768 input_images="in1.png,in2.jpg" | |
| """ | |
| updates: Dict[str, Any] = {} | |
| for token in shlex.split(line): | |
| if "=" not in token: | |
| # Allow bare commands like: run, show, reset, quit | |
| updates[token] = True | |
| continue | |
| key, val = token.split("=", 1) | |
| key = key.strip() | |
| val = val.strip() | |
| try: | |
| updates[key] = coerce_value(key, val) | |
| except Exception as e: | |
| print(f" ! could not parse {key}={val!r}: {e}", file=sys.stderr) | |
| return updates | |
| def print_config(cfg: Config): | |
| d = asdict(cfg) | |
| d["input_images"] = [str(p) for p in cfg.input_images] | |
| print("Current config:") | |
| for k in [ | |
| "prompt", | |
| "seed", | |
| "width", | |
| "height", | |
| "num_steps", | |
| "guidance", | |
| "input_images", | |
| "match_image_size", | |
| "upsample_prompt_mode", | |
| "openrouter_model", | |
| ]: | |
| print(f" {k}: {d[k]}") | |
| print() | |
| def print_help(): | |
| print(""" | |
| Available commands: | |
| [Enter] - Run generation with current config | |
| run - Run generation with current config | |
| show - Show current configuration | |
| reset - Reset configuration to defaults | |
| help, h, ? - Show this help message | |
| quit, q, exit - Exit the program | |
| Setting parameters: | |
| key=value - Update a config parameter (shows updated config, doesn't run) | |
| Examples: | |
| prompt="a cat in a hat" | |
| width=768 height=768 | |
| seed=42 | |
| num_steps=30 | |
| guidance=3.5 | |
| input_images="img1.jpg,img2.jpg" | |
| match_image_size=0 (use dimensions from first input image) | |
| upsample_prompt_mode="none" (prompt upsampling mode: "none", "local", or "openrouter") | |
| openrouter_model="mistralai/pixtral-large-2411" (OpenRouter model name) | |
| You can combine parameter updates: | |
| prompt="sunset" width=1920 height=1080 | |
| Parameters: | |
| prompt - Text prompt for generation (string) | |
| seed - Random seed (integer or 'none' for random) | |
| width - Output width in pixels (integer) | |
| height - Output height in pixels (integer) | |
| num_steps - Number of denoising steps (integer) | |
| guidance - Guidance scale (float) | |
| input_images - Comma-separated list of input image paths (list) | |
| match_image_size - Index of input image to match dimensions from (integer, 0-based) | |
| upsample_prompt_mode - Prompt upsampling mode: "none" (default), "local", or "openrouter" (string) | |
| openrouter_model - OpenRouter model name (string, default: "mistralai/pixtral-large-2411") | |
| Examples: "mistralai/pixtral-large-2411", "qwen/qwen3-vl-235b-a22b-instruct", etc. | |
| Note: For "openrouter" mode, set OPENROUTER_API_KEY environment variable | |
| """) | |
| # ---------- Main Loop ---------- | |
| def main( | |
| model_name: str = "flux.2-dev", | |
| single_eval: bool = False, | |
| prompt: str | None = None, | |
| debug_mode: bool = False, | |
| cpu_offloading: bool = False, | |
| **overwrite, | |
| ): | |
| assert ( | |
| model_name.lower() in FLUX2_MODEL_INFO | |
| ), f"{model_name} is not available, choose from {FLUX2_MODEL_INFO.keys()}" | |
| torch_device = torch.device("cuda") | |
| mistral = load_mistral_small_embedder() | |
| model = load_flow_model( | |
| model_name, debug_mode=debug_mode, device="cpu" if cpu_offloading else torch_device | |
| ) | |
| ae = load_ae(model_name) | |
| ae.eval() | |
| mistral.eval() | |
| # API client will be initialized lazily when needed | |
| openrouter_api_client: Optional[OpenRouterAPIClient] = None | |
| cfg = DEFAULTS.copy() | |
| changes = [f"{key}={value}" for key, value in overwrite.items()] | |
| updates = parse_key_values(" ".join(changes)) | |
| apply_updates(cfg, updates) | |
| if prompt is not None: | |
| cfg.prompt = prompt | |
| print_config(cfg) | |
| while True: | |
| if not single_eval: | |
| try: | |
| line = input("> ").strip() | |
| except (EOFError, KeyboardInterrupt): | |
| print("\nbye!") | |
| break | |
| if not line: | |
| # Empty -> run with current config | |
| cmd = "run" | |
| updates = {} | |
| else: | |
| try: | |
| updates = parse_key_values(line) | |
| except Exception as e: # noqa: BLE001 | |
| print(f" ! Failed to parse command: {type(e).__name__}: {e}", file=sys.stderr) | |
| print( | |
| " ! Please check your syntax (e.g., matching quotes) and try again.\n", | |
| file=sys.stderr, | |
| ) | |
| continue | |
| if "prompt" in updates and mistral.test_txt(updates["prompt"]): | |
| print( | |
| "Your prompt has been flagged for potential copyright or public personas concerns. Please choose another." | |
| ) | |
| updates.pop("prompt") | |
| if "input_images" in updates: | |
| flagged = False | |
| for image in updates["input_images"]: | |
| if mistral.test_image(image): | |
| print(f"The image {image} has been flagged as unsuitable. Please choose another.") | |
| flagged = True | |
| if flagged: | |
| updates.pop("input_images") | |
| # If the line was only 'run' / 'show' / ... it will appear as {cmd: True} | |
| # If it had key=val pairs, there may be no bare command -> just update config | |
| bare_cmds = [k for k, v in updates.items() if v is True and k.isalpha()] | |
| cmd = bare_cmds[0] if bare_cmds else None | |
| # Remove bare commands from updates so they don't get applied as fields | |
| for c in bare_cmds: | |
| updates.pop(c, None) | |
| if cmd in ("quit", "q", "exit"): | |
| print("bye!") | |
| break | |
| elif cmd == "reset": | |
| cfg = DEFAULTS.copy() | |
| print_config(cfg) | |
| continue | |
| elif cmd == "show": | |
| print_config(cfg) | |
| continue | |
| elif cmd in ("help", "h", "?"): | |
| print_help() | |
| continue | |
| # Apply key=value changes | |
| if updates: | |
| apply_updates(cfg, updates) | |
| print_config(cfg) | |
| continue | |
| # Only run if explicitly requested (empty line or 'run' command) | |
| if cmd != "run": | |
| if cmd is not None: | |
| print(f" ! Unknown command: '{cmd}'", file=sys.stderr) | |
| print(" ! Type 'help' to see available commands.\n", file=sys.stderr) | |
| continue | |
| try: | |
| # Load input images first to potentially match dimensions | |
| img_ctx = [Image.open(input_image) for input_image in cfg.input_images] | |
| # Apply match_image_size if specified | |
| width = cfg.width | |
| height = cfg.height | |
| if cfg.match_image_size is not None: | |
| if cfg.match_image_size < 0 or cfg.match_image_size >= len(img_ctx): | |
| print( | |
| f" ! match_image_size={cfg.match_image_size} is out of range (0-{len(img_ctx)-1})", | |
| file=sys.stderr, | |
| ) | |
| print(f" ! Using default dimensions: {width}x{height}", file=sys.stderr) | |
| else: | |
| ref_img = img_ctx[cfg.match_image_size] | |
| width, height = ref_img.size | |
| print(f" Matched dimensions from image {cfg.match_image_size}: {width}x{height}") | |
| seed = cfg.seed if cfg.seed is not None else random.randrange(2**31) | |
| dir = Path("output") | |
| dir.mkdir(exist_ok=True) | |
| output_name = dir / f"sample_{len(list(dir.glob('*')))}.png" | |
| with torch.no_grad(): | |
| ref_tokens, ref_ids = encode_image_refs(ae, img_ctx) | |
| if cfg.upsample_prompt_mode == "openrouter": | |
| try: | |
| # Ensure API key is available, otherwise prompt the user | |
| api_key = os.environ.get("OPENROUTER_API_KEY", "").strip() | |
| if not api_key: | |
| try: | |
| entered = input( | |
| "OPENROUTER_API_KEY not set. Enter it now (leave blank to skip OpenRouter upsampling): " | |
| ).strip() | |
| except (EOFError, KeyboardInterrupt): | |
| entered = "" | |
| if entered: | |
| os.environ["OPENROUTER_API_KEY"] = entered | |
| else: | |
| print( | |
| " ! No API key provided; disabling OpenRouter upsampling", | |
| file=sys.stderr, | |
| ) | |
| cfg.upsample_prompt_mode = "none" | |
| prompt = cfg.prompt | |
| # Skip OpenRouter flow | |
| # Only proceed if still in openrouter mode (not disabled above) | |
| if cfg.upsample_prompt_mode == "openrouter": | |
| # Let user specify sampling params, or use model defaults if available | |
| sampling_params_input = "" | |
| try: | |
| sampling_params_input = input( | |
| "Enter OpenRouter sampling params as JSON or key=value (blank to use defaults): " | |
| ).strip() | |
| except (EOFError, KeyboardInterrupt): | |
| sampling_params_input = "" | |
| sampling_params: Dict[str, Any] = {} | |
| if sampling_params_input: | |
| # Try JSON first | |
| parsed_ok = False | |
| try: | |
| parsed = json.loads(sampling_params_input) | |
| if isinstance(parsed, dict): | |
| sampling_params = parsed | |
| parsed_ok = True | |
| except Exception: | |
| parsed_ok = False | |
| if not parsed_ok: | |
| # Fallback: parse key=value pairs separated by spaces or commas | |
| tokens = [ | |
| tok | |
| for tok in sampling_params_input.replace(",", " ").split(" ") | |
| if tok | |
| ] | |
| for tok in tokens: | |
| if "=" not in tok: | |
| continue | |
| k, v = tok.split("=", 1) | |
| v_str = v.strip() | |
| v_low = v_str.lower() | |
| if v_low in {"true", "false"}: | |
| val: Any = v_low == "true" | |
| else: | |
| try: | |
| if "." in v_str: | |
| num = float(v_str) | |
| val = int(num) if num.is_integer() else num | |
| else: | |
| val = int(v_str) | |
| except Exception: | |
| val = v_str | |
| sampling_params[k.strip()] = val | |
| print(f" Using custom OpenRouter sampling params: {sampling_params}") | |
| else: | |
| model_key = cfg.openrouter_model | |
| default_params = DEFAULT_SAMPLING_PARAMS.get(model_key) | |
| if default_params: | |
| sampling_params = default_params | |
| print( | |
| f" Using default OpenRouter sampling params for {model_key}: {sampling_params}" | |
| ) | |
| else: | |
| print( | |
| f" Setting no OpenRouter sampling params: not set for this model ({model_key})" | |
| ) | |
| # Initialize or reinitialize client if model changed | |
| if ( | |
| openrouter_api_client is None | |
| or openrouter_api_client.model != cfg.openrouter_model | |
| or getattr(openrouter_api_client, "sampling_params", None) != sampling_params | |
| ): | |
| openrouter_api_client = OpenRouterAPIClient( | |
| model=cfg.openrouter_model, | |
| sampling_params=sampling_params, | |
| ) | |
| else: | |
| # Ensure client uses latest sampling params | |
| openrouter_api_client.sampling_params = sampling_params | |
| upsampled_prompts = openrouter_api_client.upsample_prompt( | |
| [cfg.prompt], img=[img_ctx] if img_ctx else None | |
| ) | |
| prompt = upsampled_prompts[0] if upsampled_prompts else cfg.prompt | |
| except Exception as e: | |
| print(f" ! Failed to upsample prompt via OpenRouter API: {e}", file=sys.stderr) | |
| print( | |
| " ! Disabling OpenRouter upsampling and falling back to original prompt", | |
| file=sys.stderr, | |
| ) | |
| cfg.upsample_prompt_mode = "none" | |
| prompt = cfg.prompt | |
| elif cfg.upsample_prompt_mode == "local": | |
| # Use local model for upsampling | |
| upsampled_prompts = mistral.upsample_prompt( | |
| [cfg.prompt], img=[img_ctx] if img_ctx else None | |
| ) | |
| prompt = upsampled_prompts[0] if upsampled_prompts else cfg.prompt | |
| else: | |
| # upsample_prompt_mode == "none" or invalid value | |
| prompt = cfg.prompt | |
| print("Generating with prompt: ", prompt) | |
| ctx = mistral([prompt]).to(torch.bfloat16) | |
| ctx, ctx_ids = batched_prc_txt(ctx) | |
| if cpu_offloading: | |
| mistral = mistral.cpu() | |
| torch.cuda.empty_cache() | |
| model = model.to(torch_device) | |
| # Create noise | |
| shape = (1, 128, height // 16, width // 16) | |
| generator = torch.Generator(device="cuda").manual_seed(seed) | |
| randn = torch.randn(shape, generator=generator, dtype=torch.bfloat16, device="cuda") | |
| x, x_ids = batched_prc_img(randn) | |
| timesteps = get_schedule(cfg.num_steps, x.shape[1]) | |
| x = denoise( | |
| model, | |
| x, | |
| x_ids, | |
| ctx, | |
| ctx_ids, | |
| timesteps=timesteps, | |
| guidance=cfg.guidance, | |
| img_cond_seq=ref_tokens, | |
| img_cond_seq_ids=ref_ids, | |
| ) | |
| x = torch.cat(scatter_ids(x, x_ids)).squeeze(2) | |
| x = ae.decode(x).float() | |
| # x = embed_watermark(x) | |
| if cpu_offloading: | |
| model = model.cpu() | |
| torch.cuda.empty_cache() | |
| mistral = mistral.to(torch_device) | |
| x = x.clamp(-1, 1) | |
| x = rearrange(x[0], "c h w -> h w c") | |
| img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy()) | |
| if mistral.test_image(img): | |
| print("Your output has been flagged. Please choose another prompt / input image combination") | |
| else: | |
| exif_data = Image.Exif() | |
| exif_data[ExifTags.Base.Software] = "AI generated;flux2" | |
| exif_data[ExifTags.Base.Make] = "Black Forest Labs" | |
| img.save(output_name, exif=exif_data, quality=95, subsampling=0) | |
| print(f"Saved {output_name}") | |
| except Exception as e: # noqa: BLE001 | |
| print(f"\n ERROR: {type(e).__name__}: {e}", file=sys.stderr) | |
| print(" The model is still loaded. Please fix the error and try again.\n", file=sys.stderr) | |
| if single_eval: | |
| break | |
| if __name__ == "__main__": | |
| from fire import Fire | |
| Fire(main) | |