import csv import os import logging import uuid import shutil from copy import deepcopy from pathlib import Path # Set up logging logging.basicConfig( level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) logger = logging.getLogger(__name__) # Set Gradio temp directory BEFORE importing gradio to avoid permission issues TEMP_DIR = Path(__file__).parent / "gradio_temp" if TEMP_DIR.exists(): shutil.rmtree(str(TEMP_DIR)) TEMP_DIR.mkdir(exist_ok=True) os.environ["GRADIO_TEMP_DIR"] = str(TEMP_DIR) os.environ["TMPDIR"] = str(TEMP_DIR) import gradio as gr import numpy as np import torch from PIL import Image IS_HF_SPACE = os.environ.get("SPACE_ID") is not None try: import spaces zero_gpu = spaces.GPU(size="xlarge", duration=15) except ImportError: zero_gpu = lambda f: f from flowdis.sampling import flowdis_predict from flowdis.util import load_models from qwen import expand_prompt models = None device = "cuda" if torch.cuda.is_available(): models = load_models(device=device) else: print("No GPU available, the demo will not be able to run.") def disable_download_btn(): return gr.update(interactive=False) @zero_gpu def process_image(image, prompt, expand_prompt_enabled, resolution, num_inference_steps): """ Process the input image and prompt. This is a placeholder function - replace with your actual processing logic. Args: image: PIL Image or numpy array prompt: str, the text input from the user expand_prompt_enabled: bool, whether to expand the prompt via the model resolution: int, the inference resolution num_inference_steps: int, the number of inference steps Returns: Processed image """ if image is None: return None, None if isinstance(image, np.ndarray): image = Image.fromarray(image) logger.info(f"Original prompt: {prompt}") if prompt != "" and expand_prompt_enabled: prompt = expand_prompt(image, prompt) logger.info(f"Expanded prompt: {prompt}") num_inference_steps = int(num_inference_steps) pred_mask = flowdis_predict( image=image, prompt=prompt, models=models, resolution=resolution, num_inference_steps=num_inference_steps, device=device, ) blacked_image = Image.fromarray(np.array(image) * (np.array(pred_mask)[:, :, np.newaxis] > 0).astype(np.uint8)) transparent_png = Image.fromarray(np.dstack([blacked_image, np.array(pred_mask)])) uid = uuid.uuid4().hex png_path = TEMP_DIR / f"{uid}.png" transparent_png.save(png_path) return ( gr.update(value=[image, transparent_png], key=uid), gr.update(value=str(png_path), interactive=True) ) # Load examples from assets/examples/examples.csv: image_name, prompt, resolution, num_steps _example_dir = Path(__file__).parent / "assets" / "examples" _examples_csv = _example_dir / "examples.csv" examples = [] if _examples_csv.exists(): with open(_examples_csv, newline="", encoding="utf-8") as f: for row in csv.DictReader(f): image_path = str(_example_dir / row["image_name"].strip()) examples.append([ image_path, row["prompt"].strip(), True, # expand prompt (default for examples) int(row["resolution"].strip()), int(row["num_steps"].strip()), ]) _head_js = """ """ with gr.Blocks( title="FlowDIS – Precise Background Removal", head=_head_js, theme=gr.themes.Default( font=gr.themes.GoogleFont("Inter"), ).set( button_primary_background_fill="#C209C1", button_primary_background_fill_dark="#C209C1", button_primary_background_fill_hover="#d63bd5", button_primary_background_fill_hover_dark="#d63bd5", button_primary_text_color="#ffffff", button_primary_text_color_dark="#ffffff", ), delete_cache=(1800, 1800) ) as demo: gr.HTML( """
FlowDIS Demo
FlowDIS performs precise foreground segmentation, optionally guided by a text prompt to only preserve the specified objects.
""" ) with gr.Row(elem_id="main-row"): # Left column: Input image, text field, and submit button with gr.Column(scale=1): input_image = gr.Image( label="Input Image", type="pil", height=500, elem_id="input-image", ) text_input = gr.Textbox( label="Text Prompt (Optional)", placeholder="Enter what you want to retain...", lines=1, elem_id="text-prompt", ) expand_prompt_check = gr.Checkbox( label="Expand prompt", value=True, elem_id="expand-prompt", info="Use Qwen3-VL-4B-Instruct model to expand the prompt for better text-guided segmentation.", ) # Sliders for resolution and steps with gr.Row(): with gr.Column(scale=1, min_width=300): resolution_slider = gr.Slider( minimum=1024, maximum=2048, value=1536, step=64, label="Inference Resolution", info="Higher resolution preserves more details.", ) with gr.Column(scale=1, min_width=300): steps_slider = gr.Slider( minimum=1, maximum=12, value=4, step=1, label="Number of Steps", info="More steps generate sharper results.", ) submit_btn = gr.Button("🚀 Remove Background", variant="primary") # Right column: Output image with gr.Column(scale=1): output_image = gr.ImageSlider( label="FlowDIS prediction", type="pil", format="webp", height=500, slider_position=10, elem_id="output-slider", ) _checker = "repeating-conic-gradient(#ccc 0% 25%,#fff 0% 50%) 50%/12px 12px" _bg_buttons = [ (_checker, _checker), ("#ffffff", "#ffffff"), ("#000000", "#000000"), ("#00ff00", "#00ff00"), ("#0000ff", "#0000ff"), ("#ff0000", "#ff0000"), ("#ffff00", "#ffff00"), ("#ff00ff", "#ff00ff"), ("#00ffff", "#00ffff"), ] _onclick = ( "var s=document.getElementById('slider-bg-style');" "if(!s){s=document.createElement('style');" "s.id='slider-bg-style';document.head.appendChild(s);}" "s.textContent='#output-slider img,#output-slider canvas" "{background:'+this.dataset.bg+' !important}';" ) gr.HTML( value='
' + "".join( f'' for style, bg in _bg_buttons ) + "
" ) download_btn = gr.DownloadButton( label="📥 Download PNG", variant="primary", interactive=False ) # Connect the submit button to the processing function submit_btn.click( disable_download_btn, outputs=download_btn ).then( fn=process_image, inputs=[input_image, text_input, expand_prompt_check, resolution_slider, steps_slider], outputs=[output_image, download_btn] ) # Optional: Also trigger on text input enter key text_input.submit( disable_download_btn, outputs=download_btn ).then( fn=process_image, inputs=[input_image, text_input, expand_prompt_check, resolution_slider, steps_slider], outputs=[output_image, download_btn], ) examples_component = gr.Examples( examples=examples, inputs=[input_image, text_input, expand_prompt_check, resolution_slider, steps_slider], label="Examples", elem_id="examples-table", ) examples_component.dataset.click( disable_download_btn, outputs=download_btn ).then( process_image, inputs=[input_image, text_input, expand_prompt_check, resolution_slider, steps_slider], outputs=[output_image, download_btn], ) # Launch the app if __name__ == "__main__": demo.queue(max_size=20) if IS_HF_SPACE: demo.launch(allowed_paths=[str(TEMP_DIR), "assets"]) else: demo.launch( server_name="0.0.0.0", server_port=7860, share=True, allowed_paths=[str(TEMP_DIR), "assets"], )