gcaillaut commited on
Commit
5437ff2
·
1 Parent(s): 91126af

better handle missing src and domain tokens

Browse files
Files changed (1) hide show
  1. app.py +101 -18
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
3
  import torch
4
  import itertools
5
 
@@ -75,11 +76,11 @@ CODE2LANG = {v: k for k, v in LANG2CODE.items()}
75
  LANGUAGES = sorted(LANG2CODE.keys())
76
 
77
 
78
- def language_token(lang):
79
  return f"<lang_{lang}>"
80
 
81
 
82
- def domain_token(dom):
83
  return f"<dom_{dom}>"
84
 
85
 
@@ -92,7 +93,7 @@ def domain_token_to_str(token):
92
 
93
 
94
  def format_input(src, tgt_lang, src_lang, domain):
95
- tgt_lang_token = language_token(tgt_lang)
96
 
97
  prefix = TOKENIZER.eos_token
98
 
@@ -100,13 +101,13 @@ def format_input(src, tgt_lang, src_lang, domain):
100
  if src_lang is None:
101
  return base_input
102
  else:
103
- src_lang_token = language_token(src_lang)
104
  base_input = f"{base_input}{src_lang_token}"
105
 
106
  if domain is None:
107
  return base_input
108
  else:
109
- dom_token = domain_token(domain)
110
  base_input = f"{base_input}{dom_token}"
111
 
112
  return base_input
@@ -115,27 +116,109 @@ def format_input(src, tgt_lang, src_lang, domain):
115
  def translate_with_model(model_name, text, tgt_lang, src_lang, domain):
116
  model = MODELS[model_name]
117
  formatted_text = format_input(text, tgt_lang, src_lang, domain)
118
-
119
- inputs = TOKENIZER(formatted_text, return_tensors="pt", return_token_type_ids=False)
 
 
 
 
120
  for k, v in inputs.items():
121
  inputs[k] = v.to(DEVICE)
 
 
 
122
 
123
- if src_lang is None:
124
- domain_token_pos = inputs["input_ids"].size(1) + 1
125
- elif domain is None:
126
- domain_token_pos = inputs["input_ids"].size(1)
127
- else:
128
- domain_token_pos = inputs["input_ids"].size(1) - 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  src_lang_token_pos = domain_token_pos - 1
130
  _tgt_lang_token_pos = src_lang_token_pos - 1
131
 
132
  outputs = model.generate(
133
- **inputs,
134
- num_beams=5,
135
- length_penalty=0.65,
136
  max_new_tokens=500,
137
  pad_token_id=TOKENIZER.pad_token_id,
138
  eos_token_id=TOKENIZER.eos_token_id,
 
139
  )
140
 
141
  generated_translation = TOKENIZER.decode(
@@ -145,12 +228,12 @@ def translate_with_model(model_name, text, tgt_lang, src_lang, domain):
145
  source_language_token = TOKENIZER.convert_ids_to_tokens(
146
  outputs[0, src_lang_token_pos].item()
147
  )
148
- domain_token = TOKENIZER.convert_ids_to_tokens(outputs[0, domain_token_pos].item())
149
 
150
  return {
151
  "model": model_name,
152
  "source_lang": CODE2LANG[language_token_to_str(source_language_token)],
153
- "domain": DOMAIN_MAPPING_REVERSED[domain_token_to_str(domain_token)],
154
  "translation": generated_translation,
155
  }
156
 
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ from transformers.cache_utils import DynamicCache
4
  import torch
5
  import itertools
6
 
 
76
  LANGUAGES = sorted(LANG2CODE.keys())
77
 
78
 
79
+ def build_language_token(lang):
80
  return f"<lang_{lang}>"
81
 
82
 
83
+ def build_domain_token(dom):
84
  return f"<dom_{dom}>"
85
 
86
 
 
93
 
94
 
95
  def format_input(src, tgt_lang, src_lang, domain):
96
+ tgt_lang_token = build_language_token(tgt_lang)
97
 
98
  prefix = TOKENIZER.eos_token
99
 
 
101
  if src_lang is None:
102
  return base_input
103
  else:
104
+ src_lang_token = build_language_token(src_lang)
105
  base_input = f"{base_input}{src_lang_token}"
106
 
107
  if domain is None:
108
  return base_input
109
  else:
110
+ dom_token = build_domain_token(domain)
111
  base_input = f"{base_input}{dom_token}"
112
 
113
  return base_input
 
116
  def translate_with_model(model_name, text, tgt_lang, src_lang, domain):
117
  model = MODELS[model_name]
118
  formatted_text = format_input(text, tgt_lang, src_lang, domain)
119
+ inputs = TOKENIZER(
120
+ formatted_text,
121
+ return_attention_mask=True,
122
+ return_tensors="pt",
123
+ return_token_type_ids=False,
124
+ )
125
  for k, v in inputs.items():
126
  inputs[k] = v.to(DEVICE)
127
+ src_lang_provided = src_lang is not None
128
+ domain_provided = domain is not None
129
+ need_format_again = not (src_lang_provided and domain_provided)
130
 
131
+ past_key_values = DynamicCache()
132
+ cache_position = torch.arange(
133
+ inputs["input_ids"].size(1), dtype=torch.int64, device=DEVICE
134
+ )
135
+
136
+ if not src_lang_provided:
137
+ # Need to predict src lang
138
+ with torch.inference_mode():
139
+ outputs = model(
140
+ input_ids=inputs["input_ids"],
141
+ attention_mask=inputs["attention_mask"],
142
+ use_cache=True,
143
+ past_key_values=past_key_values,
144
+ cache_position=cache_position,
145
+ )
146
+ src_lang_token_id = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(0)
147
+ src_lang = language_token_to_str(
148
+ TOKENIZER.convert_ids_to_tokens(src_lang_token_id.squeeze().item())
149
+ )
150
+
151
+ cache_position = cache_position[-1:] + 1
152
+
153
+ attention_mask = inputs["attention_mask"]
154
+ attention_mask = torch.cat(
155
+ [attention_mask, attention_mask.new_ones((attention_mask.size(0), 1))],
156
+ dim=-1,
157
+ )
158
+ inputs = {"input_ids": src_lang_token_id, "attention_mask": attention_mask}
159
+
160
+ if not domain_provided:
161
+ # Need to predict domain
162
+ with torch.inference_mode():
163
+ outputs = model(
164
+ input_ids=inputs["input_ids"],
165
+ attention_mask=inputs["attention_mask"],
166
+ use_cache=True,
167
+ past_key_values=past_key_values,
168
+ )
169
+ domain_token_id = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(0)
170
+ domain = domain_token_to_str(
171
+ TOKENIZER.convert_ids_to_tokens(domain_token_id.squeeze().item())
172
+ )
173
+
174
+ cache_position = cache_position[-1:] + 1
175
+
176
+ attention_mask = inputs["attention_mask"]
177
+ attention_mask = torch.cat(
178
+ [attention_mask, attention_mask.new_ones((attention_mask.size(0), 1))],
179
+ dim=-1,
180
+ )
181
+
182
+ inputs = {"input_ids": domain_token_id, "attention_mask": attention_mask}
183
+ elif not src_lang_provided:
184
+ # in this case, src_lang was not provided, but domain was.
185
+ # So we still need to run a forward pass to build the kv cache for the domain token
186
+ dom_token = build_domain_token(domain)
187
+ # dom_token = "<dom_general>"
188
+ domain = domain_token_to_str(dom_token)
189
+
190
+ domain_token_id = TOKENIZER.convert_tokens_to_ids(dom_token)
191
+ inputs["input_ids"] = torch.hstack(
192
+ [inputs["input_ids"], torch.tensor([[domain_token_id]], device=DEVICE)]
193
+ )
194
+ inputs["attention_mask"] = torch.hstack(
195
+ [inputs["attention_mask"], inputs["attention_mask"].new_ones((1, 1))]
196
+ )
197
+ cache_position = torch.hstack([cache_position, cache_position[-1:] + 1])
198
+
199
+ if need_format_again:
200
+ formatted_text = format_input(text, tgt_lang, src_lang, domain)
201
+ inputs = TOKENIZER(
202
+ formatted_text,
203
+ return_attention_mask=True,
204
+ return_tensors="pt",
205
+ return_token_type_ids=False,
206
+ )
207
+ for k, v in inputs.items():
208
+ inputs[k] = v.to(DEVICE)
209
+
210
+ domain_token_pos = inputs["input_ids"].size(1) - 1
211
  src_lang_token_pos = domain_token_pos - 1
212
  _tgt_lang_token_pos = src_lang_token_pos - 1
213
 
214
  outputs = model.generate(
215
+ input_ids=inputs["input_ids"],
216
+ attention_mask=inputs["attention_mask"],
217
+ num_beams=1,
218
  max_new_tokens=500,
219
  pad_token_id=TOKENIZER.pad_token_id,
220
  eos_token_id=TOKENIZER.eos_token_id,
221
+ past_key_values=past_key_values,
222
  )
223
 
224
  generated_translation = TOKENIZER.decode(
 
228
  source_language_token = TOKENIZER.convert_ids_to_tokens(
229
  outputs[0, src_lang_token_pos].item()
230
  )
231
+ dom_token = TOKENIZER.convert_ids_to_tokens(outputs[0, domain_token_pos].item())
232
 
233
  return {
234
  "model": model_name,
235
  "source_lang": CODE2LANG[language_token_to_str(source_language_token)],
236
+ "domain": DOMAIN_MAPPING_REVERSED[domain_token_to_str(dom_token)],
237
  "translation": generated_translation,
238
  }
239