langsplat-backup / eval /openclip_encoder.py
tusharsangam's picture
Upload folder using huggingface_hub
9205b56 verified
#!/usr/bin/env python
import torch
import torchvision
import open_clip
class OpenCLIPNetwork:
def __init__(self, device):
self.process = torchvision.transforms.Compose(
[
torchvision.transforms.Resize((224, 224)),
torchvision.transforms.Normalize(
mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711],
),
]
)
self.clip_model_type = "ViT-B-16"
self.clip_model_pretrained = 'laion2b_s34b_b88k'
self.clip_n_dims = 512
model, _, _ = open_clip.create_model_and_transforms(
self.clip_model_type,
pretrained=self.clip_model_pretrained,
precision="fp16",
)
model.eval()
self.tokenizer = open_clip.get_tokenizer(self.clip_model_type)
self.model = model.to(device)
self.negatives = ("object", "things", "stuff", "texture")
self.positives = (" ",)
with torch.no_grad():
tok_phrases = torch.cat([self.tokenizer(phrase) for phrase in self.positives]).to(device)
self.pos_embeds = model.encode_text(tok_phrases)
tok_phrases = torch.cat([self.tokenizer(phrase) for phrase in self.negatives]).to(device)
self.neg_embeds = model.encode_text(tok_phrases)
self.pos_embeds /= self.pos_embeds.norm(dim=-1, keepdim=True)
self.neg_embeds /= self.neg_embeds.norm(dim=-1, keepdim=True)
@torch.no_grad()
def get_relevancy(self, embed: torch.Tensor, positive_id: int) -> torch.Tensor:
# embed: 32768x512
phrases_embeds = torch.cat([self.pos_embeds, self.neg_embeds], dim=0)
p = phrases_embeds.to(embed.dtype)
output = torch.mm(embed, p.T)
positive_vals = output[..., positive_id : positive_id + 1]
negative_vals = output[..., len(self.positives) :]
repeated_pos = positive_vals.repeat(1, len(self.negatives))
sims = torch.stack((repeated_pos, negative_vals), dim=-1)
softmax = torch.softmax(10 * sims, dim=-1)
best_id = softmax[..., 0].argmin(dim=1)
return torch.gather(softmax, 1, best_id[..., None, None].expand(best_id.shape[0], len(self.negatives), 2))[
:, 0, :
]
def encode_image(self, input, mask=None):
processed_input = self.process(input).half()
return self.model.encode_image(processed_input, mask=mask)
def encode_text(self, text_list, device):
text = self.tokenizer(text_list).to(device)
return self.model.encode_text(text)
def set_positives(self, text_list):
self.positives = text_list
with torch.no_grad():
tok_phrases = torch.cat(
[self.tokenizer(phrase) for phrase in self.positives]
).to(self.neg_embeds.device)
self.pos_embeds = self.model.encode_text(tok_phrases)
self.pos_embeds /= self.pos_embeds.norm(dim=-1, keepdim=True)
def set_semantics(self, text_list):
self.semantic_labels = text_list
with torch.no_grad():
tok_phrases = torch.cat([self.tokenizer(phrase) for phrase in self.semantic_labels]).to("cuda")
self.semantic_embeds = self.model.encode_text(tok_phrases)
self.semantic_embeds /= self.semantic_embeds.norm(dim=-1, keepdim=True)
def get_semantic_map(self, sem_map: torch.Tensor) -> torch.Tensor:
# embed: 3xhxwx512
n_levels, h, w, c = sem_map.shape
pos_num = self.semantic_embeds.shape[0]
phrases_embeds = torch.cat([self.semantic_embeds, self.neg_embeds], dim=0)
p = phrases_embeds.to(sem_map.dtype)
sem_pred = torch.zeros(n_levels, h, w)
for i in range(n_levels):
output = torch.mm(sem_map[i].view(-1, c), p.T)
softmax = torch.softmax(10 * output, dim=-1)
sem_pred[i] = torch.argmax(softmax, dim=-1).view(h, w)
sem_pred[i][sem_pred[i] >= pos_num] = -1
return sem_pred.long()
def get_max_across(self, sem_map):
n_phrases = len(self.positives)
n_phrases_sims = [None for _ in range(n_phrases)]
n_levels, h, w, _ = sem_map.shape
clip_output = sem_map.permute(1, 2, 0, 3).flatten(0, 1)
n_levels_sims = [None for _ in range(n_levels)]
for i in range(n_levels):
for j in range(n_phrases):
probs = self.get_relevancy(clip_output[..., i, :], j)
pos_prob = probs[..., 0:1]
n_phrases_sims[j] = pos_prob
n_levels_sims[i] = torch.stack(n_phrases_sims)
relev_map = torch.stack(n_levels_sims).view(n_levels, n_phrases, h, w)
return relev_map