zaddyzaddy commited on
Commit
1bea27b
·
verified ·
1 Parent(s): cc71426

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. handler.py +33 -40
handler.py CHANGED
@@ -8,17 +8,40 @@ from transformers import (
8
  RobertaTokenizer,
9
  RobertaForSequenceClassification,
10
  )
 
 
 
11
 
12
- class DipperParaphraser(object):
13
- def __init__(self, model="", verbose=True):
14
- time1 = time.time()
15
- self.tokenizer = T5Tokenizer.from_pretrained('google/t5-v1_1-xxl')
16
- self.model = T5ForConditionalGeneration.from_pretrained(model, device_map="auto", load_in_8bit=True)
17
- if verbose:
18
- print(f"{model} model loaded in {time.time() - time1}")
19
- # self.model.cuda()
20
- self.model.eval()
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  def paraphrase(self, input_text, lex_diversity, order_diversity, prefix="", sent_interval=3, **kwargs):
23
  """Paraphrase a text using the DIPPER model.
24
 
@@ -55,34 +78,4 @@ class DipperParaphraser(object):
55
  prefix += " " + outputs[0]
56
  output_text += " " + outputs[0]
57
 
58
- return output_text
59
-
60
-
61
- class EndpointHandler:
62
- def __init__(self, path=""):
63
- self.pipeline = DipperParaphraser(model=path)
64
-
65
- def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
66
- """
67
- data args:
68
- inputs (:obj: `str`)
69
- date (:obj: `str`)
70
- Return:
71
- A :obj:`list` | `dict`: will be serialized and returned
72
- """
73
- input_text = data.get("input_text", "")
74
- lex_diversity = data.get("lex_diversity", 80)
75
- order_diversity = data.get("order_diversity", 20)
76
- prefix = data.get("prefix", "")
77
- prediction = self.pipeline.paraphrase(
78
- input_text,
79
- lex_diversity,
80
- order_diversity,
81
- prefix=prefix,
82
- do_sample=True,
83
- top_p=0.75,
84
- max_length=512
85
- )
86
-
87
- prediction = {'prediction': prediction}
88
- return prediction
 
8
  RobertaTokenizer,
9
  RobertaForSequenceClassification,
10
  )
11
+ import nltk
12
+ from nltk.tokenize import sent_tokenize
13
+ nltk.download('punkt')
14
 
 
 
 
 
 
 
 
 
 
15
 
16
+ class EndpointHandler:
17
+ def __init__(self, path=""):
18
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
19
+ self.model = T5ForConditionalGeneration.from_pretrained(path, device_map="auto", load_in_8bit=True)
20
+
21
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
22
+ """
23
+ data args:
24
+ inputs (:obj: `str`)
25
+ date (:obj: `str`)
26
+ Return:
27
+ A :obj:`list` | `dict`: will be serialized and returned
28
+ """
29
+ input_text = data.get("input_text", "")
30
+ lex_diversity = data.get("lex_diversity", 80)
31
+ order_diversity = data.get("order_diversity", 20)
32
+ prefix = data.get("prefix", "")
33
+ prediction = self.paraphrase(
34
+ input_text,
35
+ lex_diversity,
36
+ order_diversity,
37
+ prefix=prefix,
38
+ do_sample=True,
39
+ top_p=0.75,
40
+ max_length=512
41
+ )
42
+
43
+ prediction = {'prediction': prediction}
44
+ return prediction
45
  def paraphrase(self, input_text, lex_diversity, order_diversity, prefix="", sent_interval=3, **kwargs):
46
  """Paraphrase a text using the DIPPER model.
47
 
 
78
  prefix += " " + outputs[0]
79
  output_text += " " + outputs[0]
80
 
81
+ return output_text