| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
|
|
| import copy |
| import logging |
| import math |
|
|
| from os.path import join as pjoin |
|
|
| import torch |
| import torch.nn as nn |
| import numpy as np |
|
|
| from torch.nn import BCEWithLogitsLoss,CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm |
| from torch.nn.modules.utils import _pair |
| from scipy import ndimage |
|
|
| import models.configs as configs |
| from models.attention import Attention |
| from models.embed import Embeddings |
| from models.mlp import Mlp |
|
|
| ATTENTION_Q = "MultiHeadDotProductAttention_1/query" |
| ATTENTION_K = "MultiHeadDotProductAttention_1/key" |
| ATTENTION_V = "MultiHeadDotProductAttention_1/value" |
| ATTENTION_OUT = "MultiHeadDotProductAttention_1/out" |
| FC_0 = "MlpBlock_3/Dense_0" |
| FC_1 = "MlpBlock_3/Dense_1" |
| ATTENTION_NORM = "LayerNorm_0" |
| MLP_NORM = "LayerNorm_2" |
|
|
| class Block(nn.Module): |
| def __init__(self, config, vis, mm=True): |
| super(Block, self).__init__() |
| self.hidden_size = config.hidden_size |
| self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6) |
| self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6) |
| if mm: |
| self.att_norm_text = LayerNorm(config.hidden_size, eps=1e-6) |
| self.ffn_norm_text = LayerNorm(config.hidden_size, eps=1e-6) |
| self.ffn_text = Mlp(config) |
|
|
| self.ffn = Mlp(config) |
| self.attn = Attention(config, vis, mm) |
|
|
| def forward(self, x, text=None): |
| if text is None: |
| h = x |
| x = self.attention_norm(x) |
| x, text,weights = self.attn(x) |
| |
| x = x + h |
|
|
| h = x |
| x = self.ffn_norm(x) |
| x = self.ffn(x) |
| x = x + h |
| return x |
| else: |
| h = x |
| h_text = text |
| x = self.attention_norm(x) |
| text = self.att_norm_text(text) |
|
|
| x, text, weights_img = self.attn(x, text) |
| |
| x = x + h |
| text = text + h_text |
|
|
| h = x |
| h_text = text |
| x = self.ffn_norm(x) |
| text = self.ffn_norm_text(text) |
| x = self.ffn(x) |
| text = self.ffn_text(text) |
| x = x + h |
| text = text + h_text |
| |
| return x |
|
|
| def load_from(self, weights, n_block): |
| ROOT = f"Transformer/encoderblock_{n_block}" |
| with torch.no_grad(): |
| query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, self.hidden_size).t() |
| key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t() |
| value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, self.hidden_size).t() |
| out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, self.hidden_size).t() |
|
|
| query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1) |
| key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1) |
| value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1) |
| out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1) |
|
|
| self.attn.query.weight.copy_(query_weight) |
| self.attn.key.weight.copy_(key_weight) |
| self.attn.value.weight.copy_(value_weight) |
| self.attn.out.weight.copy_(out_weight) |
| self.attn.query.bias.copy_(query_bias) |
| self.attn.key.bias.copy_(key_bias) |
| self.attn.value.bias.copy_(value_bias) |
| self.attn.out.bias.copy_(out_bias) |
|
|
| mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t() |
| mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t() |
| mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t() |
| mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t() |
|
|
| self.ffn.fc1.weight.copy_(mlp_weight_0) |
| self.ffn.fc2.weight.copy_(mlp_weight_1) |
| self.ffn.fc1.bias.copy_(mlp_bias_0) |
| self.ffn.fc2.bias.copy_(mlp_bias_1) |
|
|
| self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")])) |
| self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")])) |
| self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")])) |
| self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")])) |
|
|
|
|
|
|