seg-app / api /model.py
mahmed10's picture
Upload 55 files
19d78dd verified
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.models.segmentation import deeplabv3_resnet50, DeepLabV3_ResNet50_Weights
from PIL import Image
from sentence_transformers import SentenceTransformer
from segformer_models import MeruSegformer
from MERU_utils import lorentz as L
_pre = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])
@torch.inference_mode()
def load_model(device: str = "cpu"):
model = MeruSegformer(num_classes=183, entail_weight=1.0).to(device)
text_feats_pca = torch.load("class_embeddings_word/Coco_183_pca_bse1_5_64D.pt", map_location=device)
text_feats = text_feats_pca.squeeze(dim=1)
model.text_protos = nn.Parameter(text_feats.clone().to(device), requires_grad=False)
state_dict = torch.load("saved_models/MERU_segformer_b4_100_hub_coco_entail_only.pth", map_location=device)
model.load_state_dict(state_dict, strict=False)
model.eval()
sentence_model = SentenceTransformer('saved_models/bge-small-en-v1.5', device=device)
pca_params = torch.load("saved_models/pca_model_coco.pt", map_location=device)
return model, sentence_model, pca_params
_MODEL_CACHE = {}
def load_model_variant(variant: str, device: str = "cpu"):
"""
Return a MeruSegformer instance for the requested variant; cached per (variant, device).
- 'green' -> COCO 183
- 'orange' -> ADE 151
NOTE: This only builds the vision model. Sentence model & PCA params remain shared.
"""
key = (variant, device)
if key in _MODEL_CACHE:
return _MODEL_CACHE[key]
if variant == "orange":
# ADE (151 classes)
model = MeruSegformer(num_classes=151, entail_weight=1.0).to(device)
text_feats_pca = torch.load("class_embeddings_word/Ade_151_pca_bse1_5_64D.pt", map_location=device)
text_feats = text_feats_pca.squeeze(dim=1)
model.text_protos = nn.Parameter(text_feats.clone().to(device), requires_grad=False)
state_dict = torch.load("saved_models/MERU_segformer_b4_100_hub_ade_entail_only.pth", map_location=device)
model.load_state_dict(state_dict, strict=False)
else:
# default 'green': COCO (183 classes)
model = MeruSegformer(num_classes=183, entail_weight=1.0).to(device)
text_feats_pca = torch.load("class_embeddings_word/Coco_183_pca_bse1_5_64D.pt", map_location=device)
text_feats = text_feats_pca.squeeze(dim=1)
model.text_protos = nn.Parameter(text_feats.clone().to(device), requires_grad=False)
state_dict = torch.load("saved_models/MERU_segformer_b4_100_hub_coco_entail_only.pth", map_location=device)
model.load_state_dict(state_dict, strict=False)
model.eval()
_MODEL_CACHE[key] = model
return model
@torch.inference_mode()
def segment_image(model, img: Image.Image, prompt = None, sentence_model = None, pca_params = None, device: str = "cpu"):
# Keep original size for final output
orig_w, orig_h = img.size
img_proc = img.resize((512, 512))
x = _pre(img_proc).unsqueeze(0).to(device, non_blocking=True)
if device == "cuda":
with torch.amp.autocast("cuda", dtype=torch.float16):
image_feats_d = model.image_encoder(x, True) # [B,H,W,D]
if prompt:
emb = sentence_model.encode(
prompt,
normalize_embeddings=True,
convert_to_tensor=True,
)
x_centered = emb - pca_params["mean"]
text_embed = x_centered @ pca_params["components"].T
l_txt_embed = L.exp_map0(text_embed.cuda()*model.textual_alpha.exp(), model.curv.exp())
logits1 = L.oxy_angle_full(image_feats_d, l_txt_embed)
else:
text_feats_d = model.text_encoder(True) # [C,D]
logits1 = L.pairwise_dist(image_feats_d, text_feats_d, model.curv.exp()) # [B,H,W,C]
else:
image_feats_d = model.image_encoder(x, True) # [B,H,W,D]
if prompt:
emb = sentence_model.encode(
prompt,
normalize_embeddings=True,
convert_to_tensor=True,
)
x_centered = emb - pca_params["mean"]
text_embed = x_centered @ pca_params["components"].T
l_txt_embed = L.exp_map0(text_embed.cuda()*model.textual_alpha.exp(), model.curv.exp())
logits1 = L.oxy_angle_full(image_feats_d, l_txt_embed)
else:
text_feats_d = model.text_encoder(True) # [C,D]
logits1 = L.pairwise_dist(image_feats_d, text_feats_d, model.curv.exp()) # [B,H,W,C]
# Upsample distance tensor to ORIGINAL HxW before argmin over classes
if logits1.shape[1:3] != (orig_h, orig_w):
# [B,H,W,C] -> [B,C,H,W] for interpolation
logits1 = logits1.permute(0, 3, 1, 2)
logits1 = F.interpolate(logits1, size=(orig_h, orig_w), mode="nearest")
# back to [B,H,W,C]
# logits1 = logits1.permute(0, 2, 3, 1)
# Argmin over class dimension -> [B,H,W]
out = logits1.argmin(1)
mask = out[0].to("cpu") # [H, W] (original size)
return mask