Spaces:
Build error
Build error
PeteBleackley
commited on
Commit
·
f5599c3
1
Parent(s):
4a7707c
Ensure consistency of device assignment when training
Browse files- qarac/corpora/CombinedCorpus.py +4 -3
- scripts.py +2 -1
qarac/corpora/CombinedCorpus.py
CHANGED
|
@@ -58,6 +58,7 @@ class CombinedCorpus(torch.utils.data.IterableDataset):
|
|
| 58 |
{},
|
| 59 |
'consistency'),
|
| 60 |
n_samples)
|
|
|
|
| 61 |
self.batches = None
|
| 62 |
self.pad_token = tokenizer.token_to_id('<pad>')
|
| 63 |
self.max_lengths = {}
|
|
@@ -145,11 +146,11 @@ class CombinedCorpus(torch.utils.data.IterableDataset):
|
|
| 145 |
|
| 146 |
X={key:self.pad(value,self.max_lengths[key])
|
| 147 |
for (key,value) in X.items()}
|
| 148 |
-
Y={key:torch.tensor(value,device=
|
| 149 |
self.max_lengths[key],
|
| 150 |
False)
|
| 151 |
for (key,value) in Y.items()}
|
| 152 |
-
Y['question_answering'] = torch.zeros((n,768),device=
|
| 153 |
return (X,
|
| 154 |
tuple([Y[key]
|
| 155 |
for key in ('encode_decode',
|
|
@@ -176,7 +177,7 @@ class CombinedCorpus(torch.utils.data.IterableDataset):
|
|
| 176 |
sample.pad(maxlen,pad_id=self.pad_token)
|
| 177 |
input_ids = torch.tensor([sample.ids
|
| 178 |
for sample in batch],
|
| 179 |
-
device=
|
| 180 |
result = input_ids
|
| 181 |
if inputs:
|
| 182 |
attention_mask = torch.not_equal(input_ids,
|
|
|
|
| 58 |
{},
|
| 59 |
'consistency'),
|
| 60 |
n_samples)
|
| 61 |
+
self.device = kwargs['device']
|
| 62 |
self.batches = None
|
| 63 |
self.pad_token = tokenizer.token_to_id('<pad>')
|
| 64 |
self.max_lengths = {}
|
|
|
|
| 146 |
|
| 147 |
X={key:self.pad(value,self.max_lengths[key])
|
| 148 |
for (key,value) in X.items()}
|
| 149 |
+
Y={key:torch.tensor(value,device=self.device).float() if key=='consistency' else self.pad(value,
|
| 150 |
self.max_lengths[key],
|
| 151 |
False)
|
| 152 |
for (key,value) in Y.items()}
|
| 153 |
+
Y['question_answering'] = torch.zeros((n,768),device=self.device)
|
| 154 |
return (X,
|
| 155 |
tuple([Y[key]
|
| 156 |
for key in ('encode_decode',
|
|
|
|
| 177 |
sample.pad(maxlen,pad_id=self.pad_token)
|
| 178 |
input_ids = torch.tensor([sample.ids
|
| 179 |
for sample in batch],
|
| 180 |
+
device=self.device)
|
| 181 |
result = input_ids
|
| 182 |
if inputs:
|
| 183 |
attention_mask = torch.not_equal(input_ids,
|
scripts.py
CHANGED
|
@@ -131,7 +131,8 @@ def train_models(path,progress=gradio.Progress(track_tqdm=True)):
|
|
| 131 |
all_text='corpora/all_text.csv',
|
| 132 |
question_answering='corpora/question_answering.csv',
|
| 133 |
reasoning='corpora/reasoning_train.csv',
|
| 134 |
-
consistency='corpora/consistency.csv'
|
|
|
|
| 135 |
n_batches = len(training_data)
|
| 136 |
history = {}
|
| 137 |
for epoch in range(25):
|
|
|
|
| 131 |
all_text='corpora/all_text.csv',
|
| 132 |
question_answering='corpora/question_answering.csv',
|
| 133 |
reasoning='corpora/reasoning_train.csv',
|
| 134 |
+
consistency='corpora/consistency.csv',
|
| 135 |
+
device=trainer.device())
|
| 136 |
n_batches = len(training_data)
|
| 137 |
history = {}
|
| 138 |
for epoch in range(25):
|