| import gradio as gr
|
| import numpy as np
|
| import random
|
| import spaces
|
| import torch
|
| from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler, AutoencoderTiny, AutoencoderKL, StableDiffusionUpscalePipeline
|
| from transformers import CLIPTextModel, CLIPTokenizer,T5EncoderModel, T5TokenizerFast
|
| from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
|
| from huggingface_hub import hf_hub_download
|
| import os
|
| import requests
|
|
|
| dtype = torch.bfloat16
|
| device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
| taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
|
| good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=dtype).to(device)
|
| pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=dtype, vae=taef1).to(device)
|
|
|
|
|
| if hasattr(pipe, "enable_model_cpu_offload"):
|
| pipe.enable_model_cpu_offload()
|
| if hasattr(pipe, "enable_attention_slicing"):
|
| pipe.enable_attention_slicing(1)
|
| if hasattr(pipe, "enable_vae_slicing"):
|
| pipe.enable_vae_slicing()
|
| if hasattr(pipe, "enable_vae_tiling"):
|
| pipe.enable_vae_tiling()
|
|
|
|
|
| try:
|
| pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead", fullgraph=True)
|
| print("✓ Transformer compiled for faster inference")
|
| except Exception as e:
|
| print(f"Warning: Could not compile transformer: {e}")
|
|
|
|
|
| upscaler = StableDiffusionUpscalePipeline.from_pretrained("stabilityai/stable-diffusion-x4-upscaler", torch_dtype=dtype).to(device)
|
| if hasattr(upscaler, "enable_model_cpu_offload"):
|
| upscaler.enable_model_cpu_offload()
|
| if hasattr(upscaler, "enable_attention_slicing"):
|
| upscaler.enable_attention_slicing(1)
|
| if hasattr(upscaler, "enable_vae_slicing"):
|
| upscaler.enable_vae_slicing()
|
|
|
|
|
| LORAS = {
|
| "None": None,
|
| "AntiBlur": "Shakker-Labs/FLUX.1-dev-LoRA-AntiBlur",
|
| "Add Details": "Shakker-Labs/FLUX.1-dev-LoRA-add-details",
|
| "Ultra Realism": "https://huggingface.co/its-magick/merlin-test-loras/resolve/main/Canopus-LoRA-Flux-UltraRealism.safetensors",
|
| "Face Realism": "https://huggingface.co/its-magick/merlin-test-loras/resolve/main/Canopus-LoRA-Flux-FaceRealism.safetensors"
|
| }
|
|
|
|
|
| loaded_loras = {}
|
|
|
| def download_lora_from_url(url, filename):
|
| """Download LoRA file from direct URL"""
|
| if not os.path.exists(filename):
|
| print(f"Downloading {filename}...")
|
| response = requests.get(url)
|
| with open(filename, 'wb') as f:
|
| f.write(response.content)
|
| print(f"Downloaded {filename}")
|
| return filename
|
|
|
| def preload_and_apply_all_loras():
|
| """Download and apply all LoRAs simultaneously at startup"""
|
| global loaded_loras
|
|
|
| print("Downloading and applying all LoRAs...")
|
|
|
| for lora_name, lora_path in LORAS.items():
|
| if lora_name == "None" or lora_path is None:
|
| continue
|
|
|
|
|
| if lora_path.startswith('http'):
|
| filename = f"{lora_name.lower().replace(' ', '_')}_lora.safetensors"
|
| lora_path = download_lora_from_url(lora_path, filename)
|
|
|
| loaded_loras[lora_name] = lora_path
|
| print(f"Downloaded {lora_name}")
|
|
|
|
|
| try:
|
| optimal_scale = get_optimal_lora_scale(lora_name)
|
| pipe.load_lora_weights(lora_path, adapter_name=lora_name.lower().replace(' ', '_'))
|
| print(f"Applied {lora_name} with scale {optimal_scale}")
|
| except Exception as e:
|
| print(f"Failed to apply {lora_name}: {e}")
|
|
|
| print(f"All {len(loaded_loras)} LoRAs downloaded and applied!")
|
|
|
| def get_optimal_lora_scale(lora_name):
|
| """Return optimal LoRA scale based on LoRA type for better quality/speed balance"""
|
| lora_scales = {
|
| "AntiBlur": 0.8,
|
| "Add Details": 1.2,
|
| "Ultra Realism": 0.9,
|
| "Face Realism": 1.1,
|
| }
|
| return lora_scales.get(lora_name, 1.0)
|
|
|
|
|
| preload_and_apply_all_loras()
|
|
|
| torch.cuda.empty_cache()
|
|
|
| MAX_SEED = np.iinfo(np.int32).max
|
| MAX_IMAGE_SIZE = 2048
|
|
|
| pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
|
|
|
| @spaces.GPU(duration=75)
|
| def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=3.5, num_inference_steps=28, enable_upscale=False, progress=gr.Progress(track_tqdm=True)):
|
| if randomize_seed:
|
| seed = random.randint(0, MAX_SEED)
|
| generator = torch.Generator().manual_seed(seed)
|
|
|
|
|
|
|
| try:
|
| final_img = None
|
| for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
|
| prompt=prompt,
|
| guidance_scale=guidance_scale,
|
| num_inference_steps=num_inference_steps,
|
| width=width,
|
| height=height,
|
| generator=generator,
|
| output_type="pil",
|
| good_vae=good_vae,
|
| ):
|
| final_img = img
|
| yield img, seed
|
|
|
|
|
| if enable_upscale and final_img is not None:
|
| try:
|
|
|
| upscaled_img = upscaler(
|
| prompt=prompt,
|
| image=final_img,
|
| num_inference_steps=15,
|
| guidance_scale=6.0,
|
| generator=generator,
|
| ).images[0]
|
| yield upscaled_img, seed
|
| except Exception as e:
|
| print(f"Error during upscaling: {e}")
|
| yield final_img, seed
|
|
|
| except Exception as e:
|
| print(f"Error during generation: {e}")
|
|
|
| img = pipe(
|
| prompt=prompt,
|
| guidance_scale=guidance_scale,
|
| num_inference_steps=num_inference_steps,
|
| width=width,
|
| height=height,
|
| generator=generator,
|
| ).images[0]
|
|
|
|
|
| if enable_upscale:
|
| try:
|
| img = upscaler(
|
| prompt=prompt,
|
| image=img,
|
| num_inference_steps=20,
|
| guidance_scale=7.5,
|
| generator=generator,
|
| ).images[0]
|
| except Exception as e:
|
| print(f"Error during upscaling: {e}")
|
|
|
| yield img, seed
|
|
|
| examples = [
|
| "a tiny astronaut hatching from an egg on the moon",
|
| "a cat holding a sign that says hello world",
|
| "an anime illustration of a wiener schnitzel",
|
| ]
|
|
|
| css="""
|
| #col-container {
|
| margin: 0 auto;
|
| max-width: 520px;
|
| }
|
| """
|
|
|
| with gr.Blocks(css=css) as demo:
|
|
|
| with gr.Column(elem_id="col-container"):
|
| gr.Markdown(f"""# FLUX.1 [dev]
|
| 12B param rectified flow transformer guidance-distilled from [FLUX.1 [pro]](https://blackforestlabs.ai/)
|
| [[non-commercial license](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md)] [[blog](https://blackforestlabs.ai/announcing-black-forest-labs/)] [[model](https://huggingface.co/black-forest-labs/FLUX.1-dev)]
|
| """)
|
|
|
| with gr.Row():
|
|
|
| prompt = gr.Text(
|
| label="Prompt",
|
| show_label=False,
|
| max_lines=1,
|
| placeholder="Enter your prompt",
|
| container=False,
|
| )
|
|
|
| run_button = gr.Button("Run", scale=0)
|
|
|
| result = gr.Image(label="Result", show_label=False)
|
|
|
| with gr.Accordion("Advanced Settings", open=False):
|
|
|
| gr.Markdown("**LoRAs Active:** All LoRAs are loaded and active simultaneously")
|
|
|
| enable_upscale = gr.Checkbox(
|
| label="Enable 4x Upscaling",
|
| value=False,
|
| info="Upscale final image using Stable Diffusion 4x upscaler"
|
| )
|
|
|
| seed = gr.Slider(
|
| label="Seed",
|
| minimum=0,
|
| maximum=MAX_SEED,
|
| step=1,
|
| value=0,
|
| )
|
|
|
| randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
|
|
| with gr.Row():
|
|
|
| width = gr.Slider(
|
| label="Width",
|
| minimum=256,
|
| maximum=MAX_IMAGE_SIZE,
|
| step=32,
|
| value=1024,
|
| )
|
|
|
| height = gr.Slider(
|
| label="Height",
|
| minimum=256,
|
| maximum=MAX_IMAGE_SIZE,
|
| step=32,
|
| value=1024,
|
| )
|
|
|
| with gr.Row():
|
|
|
| guidance_scale = gr.Slider(
|
| label="Guidance Scale",
|
| minimum=1,
|
| maximum=15,
|
| step=0.1,
|
| value=3.5,
|
| info="Lower values = faster generation, higher values = more prompt adherence"
|
| )
|
|
|
| num_inference_steps = gr.Slider(
|
| label="Number of inference steps",
|
| minimum=4,
|
| maximum=50,
|
| step=1,
|
| value=20,
|
| info="Lower values = faster generation, higher values = better quality"
|
| )
|
|
|
| gr.Examples(
|
| examples = examples,
|
| fn = infer,
|
| inputs = [prompt],
|
| outputs = [result, seed],
|
| cache_examples="lazy"
|
| )
|
|
|
| gr.on(
|
| triggers=[run_button.click, prompt.submit],
|
| fn = infer,
|
| inputs = [prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, enable_upscale],
|
| outputs = [result, seed]
|
| )
|
|
|
| demo.launch(share=True) |