import os import io import hashlib import numpy as np import cv2 import torch from PIL import Image from fastapi import FastAPI, UploadFile, File, Form from fastapi.responses import Response, JSONResponse, HTMLResponse, FileResponse from skimage.measure import label, regionprops from sklearn.decomposition import PCA from transformers import ( OneFormerProcessor, OneFormerForUniversalSegmentation, Mask2FormerForUniversalSegmentation, AutoImageProcessor ) # ========================================================= # CONFIG # ========================================================= DEVICE = "cuda" if torch.cuda.is_available() else "cpu" ALPHA = 0.65 SEMANTIC_MODEL = "shi-labs/oneformer_ade20k_swin_large" INSTANCE_MODEL = "facebook/mask2former-swin-large-coco-instance" TEXTURE_ROOT = "textures" OBJECT_CLASSES = { "Wall": {"semantic": ["wall"], "panels": True}, "Floor": {"semantic": ["floor"], "panels": False}, "Door": {"semantic": ["door"], "panels": False}, "Cabinet": {"semantic": ["cabinet", "cupboard", "wardrobe"], "panels": True}, "Counter": {"semantic": ["counter"], "panels": False}, "Countertop": {"semantic": ["countertop", "worktop"], "panels": False}, } REMOVE_FROM_WALL_FLOOR = { "door", "window", "cabinet", "counter", "countertop", "island" } # ========================================================= # FASTAPI # ========================================================= app = FastAPI(title="Interior Texture API") # ========================================================= # GLOBAL CACHES (SAFE IF 1 WORKER) # ========================================================= DETECTION_CACHE = {} # image_hash → (image, objects) CURRENT_STATE = { # single user state "image_hash": None, "image": None, "objects": None, "object_textures": {}, "panel_textures": {} } # ========================================================= # LOAD MODELS ONCE # ========================================================= print("Loading models...") sem_proc = OneFormerProcessor.from_pretrained(SEMANTIC_MODEL) sem_model = OneFormerForUniversalSegmentation.from_pretrained( SEMANTIC_MODEL ).to(DEVICE).eval() inst_proc = AutoImageProcessor.from_pretrained(INSTANCE_MODEL) inst_model = Mask2FormerForUniversalSegmentation.from_pretrained( INSTANCE_MODEL ).to(DEVICE).eval() print("Models loaded") # ========================================================= # UTILITIES # ========================================================= def extract_semantic_mask(seg_map, id2label, keywords): mask = np.zeros_like(seg_map, dtype=np.uint8) for cid, name in id2label.items(): if any(k in name.lower() for k in keywords): mask[seg_map == cid] = 255 return mask def subtract_instances(mask, instances, remove_labels, coco_id2label): cleaned = mask.copy() inst_map = instances["segmentation"].cpu().numpy() for seg in instances["segments_info"]: if seg.get("score", 1.0) < 0.7: continue label_name = coco_id2label.get(seg["label_id"], "") if label_name in remove_labels: cleaned[inst_map == seg["id"]] = 0 return cleaned def edge_cleanup(mask, image_np): gray = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY) edges = cv2.Canny(gray, 80, 160) mask[edges > 0] = 0 kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7)) mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) mask = cv2.medianBlur(mask, 7) return mask def extract_panels(mask, min_ratio=0.003): lbl = label(mask) panels = [] for r in regionprops(lbl): if r.area > mask.size * min_ratio: p = np.zeros_like(mask) p[lbl == r.label] = 255 panels.append(p.astype(bool)) return panels def detect_objects(image_np): inputs = sem_proc( images=image_np, task_inputs=["semantic"], return_tensors="pt" ).to(DEVICE) with torch.no_grad(): sem_out = sem_model(**inputs) seg_map = sem_proc.post_process_semantic_segmentation( sem_out, target_sizes=[image_np.shape[:2]] )[0].cpu().numpy() inst_inputs = inst_proc(images=image_np, return_tensors="pt").to(DEVICE) with torch.no_grad(): inst_out = inst_model(**inst_inputs) instances = inst_proc.post_process_instance_segmentation( inst_out, target_sizes=[image_np.shape[:2]] )[0] objects = {} for obj, cfg in OBJECT_CLASSES.items(): mask = extract_semantic_mask( seg_map, sem_model.config.id2label, cfg["semantic"] ) if np.count_nonzero(mask) < image_np.size * 0.002: continue if obj in {"Wall", "Floor"}: mask = subtract_instances( mask, instances, REMOVE_FROM_WALL_FLOOR, inst_model.config.id2label ) mask = edge_cleanup(mask, image_np) panels = extract_panels(mask) if cfg["panels"] else [mask.astype(bool)] objects[obj] = panels return objects def detect_cached(image_bytes: bytes): image_hash = hashlib.md5(image_bytes).hexdigest() if image_hash in DETECTION_CACHE: return image_hash, *DETECTION_CACHE[image_hash] image = np.array(Image.open(io.BytesIO(image_bytes)).convert("RGB")) objects = detect_objects(image) DETECTION_CACHE[image_hash] = (image, objects) return image_hash, image, objects def apply_texture_panel(image, mask, texture, tile_type): H, W = image.shape[:2] tile_w, tile_h = (280, 560) if "12" in tile_type else (560, 560) tile = cv2.resize(texture, (tile_w, tile_h), interpolation=cv2.INTER_NEAREST) canvas = np.zeros((H, W, 3), dtype=np.uint8) for y in range(0, H, tile_h): for x in range(0, W, tile_w): canvas[y:y+tile_h, x:x+tile_w] = tile[:H-y, :W-x] gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY).astype(np.float32) / 255.0 light = cv2.GaussianBlur(gray, (41, 41), 0) light = np.repeat(light[:, :, None], 3, axis=2) canvas = canvas.astype(np.float32) canvas *= (0.75 + 0.25 * light) out = image.astype(np.float32) out[mask] = (1 - ALPHA) * out[mask] + ALPHA * canvas[mask] return out.astype(np.uint8) # ========================================================= # API ENDPOINTS # ========================================================= @app.post("/upload-image") async def upload_image(file: UploadFile = File(...)): image_bytes = await file.read() image_hash, image, objects = detect_cached(image_bytes) CURRENT_STATE["image_hash"] = image_hash CURRENT_STATE["image"] = image CURRENT_STATE["objects"] = objects CURRENT_STATE["object_textures"].clear() CURRENT_STATE["panel_textures"].clear() return {"objects": {k: len(v) for k, v in objects.items()}} # ========================================================= # LIST TEXTURES FOR OBJECT # ========================================================= @app.get("/textures/{object_name}") def list_textures(object_name: str): folder = os.path.join(TEXTURE_ROOT, object_name.lower()) if not os.path.isdir(folder): return [] return [ f for f in os.listdir(folder) if f.lower().endswith((".png", ".jpg", ".jpeg")) ] # ========================================================= # SERVE TEXTURE FILE # ========================================================= @app.get("/texture-file/{object_name}/{filename}") def get_texture_file(object_name: str, filename: str): path = os.path.join(TEXTURE_ROOT, object_name.lower(), filename) if not os.path.exists(path): return JSONResponse({"error": "Texture not found"}, status_code=404) return FileResponse(path) # ========================================================= # APPLY TEXTURE # ========================================================= @app.post("/apply-texture") async def apply_texture( object_name: str = Form(...), filename: str = Form(...), panel_index: int | None = Form(None), tile_type: str = Form("12 x 24 inches") ): if CURRENT_STATE["image"] is None: return JSONResponse( {"error": "Upload image first"}, status_code=400 ) object_name = object_name.strip().title() if object_name not in CURRENT_STATE["objects"]: return JSONResponse( {"error": f"{object_name} not detected in image"}, status_code=400 ) # 🔹 LOAD TEXTURE FROM DISK texture_path = os.path.join( TEXTURE_ROOT, object_name.lower(), filename ) if not os.path.isfile(texture_path): return JSONResponse( {"error": f"Texture not found: {filename}"}, status_code=404 ) tex = np.array( Image.open(texture_path).convert("RGB") ) # 🔹 STORE TEXTURE if panel_index is None: CURRENT_STATE["object_textures"][object_name] = tex else: CURRENT_STATE["panel_textures"][(object_name, panel_index)] = tex # 🔹 APPLY TEXTURES output = CURRENT_STATE["image"].copy() for obj, panels in CURRENT_STATE["objects"].items(): obj_tex = CURRENT_STATE["object_textures"].get(obj) for i, mask in enumerate(panels): tex_use = CURRENT_STATE["panel_textures"].get((obj, i), obj_tex) if tex_use is not None: output = apply_texture_panel( output, mask, tex_use, tile_type ) _, buf = cv2.imencode( ".png", cv2.cvtColor(output, cv2.COLOR_RGB2BGR) ) return Response(buf.tobytes(), media_type="image/png") # ========================================================= # UI (IMAGE UPLOAD + TEXTURE PREVIEW) # ========================================================= @app.get("/", response_class=HTMLResponse) def ui(): return """