Fix recent error 'Cannot copy out of meta tensor' with flair
Browse files- transforms_cased.py +13 -1
transforms_cased.py
CHANGED
|
@@ -172,7 +172,11 @@ class FilterPOS(BaseTextTransform):
|
|
| 172 |
nltk.download("punkt", quiet=True)
|
| 173 |
self.tagger = lambda x: nltk.pos_tag(nltk.word_tokenize(x))
|
| 174 |
elif engine == "flair":
|
| 175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
|
| 177 |
def __call__(self, text: str) -> str:
|
| 178 |
"""
|
|
@@ -184,6 +188,14 @@ class FilterPOS(BaseTextTransform):
|
|
| 184 |
text = " ".join([word for word, tag in word_tags if tag not in self.tags])
|
| 185 |
elif self.engine == "flair":
|
| 186 |
sentence = Sentence(text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
self.tagger(sentence)
|
| 188 |
text = " ".join([token.text for token in sentence.tokens if token.tag in self.tags])
|
| 189 |
|
|
|
|
| 172 |
nltk.download("punkt", quiet=True)
|
| 173 |
self.tagger = lambda x: nltk.pos_tag(nltk.word_tokenize(x))
|
| 174 |
elif engine == "flair":
|
| 175 |
+
# post-pone loading the flair tagger to avoid a recent error with flair:
|
| 176 |
+
# NotImplementedError: Cannot copy out of meta tensor; no data! Please use
|
| 177 |
+
# torch.nn.Module.to_empty() instead of torch.nn.Module.to() when moving module
|
| 178 |
+
# from meta to a different device.
|
| 179 |
+
self.tagger = None
|
| 180 |
|
| 181 |
def __call__(self, text: str) -> str:
|
| 182 |
"""
|
|
|
|
| 188 |
text = " ".join([word for word, tag in word_tags if tag not in self.tags])
|
| 189 |
elif self.engine == "flair":
|
| 190 |
sentence = Sentence(text)
|
| 191 |
+
|
| 192 |
+
# post-pone loading the flair tagger to avoid a recent error with flair:
|
| 193 |
+
# NotImplementedError: Cannot copy out of meta tensor; no data! Please use
|
| 194 |
+
# torch.nn.Module.to_empty() instead of torch.nn.Module.to() when moving module
|
| 195 |
+
# from meta to a different device.
|
| 196 |
+
if self.tagger is None:
|
| 197 |
+
self.tagger = SequenceTagger.load("flair/pos-english-fast").predict
|
| 198 |
+
|
| 199 |
self.tagger(sentence)
|
| 200 |
text = " ".join([token.text for token in sentence.tokens if token.tag in self.tags])
|
| 201 |
|