jaimin commited on
Commit
985ce1e
·
1 Parent(s): 417a8d3

Create styleformer.py

Browse files
Files changed (1) hide show
  1. styleformer.py +160 -0
styleformer.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class Styleformer():
2
+
3
+ def __init__(
4
+ self,
5
+ style=0,
6
+ ctf_model_tag="jaimin/Informal_to_formal",
7
+ ftc_model_tag="jaimin/formal_to_informal",
8
+ atp_model_tag="jaimin/Active_to_passive",
9
+ pta_model_tag="jaimin/Passive_to_active",
10
+ adequacy_model_tag="jaimin/parrot_adequacy_model",
11
+ ):
12
+ from transformers import AutoTokenizer
13
+ from transformers import AutoModelForSeq2SeqLM
14
+
15
+ self.style = style
16
+ self.adequacy = adequacy_model_tag and Adequacy(model_tag=adequacy_model_tag, use_auth_token="access")
17
+ self.model_loaded = False
18
+
19
+ if self.style == 0:
20
+ self.ctf_tokenizer = AutoTokenizer.from_pretrained(ctf_model_tag, use_auth_token="access")
21
+ self.ctf_model = AutoModelForSeq2SeqLM.from_pretrained(ctf_model_tag, use_auth_token="access")
22
+ print("Casual to Formal model loaded...")
23
+ self.model_loaded = True
24
+ elif self.style == 1:
25
+ self.ftc_tokenizer = AutoTokenizer.from_pretrained(ftc_model_tag, use_auth_token="access")
26
+ self.ftc_model = AutoModelForSeq2SeqLM.from_pretrained(ftc_model_tag, use_auth_token="access")
27
+ print("Formal to Casual model loaded...")
28
+ self.model_loaded = True
29
+ elif self.style == 2:
30
+ self.atp_tokenizer = AutoTokenizer.from_pretrained(atp_model_tag,use_auth_token="access")
31
+ self.atp_model = AutoModelForSeq2SeqLM.from_pretrained(atp_model_tag,use_auth_token="access")
32
+ print("Active to Passive model loaded...")
33
+ self.model_loaded = True
34
+ elif self.style == 3:
35
+ self.pta_tokenizer = AutoTokenizer.from_pretrained(pta_model_tag,use_auth_token="access")
36
+ self.pta_model = AutoModelForSeq2SeqLM.from_pretrained(pta_model_tag,use_auth_token="access")
37
+ print("Passive to Active model loaded...")
38
+ self.model_loaded = True
39
+ else:
40
+ print("Only CTF, FTC, ATP and PTA are supported in the pre-release...stay tuned")
41
+
42
+ def transfer(self, input_sentence, inference_on=-1, quality_filter=0.95, max_candidates=5):
43
+ if self.model_loaded:
44
+ if inference_on == -1:
45
+ device = "cpu"
46
+ elif inference_on >= 0 and inference_on < 999:
47
+ device = "cpu:" + str(inference_on)
48
+ else:
49
+ device = "cpu"
50
+ print("Onnx + Quantisation is not supported in the pre-release...stay tuned.")
51
+
52
+ if self.style == 0:
53
+ output_sentence = self._casual_to_formal(input_sentence, device, quality_filter, max_candidates)
54
+ return output_sentence
55
+ elif self.style == 1:
56
+ output_sentence = self._formal_to_casual(input_sentence, device, quality_filter, max_candidates)
57
+ return output_sentence
58
+ elif self.style == 2:
59
+ output_sentence = self._active_to_passive(input_sentence, device)
60
+ return output_sentence
61
+ elif self.style == 3:
62
+ output_sentence = self._passive_to_active(input_sentence, device)
63
+ return output_sentence
64
+
65
+ else:
66
+ print("Models aren't loaded for this style, please use the right style during init")
67
+
68
+ def _formal_to_casual(self, input_sentence, device, quality_filter, max_candidates):
69
+ ftc_prefix = "transfer Formal to Casual: "
70
+ src_sentence = input_sentence
71
+ input_sentence = ftc_prefix + input_sentence
72
+ input_ids = self.ftc_tokenizer.encode(input_sentence, return_tensors='pt')
73
+ self.ftc_model = self.ftc_model.to(device)
74
+ input_ids = input_ids.to(device)
75
+
76
+ preds = self.ftc_model.generate(
77
+ input_ids,
78
+ do_sample=True,
79
+ max_length=32,
80
+ top_k=50,
81
+ top_p=0.95,
82
+ early_stopping=True,
83
+ num_return_sequences=max_candidates)
84
+
85
+ gen_sentences = set()
86
+ for pred in preds:
87
+ gen_sentences.add(self.ftc_tokenizer.decode(pred, skip_special_tokens=True).strip())
88
+
89
+ adequacy_scored_phrases = self.adequacy.score(src_sentence, list(gen_sentences), quality_filter, device)
90
+ ranked_sentences = sorted(adequacy_scored_phrases.items(), key=lambda x: x[1], reverse=True)
91
+ if len(ranked_sentences) > 0:
92
+ return ranked_sentences[0][0]
93
+ else:
94
+ return None
95
+
96
+ def _casual_to_formal(self, input_sentence, device, quality_filter, max_candidates):
97
+ ctf_prefix = "transfer Casual to Formal: "
98
+ src_sentence = input_sentence
99
+ input_sentence = ctf_prefix + input_sentence
100
+ input_ids = self.ctf_tokenizer.encode(input_sentence, return_tensors='pt')
101
+ self.ctf_model = self.ctf_model.to(device)
102
+ input_ids = input_ids.to(device)
103
+
104
+ preds = self.ctf_model.generate(
105
+ input_ids,
106
+ do_sample=True,
107
+ max_length=32,
108
+ top_k=50,
109
+ top_p=0.95,
110
+ early_stopping=True,
111
+ num_return_sequences=max_candidates)
112
+
113
+ gen_sentences = set()
114
+ for pred in preds:
115
+ gen_sentences.add(self.ctf_tokenizer.decode(pred, skip_special_tokens=True).strip())
116
+
117
+ adequacy_scored_phrases = self.adequacy.score(src_sentence, list(gen_sentences), quality_filter, device)
118
+ ranked_sentences = sorted(adequacy_scored_phrases.items(), key=lambda x: x[1], reverse=True)
119
+ if len(ranked_sentences) > 0:
120
+ return ranked_sentences[0][0]
121
+ else:
122
+ return None
123
+
124
+ def _active_to_passive(self, input_sentence, device):
125
+ atp_prefix = "transfer Active to Passive: "
126
+ src_sentence = input_sentence
127
+ input_sentence = atp_prefix + input_sentence
128
+ input_ids = self.atp_tokenizer.encode(input_sentence, return_tensors='pt')
129
+ self.atp_model = self.atp_model.to(device)
130
+ input_ids = input_ids.to(device)
131
+
132
+ preds = self.atp_model.generate(
133
+ input_ids,
134
+ do_sample=True,
135
+ max_length=32,
136
+ top_k=50,
137
+ top_p=0.95,
138
+ early_stopping=True,
139
+ num_return_sequences=1)
140
+
141
+ return self.atp_tokenizer.decode(preds[0], skip_special_tokens=True).strip()
142
+
143
+ def _passive_to_active(self, input_sentence, device):
144
+ pta_prefix = "transfer Passive to Active: "
145
+ src_sentence = input_sentence
146
+ input_sentence = pta_prefix + input_sentence
147
+ input_ids = self.pta_tokenizer.encode(input_sentence, return_tensors='pt')
148
+ self.pta_model = self.pta_model.to(device)
149
+ input_ids = input_ids.to(device)
150
+
151
+ preds = self.pta_model.generate(
152
+ input_ids,
153
+ do_sample=True,
154
+ max_length=32,
155
+ top_k=50,
156
+ top_p=0.95,
157
+ early_stopping=True,
158
+ num_return_sequences=1)
159
+
160
+ return self.pta_tokenizer.decode(preds[0], skip_special_tokens=True).strip()