Spaces:
Paused
Paused
| import spaces | |
| import gradio as gr | |
| import torch | |
| from typing import TypedDict | |
| from PIL import Image, ImageDraw, ImageFont | |
| from diffusers.pipelines import FluxPipeline | |
| from diffusers import FluxTransformer2DModel | |
| import numpy as np | |
| import examples_db | |
| from flux.condition import Condition | |
| from flux.generate import seed_everything, generate | |
| from flux.lora_controller import set_lora_scale | |
| pipe = None | |
| current_adapter = None | |
| use_int8 = False | |
| model_config = { "union_cond_attn": True, "add_cond_attn": False, "latent_lora": False, "independent_condition": True} | |
| def get_gpu_memory(): | |
| return torch.cuda.get_device_properties(0).total_memory / 1024**3 | |
| def init_pipeline(): | |
| global pipe | |
| if use_int8 or get_gpu_memory() < 33: | |
| transformer_model = FluxTransformer2DModel.from_pretrained( | |
| "sayakpaul/flux.1-schell-int8wo-improved", | |
| torch_dtype=torch.bfloat16, | |
| use_safetensors=False, | |
| ) | |
| pipe = FluxPipeline.from_pretrained( | |
| "black-forest-labs/FLUX.1-schnell", | |
| transformer=transformer_model, | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| else: | |
| pipe = FluxPipeline.from_pretrained( | |
| "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16 | |
| ) | |
| pipe = pipe.to("cuda") | |
| # Optional: Load additional LoRA weights | |
| pipe.load_lora_weights( | |
| "fotographerai/zenctrl_tools", | |
| weight_name="weights/zen2con_1024_10000/" | |
| "pytorch_lora_weights.safetensors", | |
| adapter_name="subject" | |
| ) | |
| # Optional: Load additional LoRA weights | |
| #pipe.load_lora_weights("XLabs-AI/flux-RealismLora", adapter_name="realism") | |
| def paste_on_white_background(image: Image.Image) -> Image.Image: | |
| """ | |
| Pastes a transparent image onto a white background of the same size. | |
| """ | |
| if image.mode != "RGBA": | |
| image = image.convert("RGBA") | |
| # Create white background | |
| white_bg = Image.new("RGBA", image.size, (255, 255, 255, 255)) | |
| white_bg.paste(image, (0, 0), mask=image) | |
| return white_bg.convert("RGB") # Convert back to RGB if you don't need alpha | |
| #@spaces.GPU | |
| def process_image_and_text(image, text, steps=8, strength_sub=1.0, strength_spat=1.0, size=1024): | |
| # center crop image | |
| w, h, min_size = image.size[0], image.size[1], min(image.size) | |
| image = image.crop( | |
| ( | |
| (w - min_size) // 2, | |
| (h - min_size) // 2, | |
| (w + min_size) // 2, | |
| (h + min_size) // 2, | |
| ) | |
| ) | |
| image = image.resize((size, size)) | |
| image = paste_on_white_background(image) | |
| condition0 = Condition("subject", image, position_delta=(0, size // 16)) | |
| condition1 = Condition("subject", image, position_delta=(0, -size // 16)) | |
| pipe = get_pipeline() | |
| with set_lora_scale(["subject"], scale=3.0): | |
| result_img = generate( | |
| pipe, | |
| prompt=text.strip(), | |
| conditions=[condition0, condition1], | |
| num_inference_steps=steps, | |
| height=1024, | |
| width=1024, | |
| condition_scale = [strength_sub,strength_spat], | |
| model_config=model_config, | |
| ).images[0] | |
| return result_img | |
| # ================== MODE CONFIG ===================== | |
| Mode = TypedDict( | |
| "Mode", | |
| { | |
| "model": str, | |
| "prompt": str, | |
| "default_strength": float, | |
| "default_height": int, | |
| "default_width": int, | |
| "models": list[str], | |
| "remove_bg": bool, | |
| }, | |
| ) | |
| MODEL_TO_LORA: dict[str, str] = { | |
| # dropdown-value # relative path inside the HF repo | |
| "zen2con_1024_10000": "weights/zen2con_1024_10000/pytorch_lora_weights.safetensors", | |
| "zen2con_1440_17000": "weights/zen2con_1440_17000/pytorch_lora_weights.safetensors", | |
| "zen_sub_sub_1024_10000": "weights/zen_sub_sub_1024_10000/pytorch_lora_weights.safetensors", | |
| "zen_toys_1024_4000": "weights/zen_toys_1024_4000/12000/pytorch_lora_weights.safetensors", | |
| "zen_toys_1024_15000": "weights/zen_toys_1024_4000/zen_toys_1024_15000/pytorch_lora_weights.safetensors", | |
| # add more as you upload them | |
| } | |
| MODE_DEFAULTS: dict[str, Mode] = { | |
| "Subject Generation": { | |
| "model": "zen2con_1024_10000", | |
| "prompt": "A vibrant background with dynamic lighting and textures", | |
| "default_strength": 1.2, | |
| "default_height": 1024, | |
| "default_width": 1024, | |
| "models": list(MODEL_TO_LORA.keys()), | |
| "remove_bg": True, | |
| }, | |
| #"Image fix": { | |
| # "model": "zen_toys_1024_4000", | |
| # "prompt": "A detailed portrait with soft lighting", | |
| # "default_strength": 1.2, | |
| # "default_height": 1024, | |
| # "default_width": 1024, | |
| # "models": ["weights/zen_toys_1024_4000/12000/", "weights/zen_toys_1024_4000/12000/"], | |
| # "remove_bg": True, | |
| #} | |
| } | |
| def get_pipeline(): | |
| """Lazy-build the pipeline inside the GPU worker.""" | |
| global pipe | |
| if pipe is None: | |
| init_pipeline() # safe here β this fn is @spaces.GPU wrapped | |
| return pipe | |
| def get_samples(): | |
| sample_list = [ | |
| { | |
| "image": "samples/1.png", | |
| "text": "A very close up view of this item. It is placed on a wooden table. The background is a dark room, the TV is on, and the screen is showing a cooking show. With text on the screen that reads 'Omini Control!'", | |
| }, | |
| { | |
| "image": "samples/2.png", | |
| "text": "On Christmas evening, on a crowded sidewalk, this item sits on the road, covered in snow and wearing a Christmas hat, holding a sign that reads 'Omini Control!'", | |
| }, | |
| { | |
| "image": "samples/3.png", | |
| "text": "A film style shot. On the moon, this item drives across the moon surface. The background is that Earth looms large in the foreground.", | |
| }, | |
| { | |
| "image": "samples/4.png", | |
| "text": "In a Bauhaus style room, this item is placed on a shiny glass table, with a vase of flowers next to it. In the afternoon sun, the shadows of the blinds are cast on the wall.", | |
| }, | |
| { | |
| "image": "samples/5.png", | |
| "text": "On the beach, a lady sits under a beach umbrella with 'Omini' written on it. She's wearing this shirt and has a big smile on her face, with her surfboard hehind her.", | |
| }, | |
| ] | |
| return [[Image.open(sample["image"]), sample["text"]] for sample in sample_list] | |
| # =============== UI =============== | |
| header = """ | |
| <h1>π ZenCtrl medium</h1> | |
| <div align="center" style="line-height: 1;"> | |
| <a href="https://github.com/FotographerAI/ZenCtrl/tree/main" target="_blank" style="margin: 2px;" name="github_repo_link"><img src="https://img.shields.io/badge/GitHub-Repo-181717.svg" alt="GitHub Repo" style="display: inline-block; vertical-align: middle;"></a> | |
| <a href="https://huggingface.co/fotographerai/zenctrl_tools" target="_blank" name="huggingface_space_link"><img src="https://img.shields.io/badge/π€_HuggingFace-Model-ffbd45.svg" alt="HuggingFace Model" style="display: inline-block; vertical-align: middle;"></a> | |
| <a href="https://discord.com/invite/b9RuYQ3F8k" target="_blank" style="margin: 2px;" name="discord_link"><img src="https://img.shields.io/badge/Discord-Join-7289da.svg?logo=discord" alt="Discord" style="display: inline-block; vertical-align: middle;"></a> | |
| <a href="https://fotographer.ai/zen-control" target="_blank" style="margin: 2px;" name="lp_link"><img src="https://img.shields.io/badge/Website-Landing_Page-blue" alt="LP" style="display: inline-block; vertical-align: middle;"></a> | |
| <a href="https://x.com/FotographerAI" target="_blank" style="margin: 2px;" name="twitter_link"><img src="https://img.shields.io/twitter/follow/FotographerAI?style=social" alt="X" style="display: inline-block; vertical-align: middle;"></a> | |
| </div> | |
| <div align="center" style="line-height: 1; margin-top: 16px;"> | |
| <a href="https://www.producthunt.com/products/zenctrl?embed=true&utm_source=badge-featured&utm_medium=badge&utm_source=badge-zenctrl" target="_blank"><img src="https://api.producthunt.com/widgets/embed-image/v1/featured.svg?post_id=969113&theme=light&t=1749428880088" alt="ZenCtrl - Framework to generate multi-view images | Product Hunt" style="width: 250px; height: 54px;" width="250" height="54" /></a> | |
| </div> | |
| """ | |
| with gr.Blocks(title="π ZenCtrl-medium") as demo: | |
| # ---------- banner ---------- | |
| gr.HTML(header) | |
| gr.Markdown( | |
| """ | |
| # ZenCtrl Demo | |
| One framework to Generate multi-view, diverse-scene, and task-specific high-resolution images from a single subject imageβwithout fine-tuning. | |
| We are first releasing some of the task specific weights and will release the codes soon. | |
| The goal is to unify all of the visual content generation tasks with a single LLM... | |
| **Mode:** | |
| - **Subject-driven Image Generation:** Generate in-context images of your subject with high fidelity and in different perspectives. | |
| For more details, shoot us a message on discord. | |
| """ | |
| ) | |
| # ---------- tab bar ---------- | |
| with gr.Tabs(): | |
| for mode_name, defaults in MODE_DEFAULTS.items(): | |
| with gr.Tab(mode_name): | |
| gr.Markdown(f"### {mode_name}") | |
| # -------- left (input) column -------- | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| input_image = gr.Image(label="Input Image", type="pil") | |
| model_dropdown = gr.Dropdown( | |
| label="Model (LoRA adapter)", | |
| choices=defaults["models"], | |
| value=defaults["model"], | |
| interactive=True, | |
| ) | |
| prompt_box = gr.Textbox(label="Prompt", | |
| value=defaults["prompt"], lines=2) | |
| generate_btn = gr.Button("Generate") | |
| with gr.Accordion("Generation Parameters", open=False): | |
| step_slider = gr.Slider(2, 28, value=12, step=2, label="Steps") | |
| strength_sub_slider = gr.Slider(0.0, 2.0, | |
| value=defaults["default_strength"], | |
| step=0.1, label="Strength (subject)") | |
| strength_spat_slider = gr.Slider(0.0, 2.0, | |
| value=defaults["default_strength"], | |
| step=0.1, label="Strength (spatial)") | |
| size_slider = gr.Slider(512, 2048, | |
| value=defaults["default_height"], | |
| step=64, label="Size (px)") | |
| # -------- right (output) column -------- | |
| with gr.Column(scale=2): | |
| output_image = gr.Image(label="Output Image", type="pil") | |
| # ---------- click handler ---------- | |
| def _run(image, model_name, prompt, steps, s_sub, s_spat, size): | |
| global current_adapter | |
| pipe = get_pipeline() | |
| # ββ switch adapter if needed ββββββββββββββββββββββββββ | |
| if model_name != current_adapter: | |
| lora_path = MODEL_TO_LORA[model_name] | |
| # load & activate the chosen adapter | |
| pipe.load_lora_weights( | |
| "fotographerai/zenctrl_tools", | |
| weight_name=lora_path, | |
| adapter_name=model_name, | |
| ) | |
| pipe.set_adapters([model_name]) | |
| current_adapter = model_name | |
| # ββ run generation βββββββββββββββββββββββββββββββββββ | |
| delta = size // 16 | |
| return process_image_and_text( | |
| image, prompt, steps=steps, | |
| strength_sub=s_sub, strength_spat=s_spat, size=size | |
| ) | |
| generate_btn.click( | |
| fn=_run, | |
| inputs=[input_image, model_dropdown, prompt_box, | |
| step_slider, strength_sub_slider, | |
| strength_spat_slider, size_slider], | |
| outputs=[output_image], | |
| ) | |
| # ---------------- Templates -------------------- | |
| if examples_db.MODE_EXAMPLES.get(mode_name): | |
| gr.Examples( | |
| examples=examples_db.MODE_EXAMPLES[mode_name], | |
| inputs=[ input_image, # Image widget | |
| model_dropdown, # Dropdown for adapter | |
| prompt_box, # Textbox for prompt | |
| output_image, # Gallery for output | |
| ], | |
| label="Presets (Image / Model / Prompt)", | |
| examples_per_page=15, | |
| ) | |
| # =============== launch =============== | |
| if __name__ == "__main__": | |
| #init_pipeline() | |
| demo.launch( | |
| debug=True, | |
| share=True | |
| ) |