| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """Concet projector.""" |
| |
|
| | import pickle |
| |
|
| | import numpy as np |
| | import torch |
| | from torch import nn |
| |
|
| |
|
| | class ConceptProjector(nn.Module): |
| | """Encode and decode concept using CLIP.""" |
| |
|
| | def __init__(self, src_weights=None, tgt_weights=None): |
| | super(ConceptProjector, self).__init__() |
| | self.reset_weights(src_weights, tgt_weights) |
| |
|
| | def reset_weights(self, src_weights=None, tgt_weights=None): |
| | """Reset the normalized projection weights.""" |
| | if src_weights: |
| | with open(src_weights, "rb") as f: |
| | self.src_weights, self.concepts = pickle.load(f) |
| | self.src_weights = torch.from_numpy(self.src_weights) |
| | self.concepts = np.array(self.concepts) |
| | if tgt_weights: |
| | with open(tgt_weights, "rb") as f: |
| | self.tgt_weights, self.concepts = pickle.load(f) |
| | self.tgt_weights = torch.from_numpy(self.tgt_weights) |
| | self.concepts = np.array(self.concepts) |
| |
|
| | @staticmethod |
| | def maybe_convert(embeds, proj): |
| | """Convert inputs for safe projection.""" |
| | if embeds.dtype != torch.float32: |
| | embeds = embeds.float() |
| | if embeds.device != proj.device: |
| | proj = proj.to(device=embeds.device) |
| | return embeds, proj |
| |
|
| | def encode_src(self, src_embeds, logpi=True): |
| | """Encode source visual embedding via concept projection.""" |
| | src_embeds, self.src_weights = self.maybe_convert(src_embeds, self.src_weights) |
| | logits = nn.functional.normalize(src_embeds, dim=-1) @ self.src_weights |
| | return nn.functional.log_softmax(logits, dim=-1) if logpi else logits |
| |
|
| | def encode_tgt(self, tgt_embeds): |
| | """Encode target visual embedding via concept projection.""" |
| | tgt_embeds, self.tgt_weights = self.maybe_convert(tgt_embeds, self.tgt_weights) |
| | logits = nn.functional.normalize(tgt_embeds, dim=-1) @ self.tgt_weights |
| | return nn.functional.log_softmax(logits, dim=-1) |
| |
|
| | def decode(self, src_embeds, k=1, return_index=False, return_prob=False): |
| | """Return the top-k concepts of source visual embedding.""" |
| | src_embeds, self.src_weights = self.maybe_convert(src_embeds, self.src_weights) |
| | logits = nn.functional.normalize(src_embeds, dim=-1) @ self.src_weights |
| | probs = nn.functional.softmax(logits, dim=-1) |
| | if return_prob: |
| | return probs.cpu().numpy() |
| | score, index = [x.cpu().numpy() for x in probs.topk(k, -1)] |
| | return (index if return_index else self.concepts[index]), score |
| |
|