OmniStyle2 / app.py
wyjlu's picture
Update app.py
5146a97 verified
#!/usr/bin/env python3
"""
Flux2Klein Style Transfer Demo for Hugging Face Spaces
Input: content image + style image
Output: one stylized image
"""
import os
import traceback
from pathlib import Path
import gradio as gr
import torch
from huggingface_hub import hf_hub_download, snapshot_download
from PIL import Image
try:
import spaces
except Exception:
class _DummySpaces:
@staticmethod
def GPU(func):
return func
spaces = _DummySpaces()
BASE_MODEL_ID = os.getenv("BASE_MODEL_ID", "black-forest-labs/FLUX.2-klein-9B")
TUNED_REPO_ID = os.getenv("TUNED_REPO_ID", "wyjlu/omnistyle2-klein9b-base")
TUNED_WEIGHTS_FILENAME = os.getenv("TUNED_WEIGHTS_FILENAME", "step-3000.safetensors")
HF_TOKEN = os.getenv("HF_TOKEN")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
DEFAULT_PROMPT = "Transfer the style of Figure 2 into Figure 1"
STARTUP_PRELOAD = os.getenv("STARTUP_PRELOAD", "1") == "1"
MAX_SEED = 1024
_PIPE = None
_LOAD_ERROR = None
def list_images(folder: Path):
"""List image files in a folder recursively."""
if not folder.exists():
return []
files = []
for ext in ("*.jpg", "*.jpeg", "*.png", "*.webp", "*.JPG", "*.JPEG", "*.PNG", "*.WEBP"):
files.extend(folder.rglob(ext))
return sorted([str(p) for p in files if p.is_file()])
def build_example_rows():
"""Build up to 4 (content, style, seed) example rows."""
base = Path(__file__).parent
image_exts = (".jpg", ".jpeg", ".png", ".webp", ".JPG", ".JPEG", ".PNG", ".WEBP")
root_images = sorted([p for p in base.iterdir() if p.is_file() and p.suffix in image_exts])
# Preferred naming in root:
# content_01.jpg, style_01.jpg, content_02.jpg, style_02.jpg, ...
content_map = {}
style_map = {}
for p in root_images:
stem = p.stem.lower()
if stem.startswith("content_"):
key = stem[len("content_"):]
content_map[key] = str(p)
elif stem.startswith("style_"):
key = stem[len("style_"):]
style_map[key] = str(p)
paired_keys = sorted(set(content_map.keys()) & set(style_map.keys()))
if paired_keys:
return [[content_map[k], style_map[k], 1] for k in paired_keys[:4]]
# Fallback: pair by sorted order if using generic names.
content_files = [str(p) for p in root_images if p.stem.lower().startswith("content")]
style_files = [str(p) for p in root_images if p.stem.lower().startswith("style")]
n = min(4, len(content_files), len(style_files))
return [[content_files[i], style_files[i], 1] for i in range(n)]
def preprocess_to_square_1024(img: Image.Image) -> Image.Image:
"""Center-crop to 1:1 and resize to 1024x1024."""
img = img.convert("RGB")
w, h = img.size
side = min(w, h)
left = (w - side) // 2
top = (h - side) // 2
cropped = img.crop((left, top, left + side, top + side))
return cropped.resize((1024, 1024), Image.Resampling.LANCZOS)
def resolve_base_model_paths():
"""Download base model files from Hugging Face and return local cache paths."""
cache_dir = snapshot_download(
repo_id=BASE_MODEL_ID,
token=HF_TOKEN if HF_TOKEN else None,
allow_patterns=[
"text_encoder/*.safetensors",
"transformer/*.safetensors",
"vae/diffusion_pytorch_model.safetensors",
"tokenizer/*",
],
)
root = Path(cache_dir)
text_encoder_paths = sorted(str(p) for p in (root / "text_encoder").glob("*.safetensors"))
transformer_paths = sorted(str(p) for p in (root / "transformer").glob("*.safetensors"))
vae_path = str(root / "vae" / "diffusion_pytorch_model.safetensors")
if not text_encoder_paths:
raise RuntimeError(f"No text encoder weights found in cache: {root / 'text_encoder'}")
if not transformer_paths:
raise RuntimeError(f"No transformer weights found in cache: {root / 'transformer'}")
if not Path(vae_path).exists():
raise RuntimeError(f"VAE weights not found: {vae_path}")
return {
"cache_dir": str(root),
"text_encoder_paths": text_encoder_paths,
"transformer_paths": transformer_paths,
"vae_path": vae_path,
}
def load_pipeline():
"""Lazy-load model once and reuse it."""
global _PIPE, _LOAD_ERROR
if _PIPE is not None:
return _PIPE
if _LOAD_ERROR is not None:
raise RuntimeError(_LOAD_ERROR)
# Disable optional CUDA extensions that may be incompatible in some environments.
os.environ.setdefault("DISABLE_FLASH_ATTN", "1")
os.environ.setdefault("XFORMERS_DISABLED", "1")
try:
from diffsynth.core import load_state_dict
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
base_paths = resolve_base_model_paths()
tuned_weight_path = hf_hub_download(
repo_id=TUNED_REPO_ID,
filename=TUNED_WEIGHTS_FILENAME,
token=HF_TOKEN if HF_TOKEN else None,
)
model_path_info = (
f"Base model cache dir: {base_paths['cache_dir']}\n"
f"Text encoder files: {len(base_paths['text_encoder_paths'])}\n"
f"Transformer files: {len(base_paths['transformer_paths'])}\n"
f"VAE path: {base_paths['vae_path']}\n"
f"Tuned weights path: {tuned_weight_path}"
)
print("[Model] Download/Cache resolved:")
print(model_path_info)
_PIPE = Flux2ImagePipeline.from_pretrained(
torch_dtype=DTYPE,
device=DEVICE,
model_configs=[
ModelConfig(path=base_paths["text_encoder_paths"]),
ModelConfig(path=base_paths["transformer_paths"]),
ModelConfig(path=base_paths["vae_path"]),
],
tokenizer_config=ModelConfig(model_id=BASE_MODEL_ID, origin_file_pattern="tokenizer/"),
)
state_dict = load_state_dict(tuned_weight_path, torch_dtype=DTYPE)
_PIPE.dit.load_state_dict(state_dict)
except Exception as e:
_LOAD_ERROR = (
f"Base model: {BASE_MODEL_ID}\n"
f"Tuned model: {TUNED_REPO_ID}/{TUNED_WEIGHTS_FILENAME}\n"
f"Error type: {type(e).__name__}\n"
f"Error message: {e}\n\n"
f"Traceback:\n{traceback.format_exc()}\n"
)
raise RuntimeError(_LOAD_ERROR)
return _PIPE
def preload_pipeline_on_startup():
"""Preload model at startup to reduce first-request latency."""
print("[Startup] Preloading model pipeline...")
try:
load_pipeline()
print("[Startup] Model preloaded successfully.")
except Exception as e:
# Keep service running; full error will be raised on first inference.
print(f"[Startup] Model preload failed: {e}")
@spaces.GPU
def infer(content_image: Image.Image, style_image: Image.Image, seed: int, progress=gr.Progress(track_tqdm=True)):
progress(0.0, desc="Validating inputs...")
if content_image is None or style_image is None:
raise gr.Error("Please upload both content and style images.")
progress(0.15, desc="Checking model status...")
pipe = load_pipeline()
prompt = DEFAULT_PROMPT
progress(0.3, desc="Preprocessing images...")
content = preprocess_to_square_1024(content_image)
style = preprocess_to_square_1024(style_image)
# Compatibility fallback for different diffsynth versions.
progress(0.5, desc="Generating...")
try:
output = pipe(
prompt,
edit_image=[content, style],
seed=int(seed),
rand_device="cuda" if DEVICE.startswith("cuda") else "cpu",
num_inference_steps=20,
cfg_scale=4,
height=1024,
width=1024
)
except Exception:
output = pipe(
prompt,
edit_image=[content, style],
seed=int(seed),
rand_device="cuda" if DEVICE.startswith("cuda") else "cpu",
num_inference_steps=20,
cfg_scale=4,
height=1024,
width=1024
)
if isinstance(output, list):
output = output[0]
progress(1.0, desc="Done")
return output
with gr.Blocks() as demo:
gr.Markdown(
"<h2 style='text-align:center; margin:0;'>"
"Learning to Stylize by Learning to Destylize: A Scalable Paradigm for Supervised Style Transfer"
"</h2>"
)
with gr.Row(equal_height=True):
with gr.Column():
content_input = gr.Image(type="pil", label="Content Image", height=300)
with gr.Column():
style_input = gr.Image(type="pil", label="Style Image", height=300)
with gr.Column():
result_output = gr.Image(type="pil", label="Result", height=300)
seed_input = gr.Slider(minimum=0, maximum=MAX_SEED, step=1, value=1, label="Seed")
run_button = gr.Button("Run", variant="primary")
run_button.click(
fn=infer,
inputs=[content_input, style_input, seed_input],
outputs=[result_output],
)
gr.Markdown("### Examples")
example_rows = build_example_rows()
if example_rows:
gr.Examples(
examples=example_rows,
inputs=[content_input, style_input, seed_input],
outputs=[result_output],
fn=infer,
cache_examples=False,
run_on_click=True,
examples_per_page=4,
)
else:
gr.Markdown(
"No example pairs found in the app root directory. "
"Put files like `content_01.jpg` and `style_01.jpg` next to `app.py`."
)
if __name__ == "__main__":
if STARTUP_PRELOAD:
preload_pipeline_on_startup()
demo.queue(max_size=8).launch(server_name="0.0.0.0", server_port=7860, share=False)