Commit ·
814cbbb
1
Parent(s): 65e9690
some fixes and suggestions
Browse filesSigned-off-by: Meow <ongjackm@gmail.com>
- embedding.py +2 -2
- mha.py +6 -3
- mlp.py +2 -2
- modeling_lora.py +5 -3
- modeling_xlm_roberta.py +2 -1
embedding.py
CHANGED
|
@@ -48,7 +48,7 @@ class XLMRobertaEmbeddings(nn.Module):
|
|
| 48 |
"""
|
| 49 |
batch_size, seqlen = input_ids.shape
|
| 50 |
if adapter_mask is not None:
|
| 51 |
-
unique_tasks = torch.unique(adapter_mask)
|
| 52 |
embedding_dtype = next(self.word_embeddings.parameters()).dtype
|
| 53 |
embeddings = torch.empty(*input_ids.shape, self.word_embeddings.embedding_dim,
|
| 54 |
dtype=embedding_dtype, device=input_ids.device)
|
|
@@ -71,7 +71,7 @@ class XLMRobertaEmbeddings(nn.Module):
|
|
| 71 |
token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
|
| 72 |
|
| 73 |
if adapter_mask is not None:
|
| 74 |
-
unique_tasks = torch.unique(adapter_mask)
|
| 75 |
for task_id in unique_tasks:
|
| 76 |
task_token_type_embeddings = self.token_type_embeddings(token_type_ids, task_id=task_id)
|
| 77 |
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
|
|
|
| 48 |
"""
|
| 49 |
batch_size, seqlen = input_ids.shape
|
| 50 |
if adapter_mask is not None:
|
| 51 |
+
unique_tasks = torch.unique(adapter_mask)
|
| 52 |
embedding_dtype = next(self.word_embeddings.parameters()).dtype
|
| 53 |
embeddings = torch.empty(*input_ids.shape, self.word_embeddings.embedding_dim,
|
| 54 |
dtype=embedding_dtype, device=input_ids.device)
|
|
|
|
| 71 |
token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
|
| 72 |
|
| 73 |
if adapter_mask is not None:
|
| 74 |
+
unique_tasks = torch.unique(adapter_mask)
|
| 75 |
for task_id in unique_tasks:
|
| 76 |
task_token_type_embeddings = self.token_type_embeddings(token_type_ids, task_id=task_id)
|
| 77 |
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
mha.py
CHANGED
|
@@ -647,7 +647,7 @@ class MHA(nn.Module):
|
|
| 647 |
assert x_kv is None and mixer_subset is None
|
| 648 |
|
| 649 |
if cu_adapter_mask is not None:
|
| 650 |
-
unique_tasks = torch.unique(cu_adapter_mask)
|
| 651 |
qkv_dtype = next(self.Wqkv.parameters()).dtype
|
| 652 |
qkv = torch.empty(x.shape[0], self.Wqkv.out_features,
|
| 653 |
dtype=qkv_dtype, device=x.device)
|
|
@@ -663,7 +663,10 @@ class MHA(nn.Module):
|
|
| 663 |
if not self.return_residual:
|
| 664 |
qkv = self.Wqkv(x)
|
| 665 |
else:
|
| 666 |
-
|
|
|
|
|
|
|
|
|
|
| 667 |
|
| 668 |
if self.dwconv:
|
| 669 |
qkv = rearrange(
|
|
@@ -752,7 +755,7 @@ class MHA(nn.Module):
|
|
| 752 |
|
| 753 |
inp = rearrange(context, "... h d -> ... (h d)")
|
| 754 |
if cu_adapter_mask is not None:
|
| 755 |
-
unique_tasks = torch.unique(cu_adapter_mask)
|
| 756 |
out_dtype = next(self.out_proj.parameters()).dtype
|
| 757 |
out = torch.empty(inp.shape[0], self.out_proj.out_features,
|
| 758 |
dtype=out_dtype, device=inp.device)
|
|
|
|
| 647 |
assert x_kv is None and mixer_subset is None
|
| 648 |
|
| 649 |
if cu_adapter_mask is not None:
|
| 650 |
+
unique_tasks = torch.unique(cu_adapter_mask)
|
| 651 |
qkv_dtype = next(self.Wqkv.parameters()).dtype
|
| 652 |
qkv = torch.empty(x.shape[0], self.Wqkv.out_features,
|
| 653 |
dtype=qkv_dtype, device=x.device)
|
|
|
|
| 663 |
if not self.return_residual:
|
| 664 |
qkv = self.Wqkv(x)
|
| 665 |
else:
|
| 666 |
+
if hasattr(self.Wqkv, 'parametrizations'):
|
| 667 |
+
qkv, x = self.Wqkv(x, residual=True)
|
| 668 |
+
else:
|
| 669 |
+
qkv, x = self.Wqkv(x)
|
| 670 |
|
| 671 |
if self.dwconv:
|
| 672 |
qkv = rearrange(
|
|
|
|
| 755 |
|
| 756 |
inp = rearrange(context, "... h d -> ... (h d)")
|
| 757 |
if cu_adapter_mask is not None:
|
| 758 |
+
unique_tasks = torch.unique(cu_adapter_mask)
|
| 759 |
out_dtype = next(self.out_proj.parameters()).dtype
|
| 760 |
out = torch.empty(inp.shape[0], self.out_proj.out_features,
|
| 761 |
dtype=out_dtype, device=inp.device)
|
mlp.py
CHANGED
|
@@ -49,7 +49,7 @@ class Mlp(nn.Module):
|
|
| 49 |
|
| 50 |
def forward(self, x, cu_adapter_mask=None):
|
| 51 |
if cu_adapter_mask is not None:
|
| 52 |
-
unique_tasks = torch.unique(cu_adapter_mask)
|
| 53 |
fc1_dtype = next(self.fc1.parameters()).dtype
|
| 54 |
y = torch.empty(x.shape[0], self.fc1.out_features,
|
| 55 |
dtype=fc1_dtype, device=x.device)
|
|
@@ -64,7 +64,7 @@ class Mlp(nn.Module):
|
|
| 64 |
y = self.activation(y)
|
| 65 |
|
| 66 |
if cu_adapter_mask is not None:
|
| 67 |
-
unique_tasks = torch.unique(cu_adapter_mask)
|
| 68 |
fc2_dtype = next(self.fc2.parameters()).dtype
|
| 69 |
out = torch.empty(y.shape[0], self.fc2.out_features,
|
| 70 |
dtype=fc2_dtype, device=y.device)
|
|
|
|
| 49 |
|
| 50 |
def forward(self, x, cu_adapter_mask=None):
|
| 51 |
if cu_adapter_mask is not None:
|
| 52 |
+
unique_tasks = torch.unique(cu_adapter_mask)
|
| 53 |
fc1_dtype = next(self.fc1.parameters()).dtype
|
| 54 |
y = torch.empty(x.shape[0], self.fc1.out_features,
|
| 55 |
dtype=fc1_dtype, device=x.device)
|
|
|
|
| 64 |
y = self.activation(y)
|
| 65 |
|
| 66 |
if cu_adapter_mask is not None:
|
| 67 |
+
unique_tasks = torch.unique(cu_adapter_mask)
|
| 68 |
fc2_dtype = next(self.fc2.parameters()).dtype
|
| 69 |
out = torch.empty(y.shape[0], self.fc2.out_features,
|
| 70 |
dtype=fc2_dtype, device=y.device)
|
modeling_lora.py
CHANGED
|
@@ -355,7 +355,9 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
| 355 |
f"Supported tasks are: {', '.join(self.config.lora_adaptations)}."
|
| 356 |
f"Alternatively, don't pass the `task_type` argument to disable LoRA."
|
| 357 |
)
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
|
|
|
|
|
|
| 361 |
return self.roberta.encode(sentences, *args, adapter_mask=adapter_mask, **kwargs)
|
|
|
|
| 355 |
f"Supported tasks are: {', '.join(self.config.lora_adaptations)}."
|
| 356 |
f"Alternatively, don't pass the `task_type` argument to disable LoRA."
|
| 357 |
)
|
| 358 |
+
adapter_mask = None
|
| 359 |
+
if task_type:
|
| 360 |
+
task_id = self._adaptation_map[task_type]
|
| 361 |
+
num_examples = 1 if isinstance(sentences, str) else len(sentences)
|
| 362 |
+
adapter_mask = torch.full((num_examples,), task_id, dtype=torch.int32, device=self.device)
|
| 363 |
return self.roberta.encode(sentences, *args, adapter_mask=adapter_mask, **kwargs)
|
modeling_xlm_roberta.py
CHANGED
|
@@ -314,7 +314,7 @@ class XLMRobertaPooler(nn.Module):
|
|
| 314 |
# to the first token.
|
| 315 |
first_token_tensor = hidden_states[:, 0] if pool else hidden_states
|
| 316 |
if adapter_mask is not None:
|
| 317 |
-
unique_tasks = torch.unique(adapter_mask)
|
| 318 |
pool_dtype = next(self.dense.parameters()).dtype
|
| 319 |
pooled_output = torch.empty(first_token_tensor.shape[0], self.dense.out_features,
|
| 320 |
dtype=pool_dtype, device=first_token_tensor.device)
|
|
@@ -465,6 +465,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
| 465 |
normalize_embeddings: bool = False,
|
| 466 |
truncate_dim: Optional[int] = None,
|
| 467 |
adapter_mask: Optional[torch.Tensor] = None,
|
|
|
|
| 468 |
**tokenizer_kwargs,
|
| 469 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
| 470 |
"""
|
|
|
|
| 314 |
# to the first token.
|
| 315 |
first_token_tensor = hidden_states[:, 0] if pool else hidden_states
|
| 316 |
if adapter_mask is not None:
|
| 317 |
+
unique_tasks = torch.unique(adapter_mask)
|
| 318 |
pool_dtype = next(self.dense.parameters()).dtype
|
| 319 |
pooled_output = torch.empty(first_token_tensor.shape[0], self.dense.out_features,
|
| 320 |
dtype=pool_dtype, device=first_token_tensor.device)
|
|
|
|
| 465 |
normalize_embeddings: bool = False,
|
| 466 |
truncate_dim: Optional[int] = None,
|
| 467 |
adapter_mask: Optional[torch.Tensor] = None,
|
| 468 |
+
task_type: Optional[str] = None,
|
| 469 |
**tokenizer_kwargs,
|
| 470 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
| 471 |
"""
|