Spaces:
Build error
Build error
PeteBleackley
commited on
Commit
·
7a9be99
1
Parent(s):
c8625dc
Modified CombinedCorpus to use PyTorch
Browse files- qarac/corpora/CombinedCorpus.py +20 -35
qarac/corpora/CombinedCorpus.py
CHANGED
|
@@ -6,14 +6,11 @@ Created on Wed Sep 20 14:12:34 2023
|
|
| 6 |
@author: peter
|
| 7 |
"""
|
| 8 |
|
| 9 |
-
import itertools
|
| 10 |
import collections
|
| 11 |
-
import
|
| 12 |
-
import tensorflow
|
| 13 |
-
import keras
|
| 14 |
from qarac.corpora import CorpusLoader, CorpusRepeater
|
| 15 |
|
| 16 |
-
class CombinedCorpus(
|
| 17 |
|
| 18 |
def __init__(self,tokenizer,**kwargs):
|
| 19 |
"""
|
|
@@ -82,23 +79,7 @@ class CombinedCorpus(keras.utils.Sequence):
|
|
| 82 |
"""
|
| 83 |
return self.n_batches
|
| 84 |
|
| 85 |
-
|
| 86 |
-
"""
|
| 87 |
-
Retrieves a batch of data
|
| 88 |
-
|
| 89 |
-
Parameters
|
| 90 |
-
----------
|
| 91 |
-
n : int
|
| 92 |
-
index of batch to retrieve
|
| 93 |
-
|
| 94 |
-
Returns
|
| 95 |
-
-------
|
| 96 |
-
tupe(dict,dict)
|
| 97 |
-
Batch of data
|
| 98 |
-
|
| 99 |
-
"""
|
| 100 |
-
|
| 101 |
-
return self.batch(next(self.batches))
|
| 102 |
|
| 103 |
def samples(self):
|
| 104 |
"""
|
|
@@ -123,14 +104,14 @@ class CombinedCorpus(keras.utils.Sequence):
|
|
| 123 |
Y.update(y)
|
| 124 |
yield (X,Y)
|
| 125 |
|
| 126 |
-
def
|
| 127 |
batch = []
|
| 128 |
n=0
|
| 129 |
for sample in self.samples():
|
| 130 |
batch.append(sample)
|
| 131 |
n+=1
|
| 132 |
if n==32:
|
| 133 |
-
yield(batch)
|
| 134 |
batch = []
|
| 135 |
n=0
|
| 136 |
|
|
@@ -149,9 +130,9 @@ class CombinedCorpus(keras.utils.Sequence):
|
|
| 149 |
|
| 150 |
Returns
|
| 151 |
-------
|
| 152 |
-
X : dict[str,
|
| 153 |
Batched input samples
|
| 154 |
-
Y : dict[str,
|
| 155 |
Batched output samples
|
| 156 |
|
| 157 |
"""
|
|
@@ -167,12 +148,16 @@ class CombinedCorpus(keras.utils.Sequence):
|
|
| 167 |
|
| 168 |
X={key:self.pad(value,self.max_lengths[key])
|
| 169 |
for (key,value) in X.items()}
|
| 170 |
-
Y={key:
|
| 171 |
-
|
| 172 |
-
|
| 173 |
for (key,value) in Y.items()}
|
| 174 |
-
Y['question_answering'] =
|
| 175 |
-
return (X,Y
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
|
| 177 |
def pad(self,batch,maxlen,inputs=True):
|
| 178 |
"""
|
|
@@ -191,12 +176,12 @@ class CombinedCorpus(keras.utils.Sequence):
|
|
| 191 |
"""
|
| 192 |
for sample in batch:
|
| 193 |
sample.pad(maxlen,pad_id=self.pad_token)
|
| 194 |
-
input_ids =
|
| 195 |
-
|
| 196 |
result = input_ids
|
| 197 |
if inputs:
|
| 198 |
-
attention_mask =
|
| 199 |
-
|
| 200 |
result = {'input_ids':input_ids,
|
| 201 |
'attention_mask':attention_mask}
|
| 202 |
return result
|
|
|
|
| 6 |
@author: peter
|
| 7 |
"""
|
| 8 |
|
|
|
|
| 9 |
import collections
|
| 10 |
+
import torch
|
|
|
|
|
|
|
| 11 |
from qarac.corpora import CorpusLoader, CorpusRepeater
|
| 12 |
|
| 13 |
+
class CombinedCorpus(torch.utils.data.IterableDataset()):
|
| 14 |
|
| 15 |
def __init__(self,tokenizer,**kwargs):
|
| 16 |
"""
|
|
|
|
| 79 |
"""
|
| 80 |
return self.n_batches
|
| 81 |
|
| 82 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
def samples(self):
|
| 85 |
"""
|
|
|
|
| 104 |
Y.update(y)
|
| 105 |
yield (X,Y)
|
| 106 |
|
| 107 |
+
def __iter__(self):
|
| 108 |
batch = []
|
| 109 |
n=0
|
| 110 |
for sample in self.samples():
|
| 111 |
batch.append(sample)
|
| 112 |
n+=1
|
| 113 |
if n==32:
|
| 114 |
+
yield(self.batch(batch))
|
| 115 |
batch = []
|
| 116 |
n=0
|
| 117 |
|
|
|
|
| 130 |
|
| 131 |
Returns
|
| 132 |
-------
|
| 133 |
+
X : dict[str,torch.Tensor]
|
| 134 |
Batched input samples
|
| 135 |
+
Y : dict[str,torch.Tensor]
|
| 136 |
Batched output samples
|
| 137 |
|
| 138 |
"""
|
|
|
|
| 148 |
|
| 149 |
X={key:self.pad(value,self.max_lengths[key])
|
| 150 |
for (key,value) in X.items()}
|
| 151 |
+
Y={key:torch.tensor(value) if key=='consistency' else self.pad(value,
|
| 152 |
+
self.max_lengths[key],
|
| 153 |
+
False)
|
| 154 |
for (key,value) in Y.items()}
|
| 155 |
+
Y['question_answering'] = torch.zeros((n,768))
|
| 156 |
+
return (X,tuple([Y[key]
|
| 157 |
+
for key in ('encode_decode',
|
| 158 |
+
'question_answering',
|
| 159 |
+
'reasoning',
|
| 160 |
+
'consistency')]))
|
| 161 |
|
| 162 |
def pad(self,batch,maxlen,inputs=True):
|
| 163 |
"""
|
|
|
|
| 176 |
"""
|
| 177 |
for sample in batch:
|
| 178 |
sample.pad(maxlen,pad_id=self.pad_token)
|
| 179 |
+
input_ids = torch.tensor([sample.ids
|
| 180 |
+
for sample in batch])
|
| 181 |
result = input_ids
|
| 182 |
if inputs:
|
| 183 |
+
attention_mask = torch.not_equal(input_ids,
|
| 184 |
+
self.pad_token)
|
| 185 |
result = {'input_ids':input_ids,
|
| 186 |
'attention_mask':attention_mask}
|
| 187 |
return result
|