OnlyBiggg commited on
Commit
900edfa
·
1 Parent(s): 668f128

fix truncate

Browse files
Files changed (1) hide show
  1. app/ner/services/ner.py +17 -3
app/ner/services/ner.py CHANGED
@@ -22,9 +22,23 @@ class NER:
22
  async def predict(self, text: str, entity_tag: str = None):
23
  if self.pipeline is None:
24
  raise ValueError("Model not loaded. Please call load_model() first.")
25
- pred = self.pipeline(text,
26
- truncation=settings.TRUNCATE,
27
- max_length=settings.MAX_LENGTH)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  if entity_tag:
29
  return self.extract_entities(pred, entity_tag)
30
  return pred
 
22
  async def predict(self, text: str, entity_tag: str = None):
23
  if self.pipeline is None:
24
  raise ValueError("Model not loaded. Please call load_model() first.")
25
+
26
+ inputs = self.tokenizer(
27
+ text,
28
+ truncation=settings.TRUNCATE, # Enable truncation
29
+ max_length=settings.MAX_LENGTH, # Set maximum length
30
+ return_tensors="pt" # Make sure the tokenized inputs are returned as PyTorch tensors
31
+ )
32
+
33
+ # Get the prediction from the model
34
+ pred = self.model(**inputs)
35
+
36
+ # Convert model output to pipeline format
37
+ pred = self.pipeline.decode(pred.logits)
38
+
39
+ # pred = self.pipeline(text,
40
+ # truncation=settings.TRUNCATE,
41
+ # max_length=settings.MAX_LENGTH)
42
  if entity_tag:
43
  return self.extract_entities(pred, entity_tag)
44
  return pred