|
|
import math |
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
class MultiheadAttention(torch.nn.Module): |
|
|
def __init__(self, d_model, n_head, n_token = 77, dropout = 0.1): |
|
|
super().__init__() |
|
|
self.d_model = d_model |
|
|
self.n_head = n_head |
|
|
self.d_head = d_model // n_head |
|
|
self.n_token = n_token |
|
|
|
|
|
self.query = torch.nn.Linear(d_model, d_model) |
|
|
self.key = torch.nn.Linear(d_model, d_model) |
|
|
self.value = torch.nn.Linear(d_model, d_model) |
|
|
self.proj = torch.nn.Linear(d_model, d_model) |
|
|
|
|
|
self.div = torch.sqrt(torch.tensor(self.d_head, dtype = self.query.weight.dtype)) |
|
|
|
|
|
self.softmax = torch.nn.Softmax(dim = -1) |
|
|
self.dropout = torch.nn.Dropout(dropout) |
|
|
|
|
|
self._reset_parameters() |
|
|
|
|
|
def _reset_parameters(self): |
|
|
torch.nn.init.xavier_uniform_(self.query.weight) |
|
|
torch.nn.init.xavier_uniform_(self.key.weight) |
|
|
torch.nn.init.xavier_uniform_(self.value.weight) |
|
|
torch.nn.init.xavier_uniform_(self.proj.weight) |
|
|
|
|
|
torch.nn.init.constant_(self.query.bias, 0.) |
|
|
torch.nn.init.constant_(self.key.bias, 0.) |
|
|
torch.nn.init.constant_(self.value.bias, 0.) |
|
|
torch.nn.init.constant_(self.proj.bias, 0.) |
|
|
|
|
|
def forward(self, q, k, v, mask = None, weight = None, alpha = None): |
|
|
b, s = q.shape[:2] |
|
|
b2, s2 = k.shape[:2] |
|
|
|
|
|
q = self.query(q) |
|
|
k = self.key(k) |
|
|
v = self.value(v) |
|
|
|
|
|
q = q.view(-1, s, self.n_head, self.d_head).transpose(1, 2) |
|
|
k = k.view(-1, s2, self.n_head, self.d_head).transpose(1, 2) |
|
|
v = v.view(-1, s2, self.n_head, self.d_head).transpose(1, 2) |
|
|
|
|
|
score = torch.matmul(q, k.transpose(-2, -1)) / self.div |
|
|
|
|
|
if mask is not None: |
|
|
mask = mask.unsqueeze(1) |
|
|
if mask.dim() != score.dim(): |
|
|
mask = mask.unsqueeze(2) |
|
|
score = score * mask |
|
|
|
|
|
if weight is not None: |
|
|
weight = weight.unsqueeze(1) |
|
|
if weight.dim() != score.dim(): |
|
|
weight = weight.unsqueeze(2) |
|
|
if self.n_token == s2: |
|
|
w = self.softmax(score) |
|
|
if weight is not None: |
|
|
w = w * weight |
|
|
w = w / (w.sum(dim = -1, keepdim = True) + 1e-12) |
|
|
else: |
|
|
target, ref = torch.split(score, [self.n_token, s2 - self.n_token], dim = -1) |
|
|
target = self.softmax(target) |
|
|
if alpha is None: |
|
|
alpha = 0.5 |
|
|
if weight is not None: |
|
|
ws = weight.shape[-1] |
|
|
target_weight, ref_weight = torch.split(weight, [self.n_token, ws - self.n_token], dim = -1) |
|
|
ref = ref.view(b2, self.n_head, s, ws - self.n_token, self.n_token) |
|
|
ref = self.softmax(ref) |
|
|
ref = ref * ref_weight.unsqueeze(-1) |
|
|
ref = ref.view(b2, self.n_head, s, s2 - self.n_token) |
|
|
ref = alpha * (ref / (ref.sum(dim = -1, keepdim = True) + 1e-12)) |
|
|
target = target * (1 - alpha) * target_weight |
|
|
w = torch.cat([target, ref], dim = -1) |
|
|
w = w / (w.sum(dim = -1, keepdim = True) + 1e-12) |
|
|
w = self.dropout(w) |
|
|
|
|
|
out = torch.matmul(w, v) |
|
|
out = out.transpose(1, 2).contiguous().view(b, s, self.d_model) |
|
|
out = self.proj(out) |
|
|
return out |
|
|
|
|
|
class QuickGELU(torch.nn.Module): |
|
|
def forward(self, x): |
|
|
return x * torch.sigmoid(1.702 * x) |
|
|
|
|
|
class TransformerBlock(torch.nn.Module): |
|
|
def __init__(self, emb_dim, n_head, ff_dim, n_token = 77, activation = "quick_gelu", dropout = 0.1): |
|
|
super().__init__() |
|
|
self.attn = MultiheadAttention(emb_dim, n_head, n_token = n_token, dropout = dropout) |
|
|
if activation.lower() == "gelu" or activation is None: |
|
|
self.act = torch.nn.GELU() |
|
|
elif activation.lower() == "relu": |
|
|
self.act = torch.nn.ReLU() |
|
|
elif activation.lower() == "quick_gelu": |
|
|
self.act = QuickGELU() |
|
|
else: |
|
|
self.act = activation |
|
|
self.ff = torch.nn.Sequential( |
|
|
torch.nn.Linear(emb_dim, ff_dim), |
|
|
self.act, |
|
|
torch.nn.Linear(ff_dim, emb_dim), |
|
|
) |
|
|
self.norm1 = torch.nn.LayerNorm(emb_dim) |
|
|
self.norm2 = torch.nn.LayerNorm(emb_dim) |
|
|
self.dropout1 = torch.nn.Dropout(dropout) |
|
|
self.dropout2 = torch.nn.Dropout(dropout) |
|
|
|
|
|
self._reset_parameters() |
|
|
|
|
|
def _reset_parameters(self): |
|
|
torch.nn.init.xavier_uniform_(self.ff[0].weight) |
|
|
torch.nn.init.xavier_uniform_(self.ff[2].weight) |
|
|
|
|
|
torch.nn.init.constant_(self.ff[0].bias, 0.) |
|
|
torch.nn.init.constant_(self.ff[2].bias, 0.) |
|
|
|
|
|
def forward(self, x, context = None, mask = None, weight = None, alpha = None): |
|
|
context = context if context is not None else x |
|
|
out = self.attn(x, context, context, mask = mask, weight = weight, alpha = alpha) |
|
|
out = x + self.dropout1(out) |
|
|
out = self.norm1(out) |
|
|
|
|
|
ff_out = self.ff(out) |
|
|
out = out + self.dropout2(ff_out) |
|
|
out = self.norm2(out) |
|
|
return out |
|
|
|
|
|
class PersonalizedAdapter(torch.nn.Module): |
|
|
def __init__(self, emb_dim, n_head, ff_dim, n_layer = 4, n_token = 77, proj = False, extra_proj = False, pos = True, cls_pos = False, cls_token = True, encode_ratio = None, activation = "quick_gelu", dropout = 0.1): |
|
|
super().__init__() |
|
|
self.n_layer = n_layer |
|
|
self.n_token = n_token |
|
|
self.cls_pos = cls_pos |
|
|
self.cls_token = cls_token |
|
|
self.encode_ratio = encode_ratio |
|
|
|
|
|
self.pre_proj = self.post_proj = None |
|
|
if encode_ratio and encode_ratio != 1: |
|
|
self.pre_proj = torch.nn.Linear(emb_dim, int(emb_dim // encode_ratio)) |
|
|
self.post_proj = torch.nn.Linear(int(emb_dim // encode_ratio), emb_dim) |
|
|
emb_dim = int(emb_dim // encode_ratio) |
|
|
n_head = int(n_head // encode_ratio) |
|
|
|
|
|
if activation.lower() == "gelu" or activation is None: |
|
|
self.act = torch.nn.GELU() |
|
|
elif activation.lower() == "relu": |
|
|
self.act = torch.nn.ReLU() |
|
|
elif activation.lower() == "quick_gelu": |
|
|
self.act = QuickGELU() |
|
|
else: |
|
|
self.act = activation |
|
|
self.base_query = torch.nn.Parameter(torch.empty(1, n_token + int(cls_token), emb_dim)) |
|
|
self.pos = torch.nn.Parameter(torch.empty(1, n_token + int(cls_pos and cls_token), emb_dim)) if pos else None |
|
|
self.init_query = None |
|
|
|
|
|
self.proj = None |
|
|
if proj: |
|
|
self.proj = torch.nn.Sequential( |
|
|
torch.nn.Linear(emb_dim, ff_dim), |
|
|
self.act, |
|
|
torch.nn.Linear(ff_dim, emb_dim), |
|
|
) |
|
|
|
|
|
self.extra_proj = None |
|
|
self.tf = torch.nn.ModuleList([TransformerBlock(emb_dim, n_head, ff_dim, n_token = n_token, activation = activation, dropout = dropout) for _ in range(n_layer)]) |
|
|
if extra_proj: |
|
|
self.extra_proj = torch.nn.ModuleList([torch.nn.Linear(emb_dim, emb_dim) for _ in range(n_layer)]) |
|
|
|
|
|
self._reset_parameters() |
|
|
|
|
|
def _reset_parameters(self): |
|
|
torch.nn.init.normal_(self.base_query, std = 0.02) |
|
|
if self.pos is not None: |
|
|
torch.nn.init.normal_(self.pos, std = 0.01) |
|
|
|
|
|
for proj in [self.pre_proj, self.post_proj]: |
|
|
if proj is not None: |
|
|
torch.nn.init.xavier_uniform_(proj.weight) |
|
|
torch.nn.init.constant_(proj.bias, 0.) |
|
|
for proj in [self.proj]: |
|
|
if proj is not None: |
|
|
torch.nn.init.xavier_uniform_(proj[0].weight) |
|
|
torch.nn.init.xavier_uniform_(proj[2].weight) |
|
|
|
|
|
torch.nn.init.constant_(proj[0].bias, 0.) |
|
|
torch.nn.init.constant_(proj[2].bias, 0.) |
|
|
if self.extra_proj is not None: |
|
|
for l in self.extra_proj: |
|
|
torch.nn.init.xavier_uniform_(l.weight) |
|
|
torch.nn.init.constant_(l.bias, 0.) |
|
|
|
|
|
def set_base_query(self, x): |
|
|
if not torch.is_tensor(x): |
|
|
x = torch.tensor(x, dtype=self.base_query.dtype).to(self.base_query.device) |
|
|
if x.dim() == 2: |
|
|
x = x.unsqueeze(0) |
|
|
self.init_query = x |
|
|
|
|
|
def normal_forward(self, x, context, mask = None, weight = None, alpha = None): |
|
|
out = x |
|
|
for i in range(self.n_layer): |
|
|
if self.extra_proj is not None: |
|
|
_context = self.extra_proj[i](self.act(context)) |
|
|
else: |
|
|
_context = context |
|
|
out = self.tf[i](out, _context, mask = mask, weight = weight, alpha = alpha) |
|
|
if self.cls_token: |
|
|
return out[:, :-1], out[:, -1] |
|
|
else: |
|
|
return out, None |
|
|
|
|
|
def forward(self, context, mask = None, weight = None, alpha = None, base_query = None): |
|
|
dtype = self.base_query.dtype |
|
|
if base_query is not None: |
|
|
x = base_query |
|
|
else: |
|
|
x = self.base_query if self.init_query is None else self.init_query |
|
|
x = x.type(dtype) |
|
|
if context is not None: |
|
|
context = context.type(dtype) |
|
|
if weight is not None: |
|
|
weight = weight.type(dtype) |
|
|
if self.encode_ratio is not None and x.shape[-1] != self.base_query.shape[-1]: |
|
|
x = self.pre_proj(x) |
|
|
if self.n_token < x.shape[1]: |
|
|
x, cls = x[:, :self.n_token], x[:, self.n_token:] |
|
|
else: |
|
|
cls = self.base_query[:, self.n_token:] if self.cls_token else None |
|
|
if self.pos is not None: |
|
|
if self.cls_pos and self.cls_token: |
|
|
x = x + self.pos[:, :self.n_token] |
|
|
if cls is not None: |
|
|
cls = cls + self.pos[:, self.n_token:] |
|
|
else: |
|
|
x = x + self.pos |
|
|
if self.cls_token: |
|
|
x = torch.cat([x, cls], dim = 1) |
|
|
x = x.repeat_interleave(context.shape[0], dim = 0) |
|
|
if self.encode_ratio is not None: |
|
|
if context is not None: |
|
|
context = self.pre_proj(context) |
|
|
if self.proj is not None: |
|
|
context = self.proj(context) |
|
|
out = self.normal_forward(x, context, mask = mask, weight = weight, alpha = alpha) |
|
|
if self.encode_ratio is not None: |
|
|
out = (self.post_proj(out[0]), self.post_proj(out[1]) if out[1] is not None else out[1]) |
|
|
return out |
|
|
|
|
|
class DrUM: |
|
|
def __init__(self, model, processor, n_layer = 8, proj = False, extra_proj = False, mlp_ratio = 4, pos = True, cls_pos = False, cls_token = True, encode_ratio = None, max_token_size = 256, activation = "quick_gelu", dropout = 0.1): |
|
|
config = model.config.text_config if hasattr(model.config, "text_config") else model.config |
|
|
if hasattr(config, "model_type") and config.model_type == "t5": |
|
|
self.d_model = config.d_model |
|
|
self.n_head = config.num_heads |
|
|
self.n_token = min(processor.model_max_length, max_token_size) |
|
|
self.clip = False |
|
|
self.cls_token = False |
|
|
else: |
|
|
self.d_model = config.hidden_size |
|
|
self.n_head = config.num_attention_heads |
|
|
self.n_token = config.max_position_embeddings |
|
|
self.clip = True |
|
|
self.cls_token = cls_token |
|
|
self.n_layer = n_layer |
|
|
self.proj = proj |
|
|
self.extra_proj = extra_proj |
|
|
self.mlp_ratio = mlp_ratio |
|
|
self.pos = pos |
|
|
self.cls_pos = cls_pos |
|
|
self.encode_ratio = encode_ratio |
|
|
self.activation = activation |
|
|
self.dropout = dropout |
|
|
|
|
|
self.model = model |
|
|
self.processor = processor |
|
|
self.adapter = PersonalizedAdapter(self.d_model, self.n_head, self.d_model // mlp_ratio, n_layer, self.n_token, proj = proj, extra_proj = extra_proj, pos = pos, cls_pos = cls_pos, cls_token = self.cls_token, encode_ratio = encode_ratio, activation = activation, dropout = dropout).to(model.device) |
|
|
|
|
|
self.train() |
|
|
self.to(model.device) |
|
|
|
|
|
def preprocess(self, text = None, image = None, return_tensors = "pt", padding = "max_length", truncation = True, **kwargs): |
|
|
feed = {"text":([text] if np.ndim(text) == 0 else list(text)) if text is not None else None, |
|
|
"return_tensors":return_tensors, |
|
|
"max_length":self.n_token, |
|
|
"padding":padding, |
|
|
"truncation":truncation, |
|
|
**kwargs} |
|
|
if not self.clip: |
|
|
feed["add_special_tokens"] = True |
|
|
if image is not None: |
|
|
feed["images"] = image |
|
|
return self.processor(**feed) |
|
|
|
|
|
def pool_text_hidden_state(self, hidden_state, x, padding = "max_length", truncation = True, **kwargs): |
|
|
if not self.clip: |
|
|
raise TypeError("T5 encoder does not support this function (pool_text_hidden_state).") |
|
|
if not hasattr(x, "items"): |
|
|
x = self.preprocess(text = x, padding = padding, truncation = truncation, **kwargs) |
|
|
if self.model.text_model.eos_token_id == 2: |
|
|
out = hidden_state[torch.arange(hidden_state.shape[0], device = hidden_state.device), |
|
|
x["input_ids"].to(dtype = torch.int, device = hidden_state.device).argmax(dim = -1),] |
|
|
else: |
|
|
out = hidden_state[torch.arange(hidden_state.shape[0], device = hidden_state.device), |
|
|
(x["input_ids"].to(dtype = torch.int, device = hidden_state.device) == self.model.text_model.eos_token_id).int().argmax(dim = -1),] |
|
|
return out |
|
|
|
|
|
def normalize_text_hidden_state(self, hidden_state): |
|
|
out = self.model.text_model.final_layer_norm(hidden_state.type(self.model.dtype)) if self.clip and hasattr(self.model.text_model, "final_layer_norm") else hidden_state |
|
|
return out |
|
|
|
|
|
def projection_text_hidden_state(self, hidden_state): |
|
|
out = self.model.text_projection(hidden_state.type(self.model.dtype)) if self.clip and hasattr(self.model, "text_projection") else hidden_state |
|
|
return out |
|
|
|
|
|
def encode_prompt(self, x, pooling = True, skip = -1, skip_pool = None, padding = "max_length", truncation = True, use_attn_mask = False, normalize = True, normalize_pool = True, **kwargs): |
|
|
if not hasattr(x, "items"): |
|
|
x = self.preprocess(text = x, padding = padding, truncation = truncation, **kwargs) |
|
|
input_ids = x["input_ids"].to(self.device) |
|
|
attention_mask = x["attention_mask"].to(self.device) if use_attn_mask else None |
|
|
with torch.no_grad(): |
|
|
if self.clip: |
|
|
hidden_state = self.model.text_model(output_hidden_states = True, input_ids = input_ids, attention_mask = attention_mask)["hidden_states"] |
|
|
pool, hidden_state = hidden_state[skip_pool if skip_pool is not None else skip], hidden_state[skip] |
|
|
hidden_state = self.normalize_text_hidden_state(hidden_state) if normalize else hidden_state |
|
|
else: |
|
|
hidden_state = self.model(input_ids = input_ids, attention_mask = attention_mask)[0] |
|
|
pool = None |
|
|
if pooling: |
|
|
if self.clip: |
|
|
with torch.no_grad(): |
|
|
pool = self.pool_text_hidden_state(self.normalize_text_hidden_state(pool) if normalize_pool else pool, x, **kwargs) |
|
|
return (hidden_state, pool) |
|
|
return hidden_state |
|
|
|
|
|
def get_text_feature(self, x, ref_x = None, weight = None, alpha = 0.3, skip = -1, batch_size = 64, padding = "max_length", truncation = True, use_attn_mask = False, **kwargs): |
|
|
if not self.clip: |
|
|
raise TypeError("T5 encoder does not support this function (get_text_feature).") |
|
|
with torch.no_grad(): |
|
|
pool_hidden_state = self(x, ref_x, weight = weight, alpha = alpha, pooling = True, skip_pool = skip, batch_size = batch_size, padding = padding, truncation = truncation, use_attn_mask = use_attn_mask, normalize_pool = True, **kwargs)[1] |
|
|
result = self.projection_text_hidden_state(pool_hidden_state) |
|
|
return result |
|
|
|
|
|
def get_image_feature(self, x, return_tensors = "pt", **kwargs): |
|
|
if not self.clip: |
|
|
raise TypeError("T5 encoder does not support this function (get_image_feature).") |
|
|
if hasattr(x, "items"): |
|
|
x = x["pixel_values"] |
|
|
elif not torch.is_tensor(x): |
|
|
x = self.preprocess(image = x, return_tensors = return_tensors, **kwargs)["pixel_values"] |
|
|
with torch.no_grad(): |
|
|
result = self.model.get_image_features(pixel_values = x.to(self.device)) |
|
|
return result |
|
|
|
|
|
def encode_context(self, ref_x, pooling = False, skip = -1, skip_pool = None, batch_size = 64, padding = "max_length", truncation = True, use_attn_mask = False, normalize = False, normalize_pool = False, **kwargs): |
|
|
if not hasattr(ref_x, "items"): |
|
|
if np.ndim(ref_x) == 0: |
|
|
ref_x = [[ref_x]] |
|
|
elif np.ndim(ref_x) == 1: |
|
|
ref_x = [ref_x] |
|
|
b, ref_size = len(ref_x), len(ref_x[0]) |
|
|
ref_x = np.reshape(ref_x, [b * ref_size]) |
|
|
ref_x = self.preprocess(text = list(ref_x), padding = padding, truncation = truncation, **kwargs) |
|
|
ref_x = {k:v for k, v in ref_x.items() if k in (["input_ids", "attention_mask"] if use_attn_mask else ["input_ids"])} |
|
|
else: |
|
|
b, ref_size = ref_x["input_ids"].shape[:2] |
|
|
ref_x = {k:v.view(b * ref_size, -1) for k, v in ref_x.items() if k in (["input_ids", "attention_mask"] if use_attn_mask else ["input_ids"])} |
|
|
hidden_state, pool_hidden_state = [], [] |
|
|
batch_indices = [(i * batch_size, min((b * ref_size), (i + 1) * batch_size)) for i in range(int(np.ceil((b * ref_size) / batch_size)))] |
|
|
for start, end in batch_indices: |
|
|
h, p = self.encode_prompt({k:v[start:end] for k, v in ref_x.items()}, pooling = True, skip = skip, skip_pool = skip_pool, padding = padding, truncation = truncation, use_attn_mask = use_attn_mask, normalize = normalize, normalize_pool = normalize_pool, **kwargs) |
|
|
hidden_state.append(h) |
|
|
if p is not None: |
|
|
pool_hidden_state.append(p) |
|
|
hidden_state = torch.cat(hidden_state, dim = 0) if 1 < len(hidden_state) else hidden_state[0] |
|
|
pool_hidden_state = torch.cat(pool_hidden_state, dim = 0) if 1 < len(pool_hidden_state) else (pool_hidden_state[0] if len(pool_hidden_state) == 1 else None) |
|
|
with torch.no_grad(): |
|
|
hidden_state = hidden_state.view(b, ref_size * hidden_state.shape[1], -1) |
|
|
if pooling: |
|
|
if self.clip: |
|
|
pool_hidden_state = pool_hidden_state.view(b, ref_size, -1) |
|
|
hidden_state = (hidden_state, pool_hidden_state) |
|
|
return hidden_state |
|
|
|
|
|
def __call__(self, x, ref_x = None, weight = None, alpha = 0.3, pooling = True, skip = -1, skip_pool = None, batch_size = 64, padding = "max_length", truncation = True, use_attn_mask = False, normalize = True, normalize_pool = True, training = False, **kwargs): |
|
|
if ref_x is not None or training: |
|
|
if training: |
|
|
context = weight = None |
|
|
else: |
|
|
_context, _context_pool = self.encode_context(ref_x, pooling = True, skip = skip, skip_pool = None, batch_size = batch_size, padding = padding, truncation = truncation, use_attn_mask = use_attn_mask, normalize = False, normalize_pool = False, **kwargs) |
|
|
if weight is not None: |
|
|
if not torch.is_tensor(weight): |
|
|
weight = torch.tensor(weight) |
|
|
if weight.dim() == 0: |
|
|
weight = weight.unsqueeze(0).unsqueeze(0) |
|
|
elif weight.dim() == 1: |
|
|
weight = weight.unsqueeze(0) |
|
|
weight = weight.to(self.device) |
|
|
else: |
|
|
weight = torch.ones((1, _context.shape[1] // self.n_token), dtype = torch.float32, device = _context.device) |
|
|
context = _context |
|
|
del _context, _context_pool |
|
|
result = self.encode_personalized_prompt(x, context, weight = weight, alpha = alpha, pooling = pooling, skip = skip, padding = padding, truncation = truncation, use_attn_mask = use_attn_mask, normalize = normalize, normalize_pool = normalize_pool, **kwargs) |
|
|
return result |
|
|
else: |
|
|
return self.encode_prompt(x, pooling = pooling, skip = skip, skip_pool = skip_pool, padding = padding, truncation = truncation, use_attn_mask = use_attn_mask, normalize = normalize, normalize_pool = normalize_pool, **kwargs) |
|
|
|
|
|
def encode_personalized_prompt(self, x, context = None, weight = None, alpha = 0.3, pooling = True, skip = -1, padding = "max_length", truncation = True, use_attn_mask = False, normalize = True, normalize_pool = True, **kwargs): |
|
|
if not torch.is_tensor(x): |
|
|
if not hasattr(x, "items"): |
|
|
x = self.preprocess(text = x, padding = padding, truncation = truncation, **kwargs) |
|
|
x = self.encode_prompt(x, pooling = False, skip = skip, skip_pool = None, padding = padding, truncation = truncation, use_attn_mask = use_attn_mask, normalize = False, normalize_pool = False, **kwargs) |
|
|
if context is None: |
|
|
context = x |
|
|
else: |
|
|
batch_size, n_token = x.shape[:2] |
|
|
if context.shape[0] == 1 and batch_size != 1: |
|
|
context = context.repeat_interleave(batch_size, dim = 0) |
|
|
if weight is not None and weight.shape[0] == 1: |
|
|
weight = weight.repeat_interleave(batch_size, dim = 0) |
|
|
context_size = context.shape[1] |
|
|
context = torch.cat([x, context], dim = 1) |
|
|
if weight is not None: |
|
|
extra_weight = torch.ones((batch_size, n_token), dtype = torch.float32, device = weight.device) |
|
|
weight = torch.cat([extra_weight, weight], dim = 1) |
|
|
hidden_state, pool = self.adapter(context, weight = weight, alpha = alpha) |
|
|
hidden_state = self.normalize_text_hidden_state(hidden_state) if normalize else hidden_state |
|
|
if pooling: |
|
|
pool = self.normalize_text_hidden_state(pool) if normalize_pool else pool |
|
|
return (hidden_state, pool) |
|
|
return hidden_state |
|
|
|
|
|
def to(self, device): |
|
|
self.model.to(device) |
|
|
self.adapter.to(device) |
|
|
self.device = device |
|
|
return self |
|
|
|
|
|
def eval(self): |
|
|
self.model.eval() |
|
|
if self.clip and hasattr(self.model, "text_projection"): |
|
|
self.model.text_model.final_layer_norm.requires_grad_(False) |
|
|
self.model.text_projection.requires_grad_(False) |
|
|
self.adapter.eval() |
|
|
return self |
|
|
|
|
|
def train(self): |
|
|
self.model.eval() |
|
|
if self.clip and hasattr(self.model, "text_projection"): |
|
|
self.model.text_model.final_layer_norm.requires_grad_(False) |
|
|
self.model.text_projection.requires_grad_(False) |
|
|
self.adapter.train() |
|
|
return self |
|
|
|
|
|
def parameters(self): |
|
|
return list(self.adapter.parameters()) |
|
|
|
|
|
def named_parameters(self): |
|
|
return list(self.adapter.named_parameters()) |