Dmitry Chaplinsky commited on
Commit
449f919
·
1 Parent(s): cbdc24e

Trying to tranform the words after the pipeline

Browse files
Files changed (1) hide show
  1. pipeline.py +20 -1
pipeline.py CHANGED
@@ -10,6 +10,17 @@ class PreTrainedPipeline:
10
  # This function is only called once, so do all the heavy processing I/O here"""
11
  self.model = PunctuationCapitalizationModel.from_pretrained("dchaplinsky/punctuation_uk_bert")
12
 
 
 
 
 
 
 
 
 
 
 
 
13
  def __call__(self, inputs: str) -> List[Dict[str, Any]]:
14
  """
15
  Args:
@@ -32,7 +43,15 @@ class PreTrainedPipeline:
32
  offset = 0
33
  for tok, lab in zip(tokens, labels):
34
  if lab != "OO":
35
- res.append({"entity_group": lab, "word": tok, "start": offset, "end": offset + len(tok), "score": 0.99})
 
 
 
 
 
 
 
 
36
 
37
  offset += len(tok) + 1
38
 
 
10
  # This function is only called once, so do all the heavy processing I/O here"""
11
  self.model = PunctuationCapitalizationModel.from_pretrained("dchaplinsky/punctuation_uk_bert")
12
 
13
+ def apply_label_to_token(self, token: str, label: str) -> str:
14
+ punct, upper = label
15
+
16
+ if punct != "O":
17
+ token += punct
18
+
19
+ if upper == "U":
20
+ token = token.title()
21
+
22
+ return token
23
+
24
  def __call__(self, inputs: str) -> List[Dict[str, Any]]:
25
  """
26
  Args:
 
43
  offset = 0
44
  for tok, lab in zip(tokens, labels):
45
  if lab != "OO":
46
+ res.append(
47
+ {
48
+ "entity_group": lab,
49
+ "word": self.apply_label_to_token(tok, lab),
50
+ "start": offset,
51
+ "end": offset + len(tok),
52
+ "score": 0.99,
53
+ }
54
+ )
55
 
56
  offset += len(tok) + 1
57