# api/utils.py # ----------------------------------------------------------------------------- # Color utilities for mask visualization (COCO-183 and ADE-151 aware) # - Name-driven colors (e.g., water -> blue, sky -> sky blue) # - Prompt-aware palettes (supports multi-term like "human and horse") # - Legends for classes present in a mask # # Public functions: # - colorize_mask(mask_tensor, classes=None, dataset=None) -> PIL.Image # - overlay_mask(image, color_img, alpha=0.5) -> PIL.Image # - build_legend_from_mask(mask_tensor, classes=None, dataset=None) -> list[dict] # ----------------------------------------------------------------------------- from __future__ import annotations import re from typing import List, Tuple, Dict import numpy as np from PIL import Image # ============================================================================= # COCO-183 (green cone) CLASS NAMES # NOTE: This is the dataset order you expect from the COCO-183 model. # If your model's index order differs, update this list accordingly. # ============================================================================= CLASS_NAMES: List[str] = [ "unlabeled", "person","bicycle","car","motorcycle","airplane","bus","train","truck","boat", "traffic light","fire hydrant","street sign","stop sign","parking meter","bench", "bird","cat","dog","horse","sheep","cow","elephant","bear","zebra","giraffe", "hat","backpack","umbrella","shoe","eyeglasses","handbag","tie","suitcase", "frisbee","skis","snowboard","ball","kite","baseball_bat","baseball_glove", "skateboard","surfboard","tennis_racket","bottle","plate","wine_glass","cup", "fork","knife","spoon","bowl","banana","apple","sandwich","orange","broccoli", "carrot","hot_dog","pizza","donut","cake","chair","couch","potted_plant","bed", "mirror","dining_table","window","desk","toilet","door","tv","laptop","mouse", "remote","keyboard","cell_phone","microwave","oven","toaster","sink","refrigerator", "blender","book","clock","vase","scissors","teddy_bear","hair_dryer","toothbrush", "hair_brush", # "stuff" classes (COCO-Stuff-like) "banner","blanket","branch","bridge","building-other","bush","cabinet","cage", "cardboard","carpet","ceiling-other","ceiling-tile","cloth","clothes","clouds", "counter","cupboard","curtain","desk-stuff","dirt","door-stuff","fence", "floor-marble","floor-other","floor-stone","floor-tile","floor-wood","flower", "fog","food-other","fruit","furniture-other","grass","gravel","ground-other", "hill","house","leaves","light","mat","metal","mirror-stuff","moss","mountain", "mud","napkin","net","paper","pavement","pillow","plant-other","plastic", "platform","playingfield","railing","railroad","river","road","rock","roof","rug", "salad","sand","sea","shelf","sky-other","skyscraper","snow","solid-other", "stairs","stone","straw","structural-other","table","tent","textile-other", "towel","tree","vegetable","wall-brick","wall-concrete","wall-other","wall-panel", "wall-stone","wall-tile","wall-wood","water","waterdrops","window_blind", "window","wood", ] # Normalize COCO names to internal canonical form (underscored) CLASS_NAMES = [re.sub(r"\s+", "_", n.strip().lower()) for n in CLASS_NAMES] # ============================================================================= # ADE-151 (orange cone) CLASS NAMES (index order given by user) # ============================================================================= ADE_151_CLASS_NAMES: List[str] = [ "unlabeled","wall","building","blue_sky","floor","tree","ceiling","road","bed","window", "grass","cabinet","sidewalk","person","ground","door","table","mountain","flora","curtain", "chair","car","water","painting","sofa","shelf","house","sea","mirror","rug", "field","armchair","seat","fence","desk","rock","wardrobe","lamp","bathtub","rail", "cushion","pedestal","box","pillar","signboard","dresser","counter","sand","sink","skyscraper", "fireplace","refrigerator","grandstand","path","stairs","runway","display","snooker","pillow","screen_door", "stairway","river","bridge","bookcase","blind","tea_table","commode","flower","book","hill", "bench","countertop","stove","palm_tree","kitchen","computer","swivel_chair","boat","bar","console", "hovel","bus","towel","light","truck","tower","chandelier","sunshade","streetlight","booth", "television","aeroplane","dirt","apparel","pole","land","bannister","escalator","ottoman","bottle", "sideboard","poster","stage","van","ship","fountain","conveyer_belt","canopy","washer","plaything", "swimming_pool","stool","barrel","basket","waterfall","tent","bag","motorcycle","cradle","oven", "ball","food","stair","tank","marque","microwave","flowerpot","animal","bicycle","lake", "dishwasher","projector","blanket","sculpture","exhaust","sconce","vase","traffic_light","tray","ashcan", "fan","pier","screen","plate","monitor","notice_board","shower","radiator","glass","clock","flag", ] ADE_151_CLASS_NAMES = [n.strip().lower() for n in ADE_151_CLASS_NAMES] # ============================================================================= # Color dictionary (seeded with explicit choices; everything else inferred) # ============================================================================= # Base named colors; extend freely. Keys are canonical underscored names. NAMED_COLORS: Dict[str, Tuple[int, int, int]] = { # universal "unlabeled": (0, 0, 0), # people/animals/vehicles — COCO "person": (220, 20, 60), "human": (220, 20, 60), # alias "horse": (90, 60, 30), # per user's requested color "dog": (184, 134, 11), "cat": (255, 160, 122), "bird": (30, 144, 255), "sheep": (245, 222, 179), "cow": (139, 69, 19), "elephant": (128, 128, 128), "bear": (92, 64, 51), "zebra": (200, 200, 200), "giraffe": (218, 165, 32), "bicycle": (60, 180, 75), "car": (0, 90, 190), "motorcycle": (255, 80, 80), "airplane": (120, 120, 255), "aeroplane": (120, 120, 255), "bus": (255, 140, 0), "train": (70, 130, 180), "truck": (200, 120, 0), "boat": (0, 120, 170), "van": (80, 140, 220), "ship": (30, 100, 160), # nature / environment "water": (64, 164, 223), "river": (64, 164, 223), "lake": (64, 164, 223), "sea": (0, 105, 148), "waterfall": (120, 170, 230), "swimming_pool": (100, 200, 230), "sky": (135, 206, 235), "blue_sky": (135, 206, 235), "clouds": (220, 230, 240), "tree": (34, 139, 34), "palm_tree": (44, 159, 44), "flora": (52, 168, 83), "flower": (233, 84, 150), "grass": (76, 187, 23), "leaves": (76, 187, 23), "moss": (107, 142, 35), "hill": (88, 120, 80), "mountain": (96, 108, 118), "sand": (194, 178, 128), "ground": (120, 72, 48), "land": (120, 72, 48), "dirt": (115, 74, 53), "mud": (110, 74, 57), "rock": (101, 110, 120), "stone": (112, 128, 144), # roads / man-made terrain "road": (128, 128, 128), "sidewalk": (170, 170, 170), "pavement": (150, 150, 150), "path": (150, 150, 150), "playingfield": (100, 180, 100), "runway": (160, 160, 160), "stairs": (145, 145, 145), "stair": (145, 145, 145), "stairway": (145, 145, 145), "railroad": (100, 100, 100), "bridge": (120, 120, 140), "pier": (120, 120, 140), # buildings / structures "building": (160, 160, 160), "building-other": (160, 160, 160), "house": (170, 160, 160), "skyscraper": (120, 130, 140), "roof": (150, 120, 100), "wall": (180, 180, 180), "wall-brick": (178, 34, 34), "wall-concrete": (190, 190, 190), "wall-other": (170, 170, 170), "wall-panel": (160, 160, 160), "wall-stone": (135, 135, 135), "wall-tile": (200, 200, 200), "wall-wood": (181, 101, 29), "ceiling": (210, 210, 210), "ceiling-other": (210, 210, 210), "ceiling-tile": (220, 220, 220), "door": (150, 120, 90), "door-stuff": (150, 120, 90), "window": (175, 215, 230), "window_blind": (170, 210, 230), "mirror": (210, 220, 230), "mirror-stuff": (210, 220, 230), "light": (255, 230, 140), "streetlight": (240, 210, 120), "tower": (140, 140, 160), "fence": (189, 183, 107), "railing": (170, 170, 150), "pillar": (180, 180, 170), "signboard": (255, 200, 80), "poster": (255, 200, 140), "traffic_light": (50, 205, 50), # furniture / interior "chair": (205, 133, 63), "armchair": (200, 120, 80), "seat": (205, 133, 63), "bench": (160, 120, 70), "sofa": (160, 82, 45), "stool": (175, 125, 80), "table": (181, 101, 29), "dining_table": (181, 101, 29), "desk": (170, 100, 40), "desk-stuff": (170, 100, 40), "bed": (180, 130, 100), "cabinet": (145, 110, 70), "cupboard": (145, 110, 70), "wardrobe": (130, 90, 60), "dresser": (135, 95, 65), "sideboard": (135, 95, 65), "shelf": (140, 105, 65), "carpet": (150, 80, 60), "rug": (150, 80, 60), "curtain": (200, 180, 160), "pillow": (230, 200, 170), "cushion": (230, 200, 170), "blanket": (200, 170, 150), "towel": (220, 220, 200), "kitchen": (170, 170, 160), "counter": (150, 140, 130), "countertop": (160, 150, 140), "sink": (200, 210, 220), "stove": (140, 140, 140), "oven": (140, 140, 150), "microwave": (155, 160, 170), "dishwasher": (190, 200, 210), "washer": (190, 200, 210), "refrigerator": (200, 220, 235), # electronics "television": (70, 100, 160), "tv": (70, 100, 160), "monitor": (70, 100, 160), "screen": (70, 100, 160), "screen_door": (170, 210, 230), "projector": (100, 120, 160), "laptop": (70, 100, 160), "keyboard": (70, 90, 120), "mouse": (80, 80, 90), "remote": (60, 60, 70), "cell_phone": (100, 120, 140), # decor / smalls "vase": (186, 85, 211), "flowerpot": (170, 100, 60), "lamp": (255, 230, 140), "chandelier": (255, 220, 120), "sconce": (255, 225, 140), # materials / stuff "paper": (240, 240, 220), "plastic": (200, 200, 220), "metal": (180, 180, 190), "cloth": (220, 200, 190), "textile-other": (220, 200, 190), "glass": (200, 220, 240), "wood": (181, 101, 29), # foods "banana": (255, 225, 53), "apple": (220, 30, 30), "sandwich": (222, 184, 135), "orange": (255, 165, 0), "broccoli": (67, 160, 71), "carrot": (255, 127, 80), "pizza": (255, 180, 100), "donut": (210, 180, 140), "cake": (255, 218, 185), "hot_dog": (204, 102, 0), "salad": (143, 188, 143), "fruit": (255, 160, 122), "vegetable": (85, 139, 47), "food-other": (200, 160, 120), "food": (200, 160, 120), # utensils / containers "bottle": (135, 206, 250), "plate": (245, 245, 245), "wine_glass": (230, 230, 250), "cup": (250, 250, 250), "fork": (192, 192, 192), "knife": (192, 192, 192), "spoon": (192, 192, 192), "bowl": (255, 239, 213), "bag": (170, 120, 70), "box": (170, 120, 70), "barrel": (165, 105, 58), "basket": (170, 120, 70), "tray": (210, 210, 210), # misc (signage, banners) "banner": (255, 215, 0), "flag": (220, 20, 60), # other ADE things "booth": (160, 160, 160), "display": (100, 120, 160), "notice_board": (210, 180, 140), "signboard": (255, 200, 80), } # ============================================================================= # Aliases & normalization # ============================================================================= # Map user tokens to canonical dataset names _ALIASES: Dict[str, str] = { "human": "person", "humans": "person", "man": "person", "men": "person", "woman": "person", "women": "person", "people": "person", "tv": "television", "tv_monitor": "television", "monitor_tv": "television", "cell phone": "cell_phone", "cellphone": "cell_phone", "mobile": "cell_phone", "phone": "cell_phone", "teddy bear": "teddy_bear", "wine glass": "wine_glass", "baseball bat": "baseball_bat", "baseball glove": "baseball_glove", "tennis racket": "tennis_racket", "blue sky": "blue_sky", "traffic light": "traffic_light", "water fall": "waterfall", "window blind": "window_blind", "street light": "streetlight", # ADE terms mapping to close COCO terms (used in heuristics) "aeroplane": "airplane", } def _normalize_token(s: str) -> str: s = s.strip().lower() s = re.sub(r"[_\-]+", " ", s) s = re.sub(r"\s+", " ", s) s = _ALIASES.get(s, s) s = s.replace(" ", "_") return s def _resolve_prompt_item_to_names(item: str) -> List[str]: """ Turn one prompt item into one or more canonical names. Splits ONLY on 'and' as a WORD, or on &, /, + (with optional spaces). Critically, it won't split inside words like 'sand'. """ norm = item.strip() parts = re.split(r"\s*(?:\band\b|&|/|\+)\s*", norm, flags=re.IGNORECASE) out: List[str] = [] for p in parts: tok = _normalize_token(p) if not tok: continue if tok in ("background", "unlabeled"): tok = "unlabeled" out.append(tok) return out if out else ["unlabeled"] # ============================================================================= # Color selection fallback (heuristics) # ============================================================================= def _infer_color_from_name(name: str) -> Tuple[int, int, int]: """Heuristic fallback: choose a sensible color by keyword.""" n = name.lower().replace("_", " ") def c(r,g,b): return (r, g, b) # water/sky if "blue sky" in n or ("sky" in n and "blue" in n): return c(135,206,235) if "sky" in n: return c(135,206,235) if any(k in n for k in ["sea","ocean"]): return c(0,105,148) if any(k in n for k in ["river","lake","waterfall","pool"]): return c(64,164,223) if "water" in n: return c(64,164,223) # vegetation / land if any(k in n for k in ["tree","palm","flora","grass","plant","field","hill","land"]): return c(52,168,83) if any(k in n for k in ["sand","beach","desert"]): return c(194,178,128) if any(k in n for k in ["ground","dirt","soil","mud"]): return c(120,72,48) if any(k in n for k in ["rock","mountain","stone","skyscraper"]): return c(120,130,140) # man-made ground if any(k in n for k in ["road","street","sidewalk","path","runway","stairs","stair"]): return c(150,150,150) if "railroad" in n: return c(100,100,100) # humans & vehicles if any(k in n for k in ["person","people","human"]): return c(220,20,60) if any(k in n for k in ["car","truck","van","bus"]): return c(0,90,190) if any(k in n for k in ["bicycle","bike","motorcycle"]): return c(60,180,75) if any(k in n for k in ["boat","ship","ferry"]): return c(0,120,170) if any(k in n for k in ["aeroplane","airplane","aircraft"]): return c(120,120,255) # buildings / structures if any(k in n for k in ["building","house","wall","ceiling","door","window","bridge","tower"]): return c(170,170,170) # furniture if any(k in n for k in ["sofa","chair","stool","bench","table","desk","bed","cabinet","wardrobe","dresser","shelf"]): return c(181,101,29) # electronics / lighting if any(k in n for k in ["television","monitor","computer","screen","projector","tv"]): return c(70,100,160) if any(k in n for k in ["lamp","light","chandelier","sconce","streetlight"]): return c(255,230,140) # reflective / transparent if "mirror" in n or "glass" in n: return c(200, 220, 240) # decorative / misc if any(k in n for k in ["flower","vase","sculpture","poster","painting","flag"]): return c(186,85,211) # containers if any(k in n for k in ["bag","bottle","barrel","basket","box"]): return c(170,120,70) # kitchen / appliances if any(k in n for k in ["kitchen","sink","stove","oven","microwave","dishwasher","washer","refrigerator","counter","countertop"]): return c(175,185,195) # default neutral return c(128, 128, 128) def _color_for_name(name: str) -> Tuple[int, int, int]: key = _normalize_token(name) if key in NAMED_COLORS: return NAMED_COLORS[key] # also try alias canonical alias_back = _ALIASES.get(name.lower(), None) if alias_back and alias_back in NAMED_COLORS: return NAMED_COLORS[alias_back] return _infer_color_from_name(key) # ============================================================================= # Palettes (LUTs) # ============================================================================= def _build_lut_for_names(names: List[str]) -> np.ndarray: lut = np.zeros((len(names), 3), dtype=np.uint8) for i, n in enumerate(names): lut[i] = _color_for_name(n) return lut _COCO_LUT: np.ndarray | None = None _ADE_LUT: np.ndarray | None = None def _palette_for_dataset(dataset: str) -> np.ndarray: """Return [N,3] palette for dataset: 'coco' or 'ade'.""" global _COCO_LUT, _ADE_LUT if dataset == "ade": if _ADE_LUT is None: _ADE_LUT = _build_lut_for_names(ADE_151_CLASS_NAMES) return _ADE_LUT # default: coco if _COCO_LUT is None: _COCO_LUT = _build_lut_for_names(CLASS_NAMES) return _COCO_LUT def _palette_for_prompt_classes(classes: List[str]) -> np.ndarray: """ Build a per-request palette given a prompt class list. Index 0 is treated as 'unlabeled' (background) if present. Supports entries like 'human and horse' -> average of person + horse. """ n = len(classes) pal = np.zeros((n, 3), dtype=np.uint8) for idx, raw in enumerate(classes): if idx == 0: # background slot convention pal[idx] = np.array(NAMED_COLORS.get("unlabeled", (0, 0, 0)), dtype=np.uint8) continue names = _resolve_prompt_item_to_names(raw) # canonicalize each token through aliases (e.g., human -> person) canon_names = [ _ALIASES.get(n.replace("_"," "), n).replace(" ", "_") for n in names ] # compute average color across the resolved set cols = [ np.array(_color_for_name(n), dtype=np.float32) for n in canon_names ] if len(cols) == 0: rgb = np.array((128,128,128), dtype=np.float32) else: rgb = np.mean(cols, axis=0) pal[idx] = np.clip(rgb, 0, 255).astype(np.uint8) return pal # Display name for legend in prompt mode def _display_name_for_prompt_item(item: str) -> str: names = _resolve_prompt_item_to_names(item) if not names: return "unlabeled" disp = [] for n in names: if n in ("background", "unlabeled"): disp.append("unlabeled") else: # show canonical term (e.g., human -> person) nn = _ALIASES.get(n.replace("_", " "), n).replace(" ", "_") disp.append(nn) return "+".join(disp) # ============================================================================= # Public API # ============================================================================= def colorize_mask(mask_tensor, classes: List[str] | None = None, dataset: str | None = None) -> Image.Image: """ Colorize a [H,W] mask. - If `classes` is provided (prompt mode), use prompt palette: index 0 is background (unlabeled), others per item or averaged - Else, choose dataset palette: 'ade' (151) or default 'coco' (183) """ mask = np.array(mask_tensor, dtype=np.int32) h, w = mask.shape if classes is not None: pal = _palette_for_prompt_classes(classes) else: pal = _palette_for_dataset("ade" if dataset == "ade" else "coco") color = np.zeros((h, w, 3), dtype=np.uint8) valid = (mask >= 0) & (mask < pal.shape[0]) color[valid] = pal[mask[valid]] return Image.fromarray(color, mode="RGB") def overlay_mask(image: Image.Image, color: Image.Image, alpha: float = 0.5) -> Image.Image: if color.size != image.size: color = color.resize(image.size, resample=Image.NEAREST) return Image.blend(image.convert("RGB"), color.convert("RGB"), alpha) def build_legend_from_mask(mask_tensor, classes: List[str] | None = None, dataset: str | None = None): """ Build a compact legend for the classes PRESENT in the mask. Returns a list of entries: {'index': int, 'name': str, 'color': [r,g,b]} - In prompt mode, names are prompt-derived (with '+' for multi-terms) - In dataset mode, names come from the dataset class list (COCO or ADE) """ mask = np.array(mask_tensor, dtype=np.int64) present = np.unique(mask[(mask >= 0)]) legend: List[Dict] = [] if classes is not None: pal = _palette_for_prompt_classes(classes) for idx in present: if 0 <= idx < pal.shape[0]: raw_item = classes[int(idx)] if int(idx) < len(classes) else "unlabeled" try: name = _display_name_for_prompt_item(raw_item) except Exception: name = str(raw_item) col = pal[int(idx)] legend.append({ "index": int(idx), "name": name, "color": [int(col[0]), int(col[1]), int(col[2])], }) else: if dataset == "ade": names = ADE_151_CLASS_NAMES pal = _palette_for_dataset("ade") else: names = CLASS_NAMES pal = _palette_for_dataset("coco") for idx in present: if 0 <= idx < len(names): col = pal[int(idx)] legend.append({ "index": int(idx), "name": names[int(idx)], "color": [int(col[0]), int(col[1]), int(col[2])], }) legend.sort(key=lambda e: (0 if e["index"] == 0 else 1, e["index"])) return legend