Spaces:
Build error
Build error
PeteBleackley
commited on
Commit
·
fcfc2b3
1
Parent(s):
dd9c3ed
TPUs need constant batch shapes
Browse files- qarac/corpora/CombinedCorpus.py +13 -7
- qarac/corpora/CorpusLoader.py +11 -0
qarac/corpora/CombinedCorpus.py
CHANGED
|
@@ -39,7 +39,7 @@ class CombinedCorpus(keras.utils.Sequence):
|
|
| 39 |
'encode_decode')})
|
| 40 |
n_samples = len(self.all_text)
|
| 41 |
|
| 42 |
-
self.n_batches =
|
| 43 |
self.question_answering = CorpusRepeater.CorpusRepeater(CorpusLoader.CorpusLoader(kwargs['question_answering'],
|
| 44 |
tokenizer,
|
| 45 |
['question',
|
|
@@ -63,6 +63,12 @@ class CombinedCorpus(keras.utils.Sequence):
|
|
| 63 |
self.batches = None
|
| 64 |
self.pad_token = tokenizer.token_to_id('<pad>')
|
| 65 |
self.on_epoch_end()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
def __len__(self):
|
| 68 |
"""
|
|
@@ -127,8 +133,7 @@ class CombinedCorpus(keras.utils.Sequence):
|
|
| 127 |
yield(batch)
|
| 128 |
batch = []
|
| 129 |
n=0
|
| 130 |
-
|
| 131 |
-
yield batch
|
| 132 |
|
| 133 |
def on_epoch_end(self):
|
| 134 |
self.batches = self.make_batches()
|
|
@@ -161,13 +166,15 @@ class CombinedCorpus(keras.utils.Sequence):
|
|
| 161 |
n+=1
|
| 162 |
|
| 163 |
for (key,value) in X.items():
|
| 164 |
-
X[key] = self.pad(value)
|
| 165 |
for (key,value) in Y.items():
|
| 166 |
-
Y[key] = tensorflow.constant(value) if key=='consistency' else self.pad(value,
|
|
|
|
|
|
|
| 167 |
Y['question_answering'] = tensorflow.zeros((n,768))
|
| 168 |
return (X,Y)
|
| 169 |
|
| 170 |
-
def pad(self,batch,inputs=True):
|
| 171 |
"""
|
| 172 |
Pads a batch of samples to uniform length
|
| 173 |
|
|
@@ -182,7 +189,6 @@ class CombinedCorpus(keras.utils.Sequence):
|
|
| 182 |
Padded data
|
| 183 |
|
| 184 |
"""
|
| 185 |
-
maxlen = max((len(sample) for sample in batch))
|
| 186 |
for sample in batch:
|
| 187 |
sample.pad(maxlen,pad_id=self.pad_token)
|
| 188 |
input_ids = tensorflow.constant([sample.ids
|
|
|
|
| 39 |
'encode_decode')})
|
| 40 |
n_samples = len(self.all_text)
|
| 41 |
|
| 42 |
+
self.n_batches = n_samples//32
|
| 43 |
self.question_answering = CorpusRepeater.CorpusRepeater(CorpusLoader.CorpusLoader(kwargs['question_answering'],
|
| 44 |
tokenizer,
|
| 45 |
['question',
|
|
|
|
| 63 |
self.batches = None
|
| 64 |
self.pad_token = tokenizer.token_to_id('<pad>')
|
| 65 |
self.on_epoch_end()
|
| 66 |
+
self.max_lengths = {}
|
| 67 |
+
for corpus in (self.all_text,
|
| 68 |
+
self.question_answering,
|
| 69 |
+
self.reasoning,
|
| 70 |
+
self.consistency):
|
| 71 |
+
self.max_lengths.update(corpus.max_lengths())
|
| 72 |
|
| 73 |
def __len__(self):
|
| 74 |
"""
|
|
|
|
| 133 |
yield(batch)
|
| 134 |
batch = []
|
| 135 |
n=0
|
| 136 |
+
|
|
|
|
| 137 |
|
| 138 |
def on_epoch_end(self):
|
| 139 |
self.batches = self.make_batches()
|
|
|
|
| 166 |
n+=1
|
| 167 |
|
| 168 |
for (key,value) in X.items():
|
| 169 |
+
X[key] = self.pad(value,self.max_lengths[key])
|
| 170 |
for (key,value) in Y.items():
|
| 171 |
+
Y[key] = tensorflow.constant(value) if key=='consistency' else self.pad(value,
|
| 172 |
+
self.max_lengths[key],
|
| 173 |
+
False)
|
| 174 |
Y['question_answering'] = tensorflow.zeros((n,768))
|
| 175 |
return (X,Y)
|
| 176 |
|
| 177 |
+
def pad(self,batch,maxlen,inputs=True):
|
| 178 |
"""
|
| 179 |
Pads a batch of samples to uniform length
|
| 180 |
|
|
|
|
| 189 |
Padded data
|
| 190 |
|
| 191 |
"""
|
|
|
|
| 192 |
for sample in batch:
|
| 193 |
sample.pad(maxlen,pad_id=self.pad_token)
|
| 194 |
input_ids = tensorflow.constant([sample.ids
|
qarac/corpora/CorpusLoader.py
CHANGED
|
@@ -99,4 +99,15 @@ class CorpusLoader(object):
|
|
| 99 |
if self.label is not None:
|
| 100 |
Y[self.label]=row[self.label]
|
| 101 |
yield (X,Y)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
|
|
|
|
| 99 |
if self.label is not None:
|
| 100 |
Y[self.label]=row[self.label]
|
| 101 |
yield (X,Y)
|
| 102 |
+
|
| 103 |
+
def max_lengths(self):
|
| 104 |
+
result = {column:max((row[column]
|
| 105 |
+
for row in self.dataset))
|
| 106 |
+
for column in self.text_inputs}
|
| 107 |
+
for (column,(inside,outside)) in self.text_outputs.items():
|
| 108 |
+
n = result[column] if column in result else max((len(row[column]
|
| 109 |
+
for row in self.dataset)))
|
| 110 |
+
result[inside] = n+1
|
| 111 |
+
result[outside] = n+1
|
| 112 |
+
return result
|
| 113 |
|