Spaces:
Runtime error
Runtime error
| import random | |
| import torch | |
| import gradio as gr | |
| from diffusers import ( | |
| FluxPipeline, | |
| DPMSolverMultistepScheduler, | |
| DPMSolverSDEScheduler, | |
| EulerDiscreteScheduler, | |
| EulerAncestralDiscreteScheduler, | |
| HeunDiscreteScheduler, | |
| DDIMScheduler, | |
| ) | |
| # ----------------------------------------------------------------------------- | |
| # Pipeline loading helpers | |
| # ----------------------------------------------------------------------------- | |
| def _load_pipe(hf_token: str | None = None) -> FluxPipeline: | |
| """Load the FLUX pipeline once and keep it in memory. | |
| Args: | |
| hf_token: Optional Hugging Face token if the model is gated/private. | |
| Returns: | |
| A fully‑initialised FluxPipeline with LoRA fused and memory‑saving | |
| features enabled. | |
| """ | |
| pipe = FluxPipeline.from_pretrained( | |
| "black-forest-labs/FLUX.1-dev", | |
| torch_dtype=torch.float16, | |
| use_auth_token=hf_token or None, | |
| ) | |
| # Memory optimisations ---------------------------------------------------- | |
| pipe.enable_sequential_cpu_offload() | |
| pipe.enable_attention_slicing() | |
| # LoRA -------------------------------------------------------------------- | |
| pipe.load_lora_weights( | |
| "kudzueye/boreal-flux-dev-v2", weight_name="boreal-v2.safetensors" | |
| ) | |
| pipe.fuse_lora(lora_scale=0.8) | |
| return pipe | |
| # Keep a single global instance to avoid re‑loading on every request | |
| _pipe: FluxPipeline | None = None | |
| def _get_pipe(hf_token: str | None = None) -> FluxPipeline: | |
| global _pipe | |
| if _pipe is None: | |
| _pipe = _load_pipe(hf_token) | |
| return _pipe | |
| # ----------------------------------------------------------------------------- | |
| # Scheduler mapping | |
| # ----------------------------------------------------------------------------- | |
| SCHED_MAP = { | |
| "DPM++ 2M Karras": DPMSolverMultistepScheduler, | |
| "DPM++ SDE Karras": DPMSolverSDEScheduler, | |
| "Euler": EulerDiscreteScheduler, | |
| "Euler a": EulerAncestralDiscreteScheduler, | |
| "Heun": HeunDiscreteScheduler, | |
| "DDIM": DDIMScheduler, | |
| } | |
| # ----------------------------------------------------------------------------- | |
| # Inference function | |
| # ----------------------------------------------------------------------------- | |
| def query( | |
| prompt: str, | |
| negative_prompt: str, | |
| steps: int, | |
| cfg_scale: float, | |
| sampler: str, | |
| seed: int, | |
| strength: float, # kept for future img2img support | |
| hf_token: str, | |
| ): | |
| """Run the generation and return a PIL image + the seed actually used.""" | |
| pipe = _get_pipe(hf_token or None) | |
| # Replace scheduler if the user selected a different sampler | |
| SchedulerCls = SCHED_MAP.get(sampler, DPMSolverMultistepScheduler) | |
| if not isinstance(pipe.scheduler, SchedulerCls): | |
| pipe.scheduler = SchedulerCls.from_config(pipe.scheduler.config) | |
| # Handle seed | |
| if seed == -1: | |
| seed = random.randint(0, 1_000_000_000) | |
| generator = torch.Generator(device=pipe.device).manual_seed(seed) | |
| # Run inference | |
| with torch.no_grad(): | |
| result = pipe( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| num_inference_steps=steps, | |
| guidance_scale=cfg_scale, | |
| generator=generator, | |
| height=512, | |
| width=512, | |
| ) | |
| return result.images[0], str(seed) | |
| # ----------------------------------------------------------------------------- | |
| # Gradio UI | |
| # ----------------------------------------------------------------------------- | |
| CSS = """ | |
| #app-container { | |
| max-width: 600px; | |
| margin-left: auto; | |
| margin-right: auto; | |
| } | |
| #title-container { | |
| display: flex; | |
| align-items: center; | |
| justify-content: center; | |
| } | |
| #title-icon { | |
| width: 32px; | |
| height: auto; | |
| margin-right: 10px; | |
| } | |
| #title-text { | |
| font-size: 24px; | |
| font-weight: bold; | |
| } | |
| """ | |
| with gr.Blocks(theme="Nymbo/Nymbo_Theme", css=CSS) as app: | |
| gr.HTML( | |
| """ | |
| <center> | |
| <div id="title-container"> | |
| <h1 id="title-text">Text-to-Image Generator App</h1> | |
| </div> | |
| </center> | |
| """ | |
| ) | |
| with gr.Column(elem_id="app-container"): | |
| with gr.Row(): | |
| with gr.Column(elem_id="prompt-container"): | |
| with gr.Row(): | |
| txt_prompt = gr.Textbox( | |
| label="Prompt", | |
| placeholder="Enter a prompt here", | |
| lines=2, | |
| elem_id="prompt-text-input", | |
| ) | |
| with gr.Row(): | |
| with gr.Accordion("Advanced Settings", open=False): | |
| neg_prompt = gr.Textbox( | |
| label="Negative Prompt", | |
| placeholder="What should not be in the image", | |
| value="(deformed, distorted, disfigured), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, (mutated hands and fingers), disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation, misspellings, typos", | |
| lines=3, | |
| elem_id="negative-prompt-text-input", | |
| ) | |
| steps_in = gr.Slider( | |
| label="Sampling steps", value=35, minimum=1, maximum=100, step=1 | |
| ) | |
| cfg_in = gr.Slider( | |
| label="CFG Scale", value=7, minimum=1, maximum=20, step=1 | |
| ) | |
| sampler_in = gr.Radio( | |
| label="Sampling method", | |
| value="DPM++ 2M Karras", | |
| choices=list(SCHED_MAP.keys()), | |
| ) | |
| strength_in = gr.Slider( | |
| label="Strength", value=0.7, minimum=0, maximum=1, step=0.001 | |
| ) | |
| seed_in = gr.Slider( | |
| label="Seed", value=-1, minimum=-1, maximum=1_000_000_000, step=1 | |
| ) | |
| api_key_in = gr.Textbox( | |
| label="Hugging Face API Key (required for private models)", | |
| placeholder="Enter your Hugging Face API Key here", | |
| type="password", | |
| elem_id="api-key", | |
| ) | |
| with gr.Row(): | |
| run_button = gr.Button("Run", variant="primary", elem_id="gen-button") | |
| with gr.Row(): | |
| img_out = gr.Image(type="pil", label="Image Output", elem_id="gallery") | |
| seed_out = gr.Textbox(label="Seed Used", elem_id="seed-output") | |
| run_button.click( | |
| fn=query, | |
| inputs=[ | |
| txt_prompt, | |
| neg_prompt, | |
| steps_in, | |
| cfg_in, | |
| sampler_in, | |
| seed_in, | |
| strength_in, | |
| api_key_in, | |
| ], | |
| outputs=[img_out, seed_out], | |
| ) | |
| app.launch(show_api=True, share=False) | |