from typing import Tuple import torch import torch.nn as nn import torch.nn.functional as F from torchvision import transforms from torchvision.models.segmentation import deeplabv3_resnet50, DeepLabV3_ResNet50_Weights from PIL import Image from sentence_transformers import SentenceTransformer from segformer_models import MeruSegformer from MERU_utils import lorentz as L _pre = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]) ]) @torch.inference_mode() def load_model(device: str = "cpu"): model = MeruSegformer(num_classes=183, entail_weight=1.0).to(device) text_feats_pca = torch.load("class_embeddings_word/Coco_183_pca_bse1_5_64D.pt", map_location=device) text_feats = text_feats_pca.squeeze(dim=1) model.text_protos = nn.Parameter(text_feats.clone().to(device), requires_grad=False) state_dict = torch.load("saved_models/MERU_segformer_b4_100_hub_coco_entail_only.pth", map_location=device) model.load_state_dict(state_dict, strict=False) model.eval() sentence_model = SentenceTransformer('saved_models/bge-small-en-v1.5', device=device) pca_params = torch.load("saved_models/pca_model_coco.pt", map_location=device) return model, sentence_model, pca_params _MODEL_CACHE = {} def load_model_variant(variant: str, device: str = "cpu"): """ Return a MeruSegformer instance for the requested variant; cached per (variant, device). - 'green' -> COCO 183 - 'orange' -> ADE 151 NOTE: This only builds the vision model. Sentence model & PCA params remain shared. """ key = (variant, device) if key in _MODEL_CACHE: return _MODEL_CACHE[key] if variant == "orange": # ADE (151 classes) model = MeruSegformer(num_classes=151, entail_weight=1.0).to(device) text_feats_pca = torch.load("class_embeddings_word/Ade_151_pca_bse1_5_64D.pt", map_location=device) text_feats = text_feats_pca.squeeze(dim=1) model.text_protos = nn.Parameter(text_feats.clone().to(device), requires_grad=False) state_dict = torch.load("saved_models/MERU_segformer_b4_100_hub_ade_entail_only.pth", map_location=device) model.load_state_dict(state_dict, strict=False) else: # default 'green': COCO (183 classes) model = MeruSegformer(num_classes=183, entail_weight=1.0).to(device) text_feats_pca = torch.load("class_embeddings_word/Coco_183_pca_bse1_5_64D.pt", map_location=device) text_feats = text_feats_pca.squeeze(dim=1) model.text_protos = nn.Parameter(text_feats.clone().to(device), requires_grad=False) state_dict = torch.load("saved_models/MERU_segformer_b4_100_hub_coco_entail_only.pth", map_location=device) model.load_state_dict(state_dict, strict=False) model.eval() _MODEL_CACHE[key] = model return model @torch.inference_mode() def segment_image(model, img: Image.Image, prompt = None, sentence_model = None, pca_params = None, device: str = "cpu"): # Keep original size for final output orig_w, orig_h = img.size img_proc = img.resize((512, 512)) x = _pre(img_proc).unsqueeze(0).to(device, non_blocking=True) if device == "cuda": with torch.amp.autocast("cuda", dtype=torch.float16): image_feats_d = model.image_encoder(x, True) # [B,H,W,D] if prompt: emb = sentence_model.encode( prompt, normalize_embeddings=True, convert_to_tensor=True, ) x_centered = emb - pca_params["mean"] text_embed = x_centered @ pca_params["components"].T l_txt_embed = L.exp_map0(text_embed.cuda()*model.textual_alpha.exp(), model.curv.exp()) logits1 = L.oxy_angle_full(image_feats_d, l_txt_embed) else: text_feats_d = model.text_encoder(True) # [C,D] logits1 = L.pairwise_dist(image_feats_d, text_feats_d, model.curv.exp()) # [B,H,W,C] else: image_feats_d = model.image_encoder(x, True) # [B,H,W,D] if prompt: emb = sentence_model.encode( prompt, normalize_embeddings=True, convert_to_tensor=True, ) x_centered = emb - pca_params["mean"] text_embed = x_centered @ pca_params["components"].T l_txt_embed = L.exp_map0(text_embed.cuda()*model.textual_alpha.exp(), model.curv.exp()) logits1 = L.oxy_angle_full(image_feats_d, l_txt_embed) else: text_feats_d = model.text_encoder(True) # [C,D] logits1 = L.pairwise_dist(image_feats_d, text_feats_d, model.curv.exp()) # [B,H,W,C] # Upsample distance tensor to ORIGINAL HxW before argmin over classes if logits1.shape[1:3] != (orig_h, orig_w): # [B,H,W,C] -> [B,C,H,W] for interpolation logits1 = logits1.permute(0, 3, 1, 2) logits1 = F.interpolate(logits1, size=(orig_h, orig_w), mode="nearest") # back to [B,H,W,C] # logits1 = logits1.permute(0, 2, 3, 1) # Argmin over class dimension -> [B,H,W] out = logits1.argmin(1) mask = out[0].to("cpu") # [H, W] (original size) return mask