File size: 5,304 Bytes
19d78dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.models.segmentation import deeplabv3_resnet50, DeepLabV3_ResNet50_Weights
from PIL import Image

from sentence_transformers import SentenceTransformer
from segformer_models import MeruSegformer
from MERU_utils import lorentz as L

_pre = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

@torch.inference_mode()
def load_model(device: str = "cpu"):
    model = MeruSegformer(num_classes=183, entail_weight=1.0).to(device)
    text_feats_pca = torch.load("class_embeddings_word/Coco_183_pca_bse1_5_64D.pt", map_location=device)
    text_feats = text_feats_pca.squeeze(dim=1)
    model.text_protos = nn.Parameter(text_feats.clone().to(device), requires_grad=False)
    state_dict = torch.load("saved_models/MERU_segformer_b4_100_hub_coco_entail_only.pth", map_location=device)
    model.load_state_dict(state_dict, strict=False)
    model.eval()

    sentence_model = SentenceTransformer('saved_models/bge-small-en-v1.5', device=device)
    pca_params = torch.load("saved_models/pca_model_coco.pt", map_location=device)

    return model, sentence_model, pca_params

_MODEL_CACHE = {}

def load_model_variant(variant: str, device: str = "cpu"):
    """
    Return a MeruSegformer instance for the requested variant; cached per (variant, device).
      - 'green'  -> COCO 183
      - 'orange' -> ADE 151
    NOTE: This only builds the vision model. Sentence model & PCA params remain shared.
    """
    key = (variant, device)
    if key in _MODEL_CACHE:
        return _MODEL_CACHE[key]

    if variant == "orange":
        # ADE (151 classes)
        model = MeruSegformer(num_classes=151, entail_weight=1.0).to(device)
        text_feats_pca = torch.load("class_embeddings_word/Ade_151_pca_bse1_5_64D.pt", map_location=device)
        text_feats = text_feats_pca.squeeze(dim=1)
        model.text_protos = nn.Parameter(text_feats.clone().to(device), requires_grad=False)
        state_dict = torch.load("saved_models/MERU_segformer_b4_100_hub_ade_entail_only.pth", map_location=device)
        model.load_state_dict(state_dict, strict=False)
    else:
        # default 'green': COCO (183 classes)
        model = MeruSegformer(num_classes=183, entail_weight=1.0).to(device)
        text_feats_pca = torch.load("class_embeddings_word/Coco_183_pca_bse1_5_64D.pt", map_location=device)
        text_feats = text_feats_pca.squeeze(dim=1)
        model.text_protos = nn.Parameter(text_feats.clone().to(device), requires_grad=False)
        state_dict = torch.load("saved_models/MERU_segformer_b4_100_hub_coco_entail_only.pth", map_location=device)
        model.load_state_dict(state_dict, strict=False)

    model.eval()
    _MODEL_CACHE[key] = model
    return model

@torch.inference_mode()
def segment_image(model, img: Image.Image, prompt = None, sentence_model = None, pca_params = None, device: str = "cpu"):
    # Keep original size for final output
    orig_w, orig_h = img.size

    img_proc = img.resize((512, 512))

    x = _pre(img_proc).unsqueeze(0).to(device, non_blocking=True)

    if device == "cuda":
        with torch.amp.autocast("cuda", dtype=torch.float16):
            image_feats_d = model.image_encoder(x, True)   # [B,H,W,D]
            if prompt:
                emb = sentence_model.encode(
                    prompt,
                    normalize_embeddings=True,
                    convert_to_tensor=True,
                )
                x_centered = emb - pca_params["mean"]
                text_embed = x_centered @ pca_params["components"].T
                l_txt_embed = L.exp_map0(text_embed.cuda()*model.textual_alpha.exp(), model.curv.exp())
                logits1 = L.oxy_angle_full(image_feats_d, l_txt_embed)
            else:
                text_feats_d  = model.text_encoder(True)       # [C,D]
                logits1 = L.pairwise_dist(image_feats_d, text_feats_d, model.curv.exp())  # [B,H,W,C]
    else:
        image_feats_d = model.image_encoder(x, True)   # [B,H,W,D]
        if prompt:
            emb = sentence_model.encode(
                prompt,
                normalize_embeddings=True,
                convert_to_tensor=True,
            )
            x_centered = emb - pca_params["mean"]
            text_embed = x_centered @ pca_params["components"].T
            l_txt_embed = L.exp_map0(text_embed.cuda()*model.textual_alpha.exp(), model.curv.exp())
            logits1 = L.oxy_angle_full(image_feats_d, l_txt_embed)
        else:
            text_feats_d  = model.text_encoder(True)       # [C,D]
            logits1 = L.pairwise_dist(image_feats_d, text_feats_d, model.curv.exp())  # [B,H,W,C]

    # Upsample distance tensor to ORIGINAL HxW before argmin over classes
    if logits1.shape[1:3] != (orig_h, orig_w):
        # [B,H,W,C] -> [B,C,H,W] for interpolation
        logits1 = logits1.permute(0, 3, 1, 2)
        logits1 = F.interpolate(logits1, size=(orig_h, orig_w), mode="nearest")
        # back to [B,H,W,C]
        # logits1 = logits1.permute(0, 2, 3, 1)

    # Argmin over class dimension -> [B,H,W]
    out = logits1.argmin(1)

    mask = out[0].to("cpu")  # [H, W] (original size)
    return mask