Spaces:
Running on Zero
Running on Zero
| #!/usr/bin/env python3 | |
| """ | |
| Flux2Klein Style Transfer Demo for Hugging Face Spaces | |
| Input: content image + style image | |
| Output: one stylized image | |
| """ | |
| import os | |
| import traceback | |
| from pathlib import Path | |
| import gradio as gr | |
| import torch | |
| from huggingface_hub import hf_hub_download, snapshot_download | |
| from PIL import Image | |
| try: | |
| import spaces | |
| except Exception: | |
| class _DummySpaces: | |
| def GPU(func): | |
| return func | |
| spaces = _DummySpaces() | |
| BASE_MODEL_ID = os.getenv("BASE_MODEL_ID", "black-forest-labs/FLUX.2-klein-9B") | |
| TUNED_REPO_ID = os.getenv("TUNED_REPO_ID", "wyjlu/omnistyle2-klein9b-base") | |
| TUNED_WEIGHTS_FILENAME = os.getenv("TUNED_WEIGHTS_FILENAME", "step-3000.safetensors") | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32 | |
| DEFAULT_PROMPT = "Transfer the style of Figure 2 into Figure 1" | |
| STARTUP_PRELOAD = os.getenv("STARTUP_PRELOAD", "1") == "1" | |
| MAX_SEED = 1024 | |
| _PIPE = None | |
| _LOAD_ERROR = None | |
| def list_images(folder: Path): | |
| """List image files in a folder recursively.""" | |
| if not folder.exists(): | |
| return [] | |
| files = [] | |
| for ext in ("*.jpg", "*.jpeg", "*.png", "*.webp", "*.JPG", "*.JPEG", "*.PNG", "*.WEBP"): | |
| files.extend(folder.rglob(ext)) | |
| return sorted([str(p) for p in files if p.is_file()]) | |
| def build_example_rows(): | |
| """Build up to 4 (content, style, seed) example rows.""" | |
| base = Path(__file__).parent | |
| image_exts = (".jpg", ".jpeg", ".png", ".webp", ".JPG", ".JPEG", ".PNG", ".WEBP") | |
| root_images = sorted([p for p in base.iterdir() if p.is_file() and p.suffix in image_exts]) | |
| # Preferred naming in root: | |
| # content_01.jpg, style_01.jpg, content_02.jpg, style_02.jpg, ... | |
| content_map = {} | |
| style_map = {} | |
| for p in root_images: | |
| stem = p.stem.lower() | |
| if stem.startswith("content_"): | |
| key = stem[len("content_"):] | |
| content_map[key] = str(p) | |
| elif stem.startswith("style_"): | |
| key = stem[len("style_"):] | |
| style_map[key] = str(p) | |
| paired_keys = sorted(set(content_map.keys()) & set(style_map.keys())) | |
| if paired_keys: | |
| return [[content_map[k], style_map[k], 1] for k in paired_keys[:4]] | |
| # Fallback: pair by sorted order if using generic names. | |
| content_files = [str(p) for p in root_images if p.stem.lower().startswith("content")] | |
| style_files = [str(p) for p in root_images if p.stem.lower().startswith("style")] | |
| n = min(4, len(content_files), len(style_files)) | |
| return [[content_files[i], style_files[i], 1] for i in range(n)] | |
| def preprocess_to_square_1024(img: Image.Image) -> Image.Image: | |
| """Center-crop to 1:1 and resize to 1024x1024.""" | |
| img = img.convert("RGB") | |
| w, h = img.size | |
| side = min(w, h) | |
| left = (w - side) // 2 | |
| top = (h - side) // 2 | |
| cropped = img.crop((left, top, left + side, top + side)) | |
| return cropped.resize((1024, 1024), Image.Resampling.LANCZOS) | |
| def resolve_base_model_paths(): | |
| """Download base model files from Hugging Face and return local cache paths.""" | |
| cache_dir = snapshot_download( | |
| repo_id=BASE_MODEL_ID, | |
| token=HF_TOKEN if HF_TOKEN else None, | |
| allow_patterns=[ | |
| "text_encoder/*.safetensors", | |
| "transformer/*.safetensors", | |
| "vae/diffusion_pytorch_model.safetensors", | |
| "tokenizer/*", | |
| ], | |
| ) | |
| root = Path(cache_dir) | |
| text_encoder_paths = sorted(str(p) for p in (root / "text_encoder").glob("*.safetensors")) | |
| transformer_paths = sorted(str(p) for p in (root / "transformer").glob("*.safetensors")) | |
| vae_path = str(root / "vae" / "diffusion_pytorch_model.safetensors") | |
| if not text_encoder_paths: | |
| raise RuntimeError(f"No text encoder weights found in cache: {root / 'text_encoder'}") | |
| if not transformer_paths: | |
| raise RuntimeError(f"No transformer weights found in cache: {root / 'transformer'}") | |
| if not Path(vae_path).exists(): | |
| raise RuntimeError(f"VAE weights not found: {vae_path}") | |
| return { | |
| "cache_dir": str(root), | |
| "text_encoder_paths": text_encoder_paths, | |
| "transformer_paths": transformer_paths, | |
| "vae_path": vae_path, | |
| } | |
| def load_pipeline(): | |
| """Lazy-load model once and reuse it.""" | |
| global _PIPE, _LOAD_ERROR | |
| if _PIPE is not None: | |
| return _PIPE | |
| if _LOAD_ERROR is not None: | |
| raise RuntimeError(_LOAD_ERROR) | |
| # Disable optional CUDA extensions that may be incompatible in some environments. | |
| os.environ.setdefault("DISABLE_FLASH_ATTN", "1") | |
| os.environ.setdefault("XFORMERS_DISABLED", "1") | |
| try: | |
| from diffsynth.core import load_state_dict | |
| from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig | |
| base_paths = resolve_base_model_paths() | |
| tuned_weight_path = hf_hub_download( | |
| repo_id=TUNED_REPO_ID, | |
| filename=TUNED_WEIGHTS_FILENAME, | |
| token=HF_TOKEN if HF_TOKEN else None, | |
| ) | |
| model_path_info = ( | |
| f"Base model cache dir: {base_paths['cache_dir']}\n" | |
| f"Text encoder files: {len(base_paths['text_encoder_paths'])}\n" | |
| f"Transformer files: {len(base_paths['transformer_paths'])}\n" | |
| f"VAE path: {base_paths['vae_path']}\n" | |
| f"Tuned weights path: {tuned_weight_path}" | |
| ) | |
| print("[Model] Download/Cache resolved:") | |
| print(model_path_info) | |
| _PIPE = Flux2ImagePipeline.from_pretrained( | |
| torch_dtype=DTYPE, | |
| device=DEVICE, | |
| model_configs=[ | |
| ModelConfig(path=base_paths["text_encoder_paths"]), | |
| ModelConfig(path=base_paths["transformer_paths"]), | |
| ModelConfig(path=base_paths["vae_path"]), | |
| ], | |
| tokenizer_config=ModelConfig(model_id=BASE_MODEL_ID, origin_file_pattern="tokenizer/"), | |
| ) | |
| state_dict = load_state_dict(tuned_weight_path, torch_dtype=DTYPE) | |
| _PIPE.dit.load_state_dict(state_dict) | |
| except Exception as e: | |
| _LOAD_ERROR = ( | |
| f"Base model: {BASE_MODEL_ID}\n" | |
| f"Tuned model: {TUNED_REPO_ID}/{TUNED_WEIGHTS_FILENAME}\n" | |
| f"Error type: {type(e).__name__}\n" | |
| f"Error message: {e}\n\n" | |
| f"Traceback:\n{traceback.format_exc()}\n" | |
| ) | |
| raise RuntimeError(_LOAD_ERROR) | |
| return _PIPE | |
| def preload_pipeline_on_startup(): | |
| """Preload model at startup to reduce first-request latency.""" | |
| print("[Startup] Preloading model pipeline...") | |
| try: | |
| load_pipeline() | |
| print("[Startup] Model preloaded successfully.") | |
| except Exception as e: | |
| # Keep service running; full error will be raised on first inference. | |
| print(f"[Startup] Model preload failed: {e}") | |
| def infer(content_image: Image.Image, style_image: Image.Image, seed: int, progress=gr.Progress(track_tqdm=True)): | |
| progress(0.0, desc="Validating inputs...") | |
| if content_image is None or style_image is None: | |
| raise gr.Error("Please upload both content and style images.") | |
| progress(0.15, desc="Checking model status...") | |
| pipe = load_pipeline() | |
| prompt = DEFAULT_PROMPT | |
| progress(0.3, desc="Preprocessing images...") | |
| content = preprocess_to_square_1024(content_image) | |
| style = preprocess_to_square_1024(style_image) | |
| # Compatibility fallback for different diffsynth versions. | |
| progress(0.5, desc="Generating...") | |
| try: | |
| output = pipe( | |
| prompt, | |
| edit_image=[content, style], | |
| seed=int(seed), | |
| rand_device="cuda" if DEVICE.startswith("cuda") else "cpu", | |
| num_inference_steps=20, | |
| cfg_scale=4, | |
| height=1024, | |
| width=1024 | |
| ) | |
| except Exception: | |
| output = pipe( | |
| prompt, | |
| edit_image=[content, style], | |
| seed=int(seed), | |
| rand_device="cuda" if DEVICE.startswith("cuda") else "cpu", | |
| num_inference_steps=20, | |
| cfg_scale=4, | |
| height=1024, | |
| width=1024 | |
| ) | |
| if isinstance(output, list): | |
| output = output[0] | |
| progress(1.0, desc="Done") | |
| return output | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| "<h2 style='text-align:center; margin:0;'>" | |
| "Learning to Stylize by Learning to Destylize: A Scalable Paradigm for Supervised Style Transfer" | |
| "</h2>" | |
| ) | |
| with gr.Row(equal_height=True): | |
| with gr.Column(): | |
| content_input = gr.Image(type="pil", label="Content Image", height=300) | |
| with gr.Column(): | |
| style_input = gr.Image(type="pil", label="Style Image", height=300) | |
| with gr.Column(): | |
| result_output = gr.Image(type="pil", label="Result", height=300) | |
| seed_input = gr.Slider(minimum=0, maximum=MAX_SEED, step=1, value=1, label="Seed") | |
| run_button = gr.Button("Run", variant="primary") | |
| run_button.click( | |
| fn=infer, | |
| inputs=[content_input, style_input, seed_input], | |
| outputs=[result_output], | |
| ) | |
| gr.Markdown("### Examples") | |
| example_rows = build_example_rows() | |
| if example_rows: | |
| gr.Examples( | |
| examples=example_rows, | |
| inputs=[content_input, style_input, seed_input], | |
| outputs=[result_output], | |
| fn=infer, | |
| cache_examples=False, | |
| run_on_click=True, | |
| examples_per_page=4, | |
| ) | |
| else: | |
| gr.Markdown( | |
| "No example pairs found in the app root directory. " | |
| "Put files like `content_01.jpg` and `style_01.jpg` next to `app.py`." | |
| ) | |
| if __name__ == "__main__": | |
| if STARTUP_PRELOAD: | |
| preload_pipeline_on_startup() | |
| demo.queue(max_size=8).launch(server_name="0.0.0.0", server_port=7860, share=False) | |