File size: 6,569 Bytes
0299827
3ccb4f2
 
 
9a8a023
3ccb4f2
 
9dd410a
3ccb4f2
 
9a8a023
 
 
 
 
 
 
3ccb4f2
9dd410a
 
05cc74d
faad11a
3ccb4f2
a3d0eee
 
 
 
 
 
 
 
 
 
faad11a
 
 
 
3ccb4f2
9a8a023
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ccb4f2
9a8a023
3ccb4f2
faad11a
3ccb4f2
 
 
 
 
 
 
 
 
faad11a
 
9dd410a
05cc74d
3ccb4f2
 
 
 
 
 
 
a3d0eee
 
 
3ccb4f2
 
 
 
 
7c12218
3ccb4f2
 
 
 
 
 
 
 
 
 
 
 
 
 
9a8a023
 
 
 
 
 
 
 
3ccb4f2
 
9a8a023
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ccb4f2
6bc9ac3
 
 
 
 
3ccb4f2
0299827
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import spaces
import numpy as np
import torch
import gradio as gr
from PIL import Image, ImageDraw
from transformers import (
    AutoImageProcessor,
    Mask2FormerForUniversalSegmentation,
    AutoModelForDepthEstimation,
)
from diffusers import (
    StableDiffusionXLControlNetInpaintPipeline,
    ControlNetModel,
    AutoencoderKL,
)

# ─── Segmentation + Depth models ─────────────────

seg_processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-large-ade-semantic")
seg_model = Mask2FormerForUniversalSegmentation.from_pretrained(
    "facebook/mask2former-swin-large-ade-semantic"
)

FLOOR_KEYWORDS = {'floor', 'flooring', 'rug', 'carpet', 'mat'}
FLOOR_IDS = set()
id2label = seg_model.config.id2label
for idx, label in id2label.items():
    if any(kw in label.lower() for kw in FLOOR_KEYWORDS):
        FLOOR_IDS.add(int(idx))
        print(f"Floor class: {idx} = {label}")
if not FLOOR_IDS:
    FLOOR_IDS = {3, 28}

depth_processor = AutoImageProcessor.from_pretrained("depth-anything/Depth-Anything-V2-Large-hf")
depth_model = AutoModelForDepthEstimation.from_pretrained(
    "depth-anything/Depth-Anything-V2-Large-hf", torch_dtype=torch.float16
)

# ─── SDXL + ControlNet Tile for AI rendering ─────

print("Loading ControlNet Tile + SDXL inpainting pipeline...")
controlnet = ControlNetModel.from_pretrained(
    "xinsir/controlnet-tile-sdxl-1.0",
    torch_dtype=torch.float16,
)
vae = AutoencoderKL.from_pretrained(
    "madebyollin/sdxl-vae-fp16-fix",
    torch_dtype=torch.float16,
)
inpaint_pipe = StableDiffusionXLControlNetInpaintPipeline.from_pretrained(
    "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
    controlnet=controlnet,
    vae=vae,
    torch_dtype=torch.float16,
    variant="fp16",
)
inpaint_pipe.enable_model_cpu_offload()
print("Pipeline loaded.")


@spaces.GPU(duration=60)
@torch.inference_mode()
def predict(image):
    if image is None:
        raise gr.Error("No image provided")

    orig_w, orig_h = image.size
    max_size = 1024
    scale = min(1.0, max_size / max(orig_w, orig_h))
    proc_w, proc_h = int(orig_w * scale), int(orig_h * scale)
    image_resized = image.resize((proc_w, proc_h), Image.LANCZOS)

    device = seg_model.device

    seg_inputs = seg_processor(images=image_resized, return_tensors="pt")
    seg_inputs = {k: v.to(device) for k, v in seg_inputs.items()}
    seg_outputs = seg_model(**seg_inputs)
    seg_result = seg_processor.post_process_semantic_segmentation(
        seg_outputs, target_sizes=[(proc_h, proc_w)]
    )[0]

    seg_map = seg_result.cpu().numpy()
    floor_mask = np.zeros((proc_h, proc_w), dtype=np.uint8)
    unique_classes = np.unique(seg_map)
    print(f"Detected classes: {[(int(c), id2label.get(c, '?')) for c in unique_classes]}")
    for class_id in FLOOR_IDS:
        floor_mask[seg_map == class_id] = 255

    mask_img = Image.fromarray(floor_mask).resize((orig_w, orig_h), Image.NEAREST)

    depth_inputs = depth_processor(images=image_resized, return_tensors="pt")
    depth_inputs = {k: v.to(device, dtype=torch.float16) if v.is_floating_point() else v.to(device) for k, v in depth_inputs.items()}
    depth_outputs = depth_model(**depth_inputs)
    depth_map = depth_outputs.predicted_depth.squeeze().cpu().numpy()

    depth_min, depth_max = depth_map.min(), depth_map.max()
    if depth_max - depth_min > 0:
        depth_norm = ((depth_map - depth_min) / (depth_max - depth_min) * 255).astype(np.uint8)
    else:
        depth_norm = np.zeros_like(depth_map, dtype=np.uint8)

    depth_img = Image.fromarray(depth_norm).resize((orig_w, orig_h), Image.BILINEAR)

    return mask_img, depth_img


def create_tiled_control_image(tile_texture, width, height):
    """Tile the texture image to fill width x height."""
    tw, th = tile_texture.size
    control = Image.new("RGB", (width, height))
    for y in range(0, height, th):
        for x in range(0, width, tw):
            control.paste(tile_texture, (x, y))
    return control


@spaces.GPU(duration=120)
@torch.inference_mode()
def render_ai(room_image, tile_texture):
    if room_image is None or tile_texture is None:
        raise gr.Error("Room image and tile texture are required")

    # Step 1: Get floor mask
    mask_img, _ = predict.__wrapped__(room_image)

    # Resize everything to 1024x1024 for SDXL
    size = 1024
    room_resized = room_image.resize((size, size), Image.LANCZOS)
    mask_resized = mask_img.resize((size, size), Image.NEAREST)

    # Step 2: Create tiled control image from tile texture
    tile_size = max(64, size // 8)
    tile_resized = tile_texture.resize((tile_size, tile_size), Image.LANCZOS)
    control_image = create_tiled_control_image(tile_resized, size, size)

    # Step 3: Run SDXL inpainting with ControlNet Tile
    result = inpaint_pipe(
        prompt="ceramic tile floor, tiled floor with repeating pattern, interior design photo, photorealistic",
        negative_prompt="blurry, distorted, low quality, watermark, text",
        image=room_resized,
        mask_image=mask_resized,
        control_image=control_image,
        num_inference_steps=25,
        guidance_scale=7.0,
        controlnet_conditioning_scale=0.9,
        strength=0.95,
        generator=torch.Generator(device="cuda").manual_seed(42),
    ).images[0]

    # Resize back to original dimensions
    result = result.resize((room_image.size[0], room_image.size[1]), Image.LANCZOS)

    return result


with gr.Blocks() as demo:
    gr.Markdown("# Tile Visualizer API")

    with gr.Tab("Segmentation"):
        with gr.Row():
            seg_input = gr.Image(type="pil", label="Room photo")
        with gr.Row():
            mask_output = gr.Image(type="pil", label="Floor mask")
            depth_output = gr.Image(type="pil", label="Depth map")
        seg_btn = gr.Button("Segment")
        seg_btn.click(fn=predict, inputs=seg_input, outputs=[mask_output, depth_output])

    with gr.Tab("AI Render"):
        with gr.Row():
            render_room = gr.Image(type="pil", label="Room photo")
            render_tile = gr.Image(type="pil", label="Tile texture")
        render_output = gr.Image(type="pil", label="Result")
        render_btn = gr.Button("Render")
        render_btn.click(fn=render_ai, inputs=[render_room, render_tile], outputs=render_output)

app = demo.app

from starlette.middleware.cors import CORSMiddleware
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])

if __name__ == "__main__":
    demo.launch(ssr_mode=False)