Spaces:
Runtime error
Runtime error
| import torch | |
| from torch import nn | |
| from transformers import CLIPModel | |
| from transformers.models.clip.modeling_clip import _expand_mask | |
| from .utils import drop_sequence_mask | |
| def position_embedding(input, d_model): | |
| input = input.view(-1, 1) | |
| dim = torch.arange(d_model // 2, dtype=torch.float32, device=input.device).view(1, -1) | |
| sin = torch.sin(input / 10000 ** (2 * dim / d_model)) | |
| cos = torch.cos(input / 10000 ** (2 * dim / d_model)) | |
| out = torch.zeros((input.shape[0], d_model), device=input.device) | |
| out[:, ::2] = sin | |
| out[:, 1::2] = cos | |
| return out | |
| def sinusoid_encoding_table(max_len, d_model, padding_idx=None): | |
| pos = torch.arange(max_len, dtype=torch.float32) | |
| out = position_embedding(pos, d_model) | |
| if padding_idx is not None: | |
| out[padding_idx] = 0 | |
| return out | |
| class KnwlModel(nn.Module): | |
| def __init__(self, d_knwl, d_out, pt=0.1): | |
| super().__init__() | |
| self.pt = pt | |
| self.fc_knwl = nn.Linear(d_knwl, d_out, bias=False) | |
| self.fc_query = nn.Linear(d_knwl, d_out) | |
| self.pos = nn.Embedding(9, d_out) | |
| self.score1 = nn.Parameter(torch.randn(1, 1, d_out)) | |
| self.score2 = nn.Parameter(torch.randn(1, 1, d_out)) | |
| self.obj = nn.Parameter(torch.randn(1, 1, d_out)) | |
| self.attr = nn.Parameter(torch.randn(1, 1, d_out)) | |
| self.act = nn.Parameter(torch.randn(1, 1, d_out)) | |
| self.query = nn.Parameter(torch.randn(1, 1, d_out)) | |
| def device(self): | |
| return self.score1.device | |
| def prepare_input(self, knowledge): | |
| e = self.fc_knwl(knowledge["embed"]) | |
| p = self.pos(knowledge["pos"]) | |
| s = knowledge["score"].unsqueeze(-1) * self.score1 + self.score2 | |
| e_knwl = e + p + s | |
| m_knwl = drop_sequence_mask( | |
| *e_knwl.shape[:2], self.device, self.pt, self.training | |
| ) | |
| e = self.fc_query(knowledge["query"]) | |
| p = torch.arange(knowledge["query"].shape[1], device=self.device) | |
| p = self.pos(p[None, :]) | |
| e_query = e + p | |
| m_query = torch.ones( | |
| e_query.shape[:2], dtype=torch.long, device=self.device | |
| ) | |
| return e_knwl, m_knwl, e_query, m_query | |
| def forward(self, knowledge): | |
| e_obj, m_obj, e_query, m_query = self.prepare_input(knowledge["obj"]) | |
| e_attr, m_attr, _, _ = self.prepare_input(knowledge["attr"]) | |
| e_act, m_act, _, _ = self.prepare_input(knowledge["act"]) | |
| e_obj = e_obj + self.obj | |
| e_attr = e_attr + self.attr | |
| e_act = e_act + self.act | |
| e_query = e_query + self.query | |
| embeds = torch.cat([e_query, e_obj, e_attr, e_act], dim=1) | |
| masks = torch.cat([m_query, m_obj, m_attr, m_act], dim=1) | |
| return embeds, masks | |
| class KnwlEncoder(nn.Module): | |
| def __init__(self, d_out, num_layers=None, grad_ckpt=True): | |
| super().__init__() | |
| self.model = CLIPModel.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K", torch_dtype=torch.float16).vision_model | |
| self.model.encoder.gradient_checkpointing = grad_ckpt | |
| if num_layers is not None: | |
| self.model.encoder.layers = nn.ModuleList([ | |
| self.model.encoder.layers[i] for i in range(-num_layers, 0) | |
| ]) | |
| self.fc = nn.Linear(self.model.config.hidden_size, d_out, bias=False) | |
| self.d = self.model.config.hidden_size | |
| def forward(self, inputs_embeds, attention_mask): | |
| embed = self.model.pre_layrnorm(inputs_embeds) | |
| mask = _expand_mask(attention_mask, embed.dtype) | |
| embed = self.model.encoder( | |
| inputs_embeds=embed, | |
| attention_mask=mask, | |
| return_dict=True, | |
| )[0] | |
| embed = self.fc(self.model.post_layernorm(embed)) | |
| return embed | |