XciD's picture
XciD HF Staff
feat: add AI render mode with SDXL + ControlNet Tile inpainting
9a8a023 unverified
Raw
History Blame Contribute Delete
6.57 kB
import spaces
import numpy as np
import torch
import gradio as gr
from PIL import Image, ImageDraw
from transformers import (
AutoImageProcessor,
Mask2FormerForUniversalSegmentation,
AutoModelForDepthEstimation,
)
from diffusers import (
StableDiffusionXLControlNetInpaintPipeline,
ControlNetModel,
AutoencoderKL,
)
# ─── Segmentation + Depth models ─────────────────
seg_processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-large-ade-semantic")
seg_model = Mask2FormerForUniversalSegmentation.from_pretrained(
"facebook/mask2former-swin-large-ade-semantic"
)
FLOOR_KEYWORDS = {'floor', 'flooring', 'rug', 'carpet', 'mat'}
FLOOR_IDS = set()
id2label = seg_model.config.id2label
for idx, label in id2label.items():
if any(kw in label.lower() for kw in FLOOR_KEYWORDS):
FLOOR_IDS.add(int(idx))
print(f"Floor class: {idx} = {label}")
if not FLOOR_IDS:
FLOOR_IDS = {3, 28}
depth_processor = AutoImageProcessor.from_pretrained("depth-anything/Depth-Anything-V2-Large-hf")
depth_model = AutoModelForDepthEstimation.from_pretrained(
"depth-anything/Depth-Anything-V2-Large-hf", torch_dtype=torch.float16
)
# ─── SDXL + ControlNet Tile for AI rendering ─────
print("Loading ControlNet Tile + SDXL inpainting pipeline...")
controlnet = ControlNetModel.from_pretrained(
"xinsir/controlnet-tile-sdxl-1.0",
torch_dtype=torch.float16,
)
vae = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix",
torch_dtype=torch.float16,
)
inpaint_pipe = StableDiffusionXLControlNetInpaintPipeline.from_pretrained(
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
controlnet=controlnet,
vae=vae,
torch_dtype=torch.float16,
variant="fp16",
)
inpaint_pipe.enable_model_cpu_offload()
print("Pipeline loaded.")
@spaces.GPU(duration=60)
@torch.inference_mode()
def predict(image):
if image is None:
raise gr.Error("No image provided")
orig_w, orig_h = image.size
max_size = 1024
scale = min(1.0, max_size / max(orig_w, orig_h))
proc_w, proc_h = int(orig_w * scale), int(orig_h * scale)
image_resized = image.resize((proc_w, proc_h), Image.LANCZOS)
device = seg_model.device
seg_inputs = seg_processor(images=image_resized, return_tensors="pt")
seg_inputs = {k: v.to(device) for k, v in seg_inputs.items()}
seg_outputs = seg_model(**seg_inputs)
seg_result = seg_processor.post_process_semantic_segmentation(
seg_outputs, target_sizes=[(proc_h, proc_w)]
)[0]
seg_map = seg_result.cpu().numpy()
floor_mask = np.zeros((proc_h, proc_w), dtype=np.uint8)
unique_classes = np.unique(seg_map)
print(f"Detected classes: {[(int(c), id2label.get(c, '?')) for c in unique_classes]}")
for class_id in FLOOR_IDS:
floor_mask[seg_map == class_id] = 255
mask_img = Image.fromarray(floor_mask).resize((orig_w, orig_h), Image.NEAREST)
depth_inputs = depth_processor(images=image_resized, return_tensors="pt")
depth_inputs = {k: v.to(device, dtype=torch.float16) if v.is_floating_point() else v.to(device) for k, v in depth_inputs.items()}
depth_outputs = depth_model(**depth_inputs)
depth_map = depth_outputs.predicted_depth.squeeze().cpu().numpy()
depth_min, depth_max = depth_map.min(), depth_map.max()
if depth_max - depth_min > 0:
depth_norm = ((depth_map - depth_min) / (depth_max - depth_min) * 255).astype(np.uint8)
else:
depth_norm = np.zeros_like(depth_map, dtype=np.uint8)
depth_img = Image.fromarray(depth_norm).resize((orig_w, orig_h), Image.BILINEAR)
return mask_img, depth_img
def create_tiled_control_image(tile_texture, width, height):
"""Tile the texture image to fill width x height."""
tw, th = tile_texture.size
control = Image.new("RGB", (width, height))
for y in range(0, height, th):
for x in range(0, width, tw):
control.paste(tile_texture, (x, y))
return control
@spaces.GPU(duration=120)
@torch.inference_mode()
def render_ai(room_image, tile_texture):
if room_image is None or tile_texture is None:
raise gr.Error("Room image and tile texture are required")
# Step 1: Get floor mask
mask_img, _ = predict.__wrapped__(room_image)
# Resize everything to 1024x1024 for SDXL
size = 1024
room_resized = room_image.resize((size, size), Image.LANCZOS)
mask_resized = mask_img.resize((size, size), Image.NEAREST)
# Step 2: Create tiled control image from tile texture
tile_size = max(64, size // 8)
tile_resized = tile_texture.resize((tile_size, tile_size), Image.LANCZOS)
control_image = create_tiled_control_image(tile_resized, size, size)
# Step 3: Run SDXL inpainting with ControlNet Tile
result = inpaint_pipe(
prompt="ceramic tile floor, tiled floor with repeating pattern, interior design photo, photorealistic",
negative_prompt="blurry, distorted, low quality, watermark, text",
image=room_resized,
mask_image=mask_resized,
control_image=control_image,
num_inference_steps=25,
guidance_scale=7.0,
controlnet_conditioning_scale=0.9,
strength=0.95,
generator=torch.Generator(device="cuda").manual_seed(42),
).images[0]
# Resize back to original dimensions
result = result.resize((room_image.size[0], room_image.size[1]), Image.LANCZOS)
return result
with gr.Blocks() as demo:
gr.Markdown("# Tile Visualizer API")
with gr.Tab("Segmentation"):
with gr.Row():
seg_input = gr.Image(type="pil", label="Room photo")
with gr.Row():
mask_output = gr.Image(type="pil", label="Floor mask")
depth_output = gr.Image(type="pil", label="Depth map")
seg_btn = gr.Button("Segment")
seg_btn.click(fn=predict, inputs=seg_input, outputs=[mask_output, depth_output])
with gr.Tab("AI Render"):
with gr.Row():
render_room = gr.Image(type="pil", label="Room photo")
render_tile = gr.Image(type="pil", label="Tile texture")
render_output = gr.Image(type="pil", label="Result")
render_btn = gr.Button("Render")
render_btn.click(fn=render_ai, inputs=[render_room, render_tile], outputs=render_output)
app = demo.app
from starlette.middleware.cors import CORSMiddleware
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
if __name__ == "__main__":
demo.launch(ssr_mode=False)