Spaces:
Sleeping
Sleeping
Commit ·
d774d01
1
Parent(s): 51312b1
Simplify interface, remove mask display and advanced settings, ensure GPU compatibility
Browse files
app.py
CHANGED
|
@@ -1,85 +1,88 @@
|
|
| 1 |
-
import
|
|
|
|
| 2 |
import torch
|
| 3 |
from diffusers import FluxFillPipeline
|
| 4 |
-
import gradio as gr
|
| 5 |
from PIL import Image
|
| 6 |
-
import
|
| 7 |
-
import numpy as np
|
| 8 |
-
from huggingface_hub import login
|
| 9 |
-
|
| 10 |
-
# Authenticate with HF token from Spaces Secrets
|
| 11 |
-
hf_token = os.getenv("HF_TOKEN")
|
| 12 |
-
if not hf_token:
|
| 13 |
-
print(
|
| 14 |
-
"Warning: HF_TOKEN not found in environment. Please set it in Spaces Secrets."
|
| 15 |
-
)
|
| 16 |
-
hf_token = (
|
| 17 |
-
input("Enter your HF_TOKEN for local testing (leave blank to skip): ") or None
|
| 18 |
-
)
|
| 19 |
-
|
| 20 |
-
if hf_token:
|
| 21 |
-
login(token=hf_token)
|
| 22 |
-
else:
|
| 23 |
-
raise ValueError(
|
| 24 |
-
"HF_TOKEN is required to access the gated FLUX.1-Fill-dev model. Set it in Spaces Secrets or locally."
|
| 25 |
-
)
|
| 26 |
|
|
|
|
| 27 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 28 |
|
| 29 |
# Load the pipeline
|
| 30 |
pipe = FluxFillPipeline.from_pretrained(
|
| 31 |
-
"black-forest-labs/FLUX.1-Fill-dev",
|
| 32 |
-
torch_dtype=torch.bfloat16,
|
| 33 |
-
token=hf_token,
|
| 34 |
).to(device)
|
| 35 |
-
if torch.cuda.is_available():
|
| 36 |
-
pipe.enable_model_cpu_offload()
|
| 37 |
|
| 38 |
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
-
mask_np = np.array(mask)
|
| 48 |
-
mask_np = cv2.cvtColor(mask_np, cv2.COLOR_RGB2GRAY)
|
| 49 |
-
_, mask_np = cv2.threshold(mask_np, 127, 255, cv2.THRESH_BINARY)
|
| 50 |
-
processed_mask = Image.fromarray(mask_np)
|
| 51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
output = pipe(
|
| 53 |
-
prompt=prompt
|
| 54 |
image=image,
|
| 55 |
-
mask_image=
|
| 56 |
-
|
|
|
|
| 57 |
guidance_scale=7.5,
|
| 58 |
-
|
|
|
|
| 59 |
).images[0]
|
| 60 |
-
return output
|
| 61 |
except Exception as e:
|
| 62 |
-
|
| 63 |
|
| 64 |
|
|
|
|
| 65 |
with gr.Blocks() as demo:
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
output_image = gr.Image(label="Inpainted Output")
|
| 76 |
-
processed_mask_display = gr.Image(label="Processed Mask")
|
| 77 |
-
error_label = gr.Markdown()
|
| 78 |
-
submit.click(
|
| 79 |
-
inpaint,
|
| 80 |
-
inputs=[base_image, mask_image, prompt],
|
| 81 |
-
outputs=[output_image, processed_mask_display, error_label],
|
| 82 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
-
|
| 85 |
-
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import numpy as np
|
| 3 |
import torch
|
| 4 |
from diffusers import FluxFillPipeline
|
|
|
|
| 5 |
from PIL import Image
|
| 6 |
+
import random
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
+
# Set device (GPU if available, otherwise CPU)
|
| 9 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 10 |
|
| 11 |
# Load the pipeline
|
| 12 |
pipe = FluxFillPipeline.from_pretrained(
|
| 13 |
+
"black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16
|
|
|
|
|
|
|
| 14 |
).to(device)
|
|
|
|
|
|
|
| 15 |
|
| 16 |
|
| 17 |
+
# Function to calculate optimal dimensions for the output image
|
| 18 |
+
def calculate_optimal_dimensions(image: Image.Image):
|
| 19 |
+
original_width, original_height = image.size
|
| 20 |
+
MIN_ASPECT_RATIO = 9 / 16
|
| 21 |
+
MAX_ASPECT_RATIO = 16 / 9
|
| 22 |
+
FIXED_DIMENSION = 1024
|
| 23 |
+
original_aspect_ratio = original_width / original_height
|
| 24 |
+
if original_aspect_ratio > 1: # Wider than tall
|
| 25 |
+
width = FIXED_DIMENSION
|
| 26 |
+
height = round(FIXED_DIMENSION / original_aspect_ratio)
|
| 27 |
+
else: # Taller than wide
|
| 28 |
+
height = FIXED_DIMENSION
|
| 29 |
+
width = round(FIXED_DIMENSION * original_aspect_ratio)
|
| 30 |
+
# Ensure dimensions are multiples of 8
|
| 31 |
+
width = (width // 8) * 8
|
| 32 |
+
height = (height // 8) * 8
|
| 33 |
+
# Enforce aspect ratio limits
|
| 34 |
+
calculated_aspect_ratio = width / height
|
| 35 |
+
if calculated_aspect_ratio > MAX_ASPECT_RATIO:
|
| 36 |
+
width = (height * MAX_ASPECT_RATIO // 8) * 8
|
| 37 |
+
elif calculated_aspect_ratio < MIN_ASPECT_RATIO:
|
| 38 |
+
height = (width / MIN_ASPECT_RATIO // 8) * 8
|
| 39 |
+
# Ensure width and height remain above minimum dimensions
|
| 40 |
+
width = max(width, 576) if width == FIXED_DIMENSION else width
|
| 41 |
+
height = max(height, 576) if height == FIXED_DIMENSION else height
|
| 42 |
+
return width, height
|
| 43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
+
# Inpainting function
|
| 46 |
+
def infer(edit_images, prompt):
|
| 47 |
+
image = edit_images["background"]
|
| 48 |
+
if not edit_images["layers"]:
|
| 49 |
+
raise gr.Error("Please draw a mask.")
|
| 50 |
+
mask = edit_images["layers"][0]
|
| 51 |
+
width, height = calculate_optimal_dimensions(image)
|
| 52 |
+
seed = random.randint(0, np.iinfo(np.int32).max)
|
| 53 |
+
generator = torch.Generator(device=device).manual_seed(seed)
|
| 54 |
+
try:
|
| 55 |
output = pipe(
|
| 56 |
+
prompt=prompt,
|
| 57 |
image=image,
|
| 58 |
+
mask_image=mask,
|
| 59 |
+
height=height,
|
| 60 |
+
width=width,
|
| 61 |
guidance_scale=7.5,
|
| 62 |
+
num_inference_steps=50,
|
| 63 |
+
generator=generator,
|
| 64 |
).images[0]
|
| 65 |
+
return output
|
| 66 |
except Exception as e:
|
| 67 |
+
raise gr.Error(f"Error during inpainting: {str(e)}")
|
| 68 |
|
| 69 |
|
| 70 |
+
# Gradio interface setup
|
| 71 |
with gr.Blocks() as demo:
|
| 72 |
+
gr.Markdown("# FLUX.1 Fill [dev]")
|
| 73 |
+
edit_image = gr.ImageEditor(
|
| 74 |
+
label="Upload and draw mask for inpainting",
|
| 75 |
+
type="pil",
|
| 76 |
+
sources=["upload", "webcam"],
|
| 77 |
+
image_mode="RGB",
|
| 78 |
+
layers=False,
|
| 79 |
+
brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed"),
|
| 80 |
+
height=600,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
)
|
| 82 |
+
prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt")
|
| 83 |
+
run_button = gr.Button("Run")
|
| 84 |
+
result = gr.Image(label="Result")
|
| 85 |
+
run_button.click(infer, inputs=[edit_image, prompt], outputs=result)
|
| 86 |
|
| 87 |
+
# Launch the demo
|
| 88 |
+
demo.launch()
|