Spaces:
Sleeping
Sleeping
File size: 6,569 Bytes
0299827 3ccb4f2 9a8a023 3ccb4f2 9dd410a 3ccb4f2 9a8a023 3ccb4f2 9dd410a 05cc74d faad11a 3ccb4f2 a3d0eee faad11a 3ccb4f2 9a8a023 3ccb4f2 9a8a023 3ccb4f2 faad11a 3ccb4f2 faad11a 9dd410a 05cc74d 3ccb4f2 a3d0eee 3ccb4f2 7c12218 3ccb4f2 9a8a023 3ccb4f2 9a8a023 3ccb4f2 6bc9ac3 3ccb4f2 0299827 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 | import spaces
import numpy as np
import torch
import gradio as gr
from PIL import Image, ImageDraw
from transformers import (
AutoImageProcessor,
Mask2FormerForUniversalSegmentation,
AutoModelForDepthEstimation,
)
from diffusers import (
StableDiffusionXLControlNetInpaintPipeline,
ControlNetModel,
AutoencoderKL,
)
# βββ Segmentation + Depth models βββββββββββββββββ
seg_processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-large-ade-semantic")
seg_model = Mask2FormerForUniversalSegmentation.from_pretrained(
"facebook/mask2former-swin-large-ade-semantic"
)
FLOOR_KEYWORDS = {'floor', 'flooring', 'rug', 'carpet', 'mat'}
FLOOR_IDS = set()
id2label = seg_model.config.id2label
for idx, label in id2label.items():
if any(kw in label.lower() for kw in FLOOR_KEYWORDS):
FLOOR_IDS.add(int(idx))
print(f"Floor class: {idx} = {label}")
if not FLOOR_IDS:
FLOOR_IDS = {3, 28}
depth_processor = AutoImageProcessor.from_pretrained("depth-anything/Depth-Anything-V2-Large-hf")
depth_model = AutoModelForDepthEstimation.from_pretrained(
"depth-anything/Depth-Anything-V2-Large-hf", torch_dtype=torch.float16
)
# βββ SDXL + ControlNet Tile for AI rendering βββββ
print("Loading ControlNet Tile + SDXL inpainting pipeline...")
controlnet = ControlNetModel.from_pretrained(
"xinsir/controlnet-tile-sdxl-1.0",
torch_dtype=torch.float16,
)
vae = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix",
torch_dtype=torch.float16,
)
inpaint_pipe = StableDiffusionXLControlNetInpaintPipeline.from_pretrained(
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
controlnet=controlnet,
vae=vae,
torch_dtype=torch.float16,
variant="fp16",
)
inpaint_pipe.enable_model_cpu_offload()
print("Pipeline loaded.")
@spaces.GPU(duration=60)
@torch.inference_mode()
def predict(image):
if image is None:
raise gr.Error("No image provided")
orig_w, orig_h = image.size
max_size = 1024
scale = min(1.0, max_size / max(orig_w, orig_h))
proc_w, proc_h = int(orig_w * scale), int(orig_h * scale)
image_resized = image.resize((proc_w, proc_h), Image.LANCZOS)
device = seg_model.device
seg_inputs = seg_processor(images=image_resized, return_tensors="pt")
seg_inputs = {k: v.to(device) for k, v in seg_inputs.items()}
seg_outputs = seg_model(**seg_inputs)
seg_result = seg_processor.post_process_semantic_segmentation(
seg_outputs, target_sizes=[(proc_h, proc_w)]
)[0]
seg_map = seg_result.cpu().numpy()
floor_mask = np.zeros((proc_h, proc_w), dtype=np.uint8)
unique_classes = np.unique(seg_map)
print(f"Detected classes: {[(int(c), id2label.get(c, '?')) for c in unique_classes]}")
for class_id in FLOOR_IDS:
floor_mask[seg_map == class_id] = 255
mask_img = Image.fromarray(floor_mask).resize((orig_w, orig_h), Image.NEAREST)
depth_inputs = depth_processor(images=image_resized, return_tensors="pt")
depth_inputs = {k: v.to(device, dtype=torch.float16) if v.is_floating_point() else v.to(device) for k, v in depth_inputs.items()}
depth_outputs = depth_model(**depth_inputs)
depth_map = depth_outputs.predicted_depth.squeeze().cpu().numpy()
depth_min, depth_max = depth_map.min(), depth_map.max()
if depth_max - depth_min > 0:
depth_norm = ((depth_map - depth_min) / (depth_max - depth_min) * 255).astype(np.uint8)
else:
depth_norm = np.zeros_like(depth_map, dtype=np.uint8)
depth_img = Image.fromarray(depth_norm).resize((orig_w, orig_h), Image.BILINEAR)
return mask_img, depth_img
def create_tiled_control_image(tile_texture, width, height):
"""Tile the texture image to fill width x height."""
tw, th = tile_texture.size
control = Image.new("RGB", (width, height))
for y in range(0, height, th):
for x in range(0, width, tw):
control.paste(tile_texture, (x, y))
return control
@spaces.GPU(duration=120)
@torch.inference_mode()
def render_ai(room_image, tile_texture):
if room_image is None or tile_texture is None:
raise gr.Error("Room image and tile texture are required")
# Step 1: Get floor mask
mask_img, _ = predict.__wrapped__(room_image)
# Resize everything to 1024x1024 for SDXL
size = 1024
room_resized = room_image.resize((size, size), Image.LANCZOS)
mask_resized = mask_img.resize((size, size), Image.NEAREST)
# Step 2: Create tiled control image from tile texture
tile_size = max(64, size // 8)
tile_resized = tile_texture.resize((tile_size, tile_size), Image.LANCZOS)
control_image = create_tiled_control_image(tile_resized, size, size)
# Step 3: Run SDXL inpainting with ControlNet Tile
result = inpaint_pipe(
prompt="ceramic tile floor, tiled floor with repeating pattern, interior design photo, photorealistic",
negative_prompt="blurry, distorted, low quality, watermark, text",
image=room_resized,
mask_image=mask_resized,
control_image=control_image,
num_inference_steps=25,
guidance_scale=7.0,
controlnet_conditioning_scale=0.9,
strength=0.95,
generator=torch.Generator(device="cuda").manual_seed(42),
).images[0]
# Resize back to original dimensions
result = result.resize((room_image.size[0], room_image.size[1]), Image.LANCZOS)
return result
with gr.Blocks() as demo:
gr.Markdown("# Tile Visualizer API")
with gr.Tab("Segmentation"):
with gr.Row():
seg_input = gr.Image(type="pil", label="Room photo")
with gr.Row():
mask_output = gr.Image(type="pil", label="Floor mask")
depth_output = gr.Image(type="pil", label="Depth map")
seg_btn = gr.Button("Segment")
seg_btn.click(fn=predict, inputs=seg_input, outputs=[mask_output, depth_output])
with gr.Tab("AI Render"):
with gr.Row():
render_room = gr.Image(type="pil", label="Room photo")
render_tile = gr.Image(type="pil", label="Tile texture")
render_output = gr.Image(type="pil", label="Result")
render_btn = gr.Button("Render")
render_btn.click(fn=render_ai, inputs=[render_room, render_tile], outputs=render_output)
app = demo.app
from starlette.middleware.cors import CORSMiddleware
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
if __name__ == "__main__":
demo.launch(ssr_mode=False)
|