Commit
·
65e9690
1
Parent(s):
4ee2970
fix: device
Browse filesSigned-off-by: Meow <ongjackm@gmail.com>
- embedding.py +1 -1
- mha.py +2 -2
- mlp.py +2 -2
- modeling_xlm_roberta.py +1 -1
embedding.py
CHANGED
|
@@ -51,7 +51,7 @@ class XLMRobertaEmbeddings(nn.Module):
|
|
| 51 |
unique_tasks = torch.unique(adapter_mask).tolist()
|
| 52 |
embedding_dtype = next(self.word_embeddings.parameters()).dtype
|
| 53 |
embeddings = torch.empty(*input_ids.shape, self.word_embeddings.embedding_dim,
|
| 54 |
-
dtype=embedding_dtype
|
| 55 |
for task_id in unique_tasks:
|
| 56 |
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
| 57 |
task_input_ids = input_ids[task_indices]
|
|
|
|
| 51 |
unique_tasks = torch.unique(adapter_mask).tolist()
|
| 52 |
embedding_dtype = next(self.word_embeddings.parameters()).dtype
|
| 53 |
embeddings = torch.empty(*input_ids.shape, self.word_embeddings.embedding_dim,
|
| 54 |
+
dtype=embedding_dtype, device=input_ids.device)
|
| 55 |
for task_id in unique_tasks:
|
| 56 |
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
| 57 |
task_input_ids = input_ids[task_indices]
|
mha.py
CHANGED
|
@@ -650,7 +650,7 @@ class MHA(nn.Module):
|
|
| 650 |
unique_tasks = torch.unique(cu_adapter_mask).tolist()
|
| 651 |
qkv_dtype = next(self.Wqkv.parameters()).dtype
|
| 652 |
qkv = torch.empty(x.shape[0], self.Wqkv.out_features,
|
| 653 |
-
dtype=qkv_dtype
|
| 654 |
for task_id in unique_tasks:
|
| 655 |
task_indices = (cu_adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
| 656 |
task_tensor = x[task_indices]
|
|
@@ -755,7 +755,7 @@ class MHA(nn.Module):
|
|
| 755 |
unique_tasks = torch.unique(cu_adapter_mask).tolist()
|
| 756 |
out_dtype = next(self.out_proj.parameters()).dtype
|
| 757 |
out = torch.empty(inp.shape[0], self.out_proj.out_features,
|
| 758 |
-
dtype=out_dtype
|
| 759 |
for task_id in unique_tasks:
|
| 760 |
task_indices = (cu_adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
| 761 |
task_tensor = inp[task_indices]
|
|
|
|
| 650 |
unique_tasks = torch.unique(cu_adapter_mask).tolist()
|
| 651 |
qkv_dtype = next(self.Wqkv.parameters()).dtype
|
| 652 |
qkv = torch.empty(x.shape[0], self.Wqkv.out_features,
|
| 653 |
+
dtype=qkv_dtype, device=x.device)
|
| 654 |
for task_id in unique_tasks:
|
| 655 |
task_indices = (cu_adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
| 656 |
task_tensor = x[task_indices]
|
|
|
|
| 755 |
unique_tasks = torch.unique(cu_adapter_mask).tolist()
|
| 756 |
out_dtype = next(self.out_proj.parameters()).dtype
|
| 757 |
out = torch.empty(inp.shape[0], self.out_proj.out_features,
|
| 758 |
+
dtype=out_dtype, device=inp.device)
|
| 759 |
for task_id in unique_tasks:
|
| 760 |
task_indices = (cu_adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
| 761 |
task_tensor = inp[task_indices]
|
mlp.py
CHANGED
|
@@ -52,7 +52,7 @@ class Mlp(nn.Module):
|
|
| 52 |
unique_tasks = torch.unique(cu_adapter_mask).tolist()
|
| 53 |
fc1_dtype = next(self.fc1.parameters()).dtype
|
| 54 |
y = torch.empty(x.shape[0], self.fc1.out_features,
|
| 55 |
-
dtype=fc1_dtype
|
| 56 |
for task_id in unique_tasks:
|
| 57 |
task_indices = (cu_adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
| 58 |
task_tensor = x[task_indices]
|
|
@@ -67,7 +67,7 @@ class Mlp(nn.Module):
|
|
| 67 |
unique_tasks = torch.unique(cu_adapter_mask).tolist()
|
| 68 |
fc2_dtype = next(self.fc2.parameters()).dtype
|
| 69 |
out = torch.empty(y.shape[0], self.fc2.out_features,
|
| 70 |
-
dtype=fc2_dtype
|
| 71 |
for task_id in unique_tasks:
|
| 72 |
task_indices = (cu_adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
| 73 |
task_tensor = y[task_indices]
|
|
|
|
| 52 |
unique_tasks = torch.unique(cu_adapter_mask).tolist()
|
| 53 |
fc1_dtype = next(self.fc1.parameters()).dtype
|
| 54 |
y = torch.empty(x.shape[0], self.fc1.out_features,
|
| 55 |
+
dtype=fc1_dtype, device=x.device)
|
| 56 |
for task_id in unique_tasks:
|
| 57 |
task_indices = (cu_adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
| 58 |
task_tensor = x[task_indices]
|
|
|
|
| 67 |
unique_tasks = torch.unique(cu_adapter_mask).tolist()
|
| 68 |
fc2_dtype = next(self.fc2.parameters()).dtype
|
| 69 |
out = torch.empty(y.shape[0], self.fc2.out_features,
|
| 70 |
+
dtype=fc2_dtype, device=y.device)
|
| 71 |
for task_id in unique_tasks:
|
| 72 |
task_indices = (cu_adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
| 73 |
task_tensor = y[task_indices]
|
modeling_xlm_roberta.py
CHANGED
|
@@ -317,7 +317,7 @@ class XLMRobertaPooler(nn.Module):
|
|
| 317 |
unique_tasks = torch.unique(adapter_mask).tolist()
|
| 318 |
pool_dtype = next(self.dense.parameters()).dtype
|
| 319 |
pooled_output = torch.empty(first_token_tensor.shape[0], self.dense.out_features,
|
| 320 |
-
dtype=pool_dtype
|
| 321 |
for task_id in unique_tasks:
|
| 322 |
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
| 323 |
task_first_token_tensor = first_token_tensor[task_indices]
|
|
|
|
| 317 |
unique_tasks = torch.unique(adapter_mask).tolist()
|
| 318 |
pool_dtype = next(self.dense.parameters()).dtype
|
| 319 |
pooled_output = torch.empty(first_token_tensor.shape[0], self.dense.out_features,
|
| 320 |
+
dtype=pool_dtype, device=first_token_tensor.device)
|
| 321 |
for task_id in unique_tasks:
|
| 322 |
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
| 323 |
task_first_token_tensor = first_token_tensor[task_indices]
|