altndrr commited on
Commit
0e10a1c
·
1 Parent(s): 32ba294

Fix recent error 'Cannot copy out of meta tensor' with flair

Browse files
Files changed (1) hide show
  1. transforms_cased.py +13 -1
transforms_cased.py CHANGED
@@ -172,7 +172,11 @@ class FilterPOS(BaseTextTransform):
172
  nltk.download("punkt", quiet=True)
173
  self.tagger = lambda x: nltk.pos_tag(nltk.word_tokenize(x))
174
  elif engine == "flair":
175
- self.tagger = SequenceTagger.load("flair/pos-english-fast").predict
 
 
 
 
176
 
177
  def __call__(self, text: str) -> str:
178
  """
@@ -184,6 +188,14 @@ class FilterPOS(BaseTextTransform):
184
  text = " ".join([word for word, tag in word_tags if tag not in self.tags])
185
  elif self.engine == "flair":
186
  sentence = Sentence(text)
 
 
 
 
 
 
 
 
187
  self.tagger(sentence)
188
  text = " ".join([token.text for token in sentence.tokens if token.tag in self.tags])
189
 
 
172
  nltk.download("punkt", quiet=True)
173
  self.tagger = lambda x: nltk.pos_tag(nltk.word_tokenize(x))
174
  elif engine == "flair":
175
+ # post-pone loading the flair tagger to avoid a recent error with flair:
176
+ # NotImplementedError: Cannot copy out of meta tensor; no data! Please use
177
+ # torch.nn.Module.to_empty() instead of torch.nn.Module.to() when moving module
178
+ # from meta to a different device.
179
+ self.tagger = None
180
 
181
  def __call__(self, text: str) -> str:
182
  """
 
188
  text = " ".join([word for word, tag in word_tags if tag not in self.tags])
189
  elif self.engine == "flair":
190
  sentence = Sentence(text)
191
+
192
+ # post-pone loading the flair tagger to avoid a recent error with flair:
193
+ # NotImplementedError: Cannot copy out of meta tensor; no data! Please use
194
+ # torch.nn.Module.to_empty() instead of torch.nn.Module.to() when moving module
195
+ # from meta to a different device.
196
+ if self.tagger is None:
197
+ self.tagger = SequenceTagger.load("flair/pos-english-fast").predict
198
+
199
  self.tagger(sentence)
200
  text = " ".join([token.text for token in sentence.tokens if token.tag in self.tags])
201