Commit ·
eefe43c
1
Parent(s): 6cc0f51
poc
Browse filesSigned-off-by: jupyterjazz <saba.sturua@jina.ai>
- embedding.py +1 -2
- mha.py +5 -3
- mlp.py +2 -2
- modeling_lora.py +33 -35
- modeling_xlm_roberta.py +1 -1
embedding.py
CHANGED
|
@@ -47,7 +47,6 @@ class XLMRobertaEmbeddings(nn.Module):
|
|
| 47 |
token_type_ids: (batch, seqlen)
|
| 48 |
"""
|
| 49 |
batch_size, seqlen = input_ids.shape
|
| 50 |
-
print('input shape', input_ids.shape)
|
| 51 |
embeddings = self.word_embeddings(input_ids, task='sts')
|
| 52 |
if self.max_position_embeddings > 0:
|
| 53 |
if position_ids is None:
|
|
@@ -58,6 +57,6 @@ class XLMRobertaEmbeddings(nn.Module):
|
|
| 58 |
if self.type_vocab_size > 0:
|
| 59 |
if token_type_ids is None:
|
| 60 |
token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
|
| 61 |
-
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
| 62 |
embeddings = embeddings + token_type_embeddings
|
| 63 |
return embeddings
|
|
|
|
| 47 |
token_type_ids: (batch, seqlen)
|
| 48 |
"""
|
| 49 |
batch_size, seqlen = input_ids.shape
|
|
|
|
| 50 |
embeddings = self.word_embeddings(input_ids, task='sts')
|
| 51 |
if self.max_position_embeddings > 0:
|
| 52 |
if position_ids is None:
|
|
|
|
| 57 |
if self.type_vocab_size > 0:
|
| 58 |
if token_type_ids is None:
|
| 59 |
token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
|
| 60 |
+
token_type_embeddings = self.token_type_embeddings(token_type_ids, task='sts')
|
| 61 |
embeddings = embeddings + token_type_embeddings
|
| 62 |
return embeddings
|
mha.py
CHANGED
|
@@ -341,6 +341,7 @@ class LinearResidual(nn.Linear):
|
|
| 341 |
"""Wrap nn.Linear to return the residual as well. For compatibility with FusedDense."""
|
| 342 |
|
| 343 |
def forward(self, input: torch.Tensor, task=None) -> torch.Tensor:
|
|
|
|
| 344 |
return super().forward(input, task=task), input
|
| 345 |
|
| 346 |
|
|
@@ -450,7 +451,7 @@ class MHA(nn.Module):
|
|
| 450 |
|
| 451 |
if fused_bias_fc and FusedDense is None:
|
| 452 |
raise ImportError("fused_dense is not installed")
|
| 453 |
-
|
| 454 |
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
|
| 455 |
linear_resid_cls = (
|
| 456 |
LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True)
|
|
@@ -647,7 +648,8 @@ class MHA(nn.Module):
|
|
| 647 |
if not self.return_residual:
|
| 648 |
qkv = self.Wqkv(x)
|
| 649 |
else:
|
| 650 |
-
qkv, x = self.Wqkv(x, task='
|
|
|
|
| 651 |
if self.dwconv:
|
| 652 |
qkv = rearrange(
|
| 653 |
self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
|
|
@@ -732,5 +734,5 @@ class MHA(nn.Module):
|
|
| 732 |
context = self._update_kvcache_attention(q, kv, inference_params)
|
| 733 |
else:
|
| 734 |
context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
|
| 735 |
-
out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
|
| 736 |
return out if not self.return_residual else (out, x)
|
|
|
|
| 341 |
"""Wrap nn.Linear to return the residual as well. For compatibility with FusedDense."""
|
| 342 |
|
| 343 |
def forward(self, input: torch.Tensor, task=None) -> torch.Tensor:
|
| 344 |
+
print('aq vafshe ar modis?')
|
| 345 |
return super().forward(input, task=task), input
|
| 346 |
|
| 347 |
|
|
|
|
| 451 |
|
| 452 |
if fused_bias_fc and FusedDense is None:
|
| 453 |
raise ImportError("fused_dense is not installed")
|
| 454 |
+
|
| 455 |
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
|
| 456 |
linear_resid_cls = (
|
| 457 |
LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True)
|
|
|
|
| 648 |
if not self.return_residual:
|
| 649 |
qkv = self.Wqkv(x)
|
| 650 |
else:
|
| 651 |
+
qkv, x = self.Wqkv(x, task='query', residual=True)
|
| 652 |
+
|
| 653 |
if self.dwconv:
|
| 654 |
qkv = rearrange(
|
| 655 |
self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
|
|
|
|
| 734 |
context = self._update_kvcache_attention(q, kv, inference_params)
|
| 735 |
else:
|
| 736 |
context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
|
| 737 |
+
out = self.out_proj(rearrange(context, "... h d -> ... (h d)"), task='passage')
|
| 738 |
return out if not self.return_residual else (out, x)
|
mlp.py
CHANGED
|
@@ -48,9 +48,9 @@ class Mlp(nn.Module):
|
|
| 48 |
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
|
| 49 |
|
| 50 |
def forward(self, x):
|
| 51 |
-
y = self.fc1(x)
|
| 52 |
y = self.activation(y)
|
| 53 |
-
y = self.fc2(y)
|
| 54 |
return y if not self.return_residual else (y, x)
|
| 55 |
|
| 56 |
|
|
|
|
| 48 |
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
|
| 49 |
|
| 50 |
def forward(self, x):
|
| 51 |
+
y = self.fc1(x, task='clustering')
|
| 52 |
y = self.activation(y)
|
| 53 |
+
y = self.fc2(y, task='sts')
|
| 54 |
return y if not self.return_residual else (y, x)
|
| 55 |
|
| 56 |
|
modeling_lora.py
CHANGED
|
@@ -9,6 +9,7 @@ import torch
|
|
| 9 |
import torch.nn.utils.parametrize as parametrize
|
| 10 |
from torch import nn
|
| 11 |
from torch.nn import Parameter
|
|
|
|
| 12 |
from transformers import PretrainedConfig
|
| 13 |
|
| 14 |
from .modeling_xlm_roberta import XLMRobertaFlashConfig, XLMRobertaModel, XLMRobertaPreTrainedModel
|
|
@@ -98,8 +99,7 @@ class LoRAParametrization(nn.Module):
|
|
| 98 |
# to mimic the original implementation: A @ dropout(x), we do (A * dropout(ones)) @ x
|
| 99 |
return A * self.lora_dropout(self.lora_dropout_mask)
|
| 100 |
|
| 101 |
-
def lora_forward(self, X, current_task
|
| 102 |
-
print('lora input shape', X.shape)
|
| 103 |
return (
|
| 104 |
X
|
| 105 |
+ torch.matmul(
|
|
@@ -114,10 +114,7 @@ class LoRAParametrization(nn.Module):
|
|
| 114 |
)
|
| 115 |
|
| 116 |
def forward(self, X):
|
| 117 |
-
|
| 118 |
-
out = self.forward_fn(X)
|
| 119 |
-
print(out.shape)
|
| 120 |
-
return out
|
| 121 |
|
| 122 |
@property
|
| 123 |
def current_task(self):
|
|
@@ -195,13 +192,20 @@ class LoRAParametrization(nn.Module):
|
|
| 195 |
alpha=alpha,
|
| 196 |
),
|
| 197 |
)
|
| 198 |
-
original_forward = layer.forward
|
| 199 |
|
| 200 |
-
def new_forward(self, input, task):
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
|
| 206 |
layer.forward = new_forward.__get__(layer, layer.__class__)
|
| 207 |
|
|
@@ -217,20 +221,20 @@ class LoRAParametrization(nn.Module):
|
|
| 217 |
alpha=alpha,
|
| 218 |
),
|
| 219 |
)
|
| 220 |
-
original_forward = layer.forward
|
| 221 |
|
| 222 |
def new_forward(self, input, task):
|
| 223 |
-
print('input here', input, input.shape)
|
| 224 |
-
print('func', original_forward)
|
| 225 |
-
# original_forward['parametrizations'] = None
|
| 226 |
-
# print('funcc', original_forward.__dict__)
|
| 227 |
-
output = original_forward(input)
|
| 228 |
-
print(output.shape, 'output shape')
|
| 229 |
task_idx = adaptation_map[task] if task else None
|
| 230 |
if task_idx:
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
|
| 235 |
layer.forward = new_forward.__get__(layer, layer.__class__)
|
| 236 |
|
|
@@ -278,13 +282,7 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
| 278 |
self._task_idx = None
|
| 279 |
# By default, disable LoRA until it's specified which adapter/task to use
|
| 280 |
self.current_task = None
|
| 281 |
-
|
| 282 |
-
if name == 'roberta.encoder.layers.22.mixer.Wqkv.parametrizations.weight.0.lora_A':
|
| 283 |
-
print('A0', param[0])
|
| 284 |
-
print('A1', param[1])
|
| 285 |
-
if name == 'roberta.encoder.layers.22.mixer.Wqkv.parametrizations.weight.0.lora_B':
|
| 286 |
-
print('B0', param[0])
|
| 287 |
-
print('B1', param[1])
|
| 288 |
|
| 289 |
@property
|
| 290 |
def main_params_trainable(self):
|
|
@@ -364,12 +362,12 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
| 364 |
f"Alternatively, set `task` to `None` if you want to disable LoRA."
|
| 365 |
)
|
| 366 |
task_idx = self._adaptation_map[task_name] if task_name else None
|
| 367 |
-
if self._task_idx != task_idx:
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
|
| 374 |
def forward(self, *args, task: Union[str, None] = LORA_NO_UPDATE, **kwargs):
|
| 375 |
if task != LORA_NO_UPDATE:
|
|
|
|
| 9 |
import torch.nn.utils.parametrize as parametrize
|
| 10 |
from torch import nn
|
| 11 |
from torch.nn import Parameter
|
| 12 |
+
from torch.nn import functional as F
|
| 13 |
from transformers import PretrainedConfig
|
| 14 |
|
| 15 |
from .modeling_xlm_roberta import XLMRobertaFlashConfig, XLMRobertaModel, XLMRobertaPreTrainedModel
|
|
|
|
| 99 |
# to mimic the original implementation: A @ dropout(x), we do (A * dropout(ones)) @ x
|
| 100 |
return A * self.lora_dropout(self.lora_dropout_mask)
|
| 101 |
|
| 102 |
+
def lora_forward(self, X, current_task):
|
|
|
|
| 103 |
return (
|
| 104 |
X
|
| 105 |
+ torch.matmul(
|
|
|
|
| 114 |
)
|
| 115 |
|
| 116 |
def forward(self, X):
|
| 117 |
+
return X
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
@property
|
| 120 |
def current_task(self):
|
|
|
|
| 192 |
alpha=alpha,
|
| 193 |
),
|
| 194 |
)
|
|
|
|
| 195 |
|
| 196 |
+
def new_forward(self, input, task, residual=False):
|
| 197 |
+
task_idx = adaptation_map[task] if task else None
|
| 198 |
+
if task_idx:
|
| 199 |
+
weights = self.parametrizations.weight[0].lora_forward(self.weight, current_task=task_idx)
|
| 200 |
+
else:
|
| 201 |
+
weights = self.weight
|
| 202 |
+
|
| 203 |
+
out = F.linear(input, weights, self.bias)
|
| 204 |
+
|
| 205 |
+
print('lin', task_idx, input.shape, out.shape)
|
| 206 |
+
if residual:
|
| 207 |
+
return out, input
|
| 208 |
+
return out
|
| 209 |
|
| 210 |
layer.forward = new_forward.__get__(layer, layer.__class__)
|
| 211 |
|
|
|
|
| 221 |
alpha=alpha,
|
| 222 |
),
|
| 223 |
)
|
|
|
|
| 224 |
|
| 225 |
def new_forward(self, input, task):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
task_idx = adaptation_map[task] if task else None
|
| 227 |
if task_idx:
|
| 228 |
+
weights = self.parametrizations.weight[0].lora_forward(self.weight, current_task=task_idx)
|
| 229 |
+
else:
|
| 230 |
+
weights = self.weight
|
| 231 |
+
|
| 232 |
+
out = F.embedding(
|
| 233 |
+
input, weights, self.padding_idx, self.max_norm,
|
| 234 |
+
self.norm_type, self.scale_grad_by_freq, self.sparse)
|
| 235 |
+
|
| 236 |
+
print('emb', task_idx, input.shape, out.shape)
|
| 237 |
+
return out
|
| 238 |
|
| 239 |
layer.forward = new_forward.__get__(layer, layer.__class__)
|
| 240 |
|
|
|
|
| 282 |
self._task_idx = None
|
| 283 |
# By default, disable LoRA until it's specified which adapter/task to use
|
| 284 |
self.current_task = None
|
| 285 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
|
| 287 |
@property
|
| 288 |
def main_params_trainable(self):
|
|
|
|
| 362 |
f"Alternatively, set `task` to `None` if you want to disable LoRA."
|
| 363 |
)
|
| 364 |
task_idx = self._adaptation_map[task_name] if task_name else None
|
| 365 |
+
# if self._task_idx != task_idx:
|
| 366 |
+
# # In this case, we need to update the LoRAs everywhere
|
| 367 |
+
# self._task_idx = task_idx
|
| 368 |
+
# self.apply(
|
| 369 |
+
# partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)
|
| 370 |
+
# )
|
| 371 |
|
| 372 |
def forward(self, *args, task: Union[str, None] = LORA_NO_UPDATE, **kwargs):
|
| 373 |
if task != LORA_NO_UPDATE:
|
modeling_xlm_roberta.py
CHANGED
|
@@ -313,7 +313,7 @@ class XLMRobertaPooler(nn.Module):
|
|
| 313 |
# We "pool" the model by simply taking the hidden state corresponding
|
| 314 |
# to the first token.
|
| 315 |
first_token_tensor = hidden_states[:, 0] if pool else hidden_states
|
| 316 |
-
pooled_output = self.dense(first_token_tensor)
|
| 317 |
pooled_output = self.activation(pooled_output)
|
| 318 |
return pooled_output
|
| 319 |
|
|
|
|
| 313 |
# We "pool" the model by simply taking the hidden state corresponding
|
| 314 |
# to the first token.
|
| 315 |
first_token_tensor = hidden_states[:, 0] if pool else hidden_states
|
| 316 |
+
pooled_output = self.dense(first_token_tensor, task='passage')
|
| 317 |
pooled_output = self.activation(pooled_output)
|
| 318 |
return pooled_output
|
| 319 |
|