| from typing import Tuple |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torchvision import transforms |
| from torchvision.models.segmentation import deeplabv3_resnet50, DeepLabV3_ResNet50_Weights |
| from PIL import Image |
|
|
| from sentence_transformers import SentenceTransformer |
| from segformer_models import MeruSegformer |
| from MERU_utils import lorentz as L |
|
|
| _pre = transforms.Compose([ |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]) |
| ]) |
|
|
| @torch.inference_mode() |
| def load_model(device: str = "cpu"): |
| model = MeruSegformer(num_classes=183, entail_weight=1.0).to(device) |
| text_feats_pca = torch.load("class_embeddings_word/Coco_183_pca_bse1_5_64D.pt", map_location=device) |
| text_feats = text_feats_pca.squeeze(dim=1) |
| model.text_protos = nn.Parameter(text_feats.clone().to(device), requires_grad=False) |
| state_dict = torch.load("saved_models/MERU_segformer_b4_100_hub_coco_entail_only.pth", map_location=device) |
| model.load_state_dict(state_dict, strict=False) |
| model.eval() |
|
|
| sentence_model = SentenceTransformer('saved_models/bge-small-en-v1.5', device=device) |
| pca_params = torch.load("saved_models/pca_model_coco.pt", map_location=device) |
|
|
| return model, sentence_model, pca_params |
|
|
| _MODEL_CACHE = {} |
|
|
| def load_model_variant(variant: str, device: str = "cpu"): |
| """ |
| Return a MeruSegformer instance for the requested variant; cached per (variant, device). |
| - 'green' -> COCO 183 |
| - 'orange' -> ADE 151 |
| NOTE: This only builds the vision model. Sentence model & PCA params remain shared. |
| """ |
| key = (variant, device) |
| if key in _MODEL_CACHE: |
| return _MODEL_CACHE[key] |
|
|
| if variant == "orange": |
| |
| model = MeruSegformer(num_classes=151, entail_weight=1.0).to(device) |
| text_feats_pca = torch.load("class_embeddings_word/Ade_151_pca_bse1_5_64D.pt", map_location=device) |
| text_feats = text_feats_pca.squeeze(dim=1) |
| model.text_protos = nn.Parameter(text_feats.clone().to(device), requires_grad=False) |
| state_dict = torch.load("saved_models/MERU_segformer_b4_100_hub_ade_entail_only.pth", map_location=device) |
| model.load_state_dict(state_dict, strict=False) |
| else: |
| |
| model = MeruSegformer(num_classes=183, entail_weight=1.0).to(device) |
| text_feats_pca = torch.load("class_embeddings_word/Coco_183_pca_bse1_5_64D.pt", map_location=device) |
| text_feats = text_feats_pca.squeeze(dim=1) |
| model.text_protos = nn.Parameter(text_feats.clone().to(device), requires_grad=False) |
| state_dict = torch.load("saved_models/MERU_segformer_b4_100_hub_coco_entail_only.pth", map_location=device) |
| model.load_state_dict(state_dict, strict=False) |
|
|
| model.eval() |
| _MODEL_CACHE[key] = model |
| return model |
|
|
| @torch.inference_mode() |
| def segment_image(model, img: Image.Image, prompt = None, sentence_model = None, pca_params = None, device: str = "cpu"): |
| |
| orig_w, orig_h = img.size |
|
|
| img_proc = img.resize((512, 512)) |
|
|
| x = _pre(img_proc).unsqueeze(0).to(device, non_blocking=True) |
|
|
| if device == "cuda": |
| with torch.amp.autocast("cuda", dtype=torch.float16): |
| image_feats_d = model.image_encoder(x, True) |
| if prompt: |
| emb = sentence_model.encode( |
| prompt, |
| normalize_embeddings=True, |
| convert_to_tensor=True, |
| ) |
| x_centered = emb - pca_params["mean"] |
| text_embed = x_centered @ pca_params["components"].T |
| l_txt_embed = L.exp_map0(text_embed.cuda()*model.textual_alpha.exp(), model.curv.exp()) |
| logits1 = L.oxy_angle_full(image_feats_d, l_txt_embed) |
| else: |
| text_feats_d = model.text_encoder(True) |
| logits1 = L.pairwise_dist(image_feats_d, text_feats_d, model.curv.exp()) |
| else: |
| image_feats_d = model.image_encoder(x, True) |
| if prompt: |
| emb = sentence_model.encode( |
| prompt, |
| normalize_embeddings=True, |
| convert_to_tensor=True, |
| ) |
| x_centered = emb - pca_params["mean"] |
| text_embed = x_centered @ pca_params["components"].T |
| l_txt_embed = L.exp_map0(text_embed.cuda()*model.textual_alpha.exp(), model.curv.exp()) |
| logits1 = L.oxy_angle_full(image_feats_d, l_txt_embed) |
| else: |
| text_feats_d = model.text_encoder(True) |
| logits1 = L.pairwise_dist(image_feats_d, text_feats_d, model.curv.exp()) |
|
|
| |
| if logits1.shape[1:3] != (orig_h, orig_w): |
| |
| logits1 = logits1.permute(0, 3, 1, 2) |
| logits1 = F.interpolate(logits1, size=(orig_h, orig_w), mode="nearest") |
| |
| |
|
|
| |
| out = logits1.argmin(1) |
|
|
| mask = out[0].to("cpu") |
| return mask |