Spaces:
Sleeping
Sleeping
| import torch | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, LogitsProcessor, LogitsProcessorList | |
| import sys | |
| #model_id = 'benoitfavre/nllb-200-distilled-1.3B_text2picto' | |
| model_id = 'benoitfavre/nllb-200-distilled-600m_text2picto-multi' | |
| device = torch.device('mps' if torch.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu') | |
| picto_prefix = '\uE000' | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_id).to(device) | |
| class RestrictTokensProcessor(LogitsProcessor): | |
| def __init__(self, allowed_token_ids, disallowed_token_ids=set()): | |
| self.allowed_token_ids = set(allowed_token_ids) - set(disallowed_token_ids) | |
| self.mask = None | |
| def __call__(self, input_ids, scores): | |
| # scores: (batch_size, vocab_size) | |
| if self.mask is None or self.mask.shape[-1] != scores.shape[-1]: | |
| self.mask = scores.new_full((1, scores.shape[-1]), float("-inf")) | |
| for token_id in self.allowed_token_ids: | |
| self.mask[:, token_id] = 0 # keep allowed tokens | |
| return scores + self.mask | |
| picto_token_ids = RestrictTokensProcessor(list(tokenizer._added_tokens_decoder.keys()) + tokenizer.all_special_ids + tokenizer.encode(' ')) #, tokenizer.encode('<unk>')) | |
| print('loaded') | |
| def lookup(text, from_lid, to_lid, n): # lookup single token from text | |
| if from_lid == 'picto': text = text.strip().replace(' ', picto_prefix) | |
| tokenizer.src_lang = from_lid | |
| input_ids = sum([tokenizer.encode(x, add_special_tokens=False) for x in [from_lid, text, '</s>']], []) | |
| decoder_input_ids = tokenizer.encode(to_lid, add_special_tokens=False) * 2 # is it necessary to double the tgt token? | |
| #print(decoder_input_ids) | |
| with torch.no_grad(): | |
| result = model( | |
| input_ids=torch.tensor([input_ids], device=model.device), | |
| decoder_input_ids=torch.tensor([decoder_input_ids], device=model.device), | |
| ) | |
| scores = picto_token_ids(decoder_input_ids, result.logits) | |
| values, indices = torch.topk(scores, n) | |
| result = [tokenizer.decode(x).replace(picto_prefix, '') for x in indices[0][0]] | |
| return [x for x in result if x != '' and not x.startswith('<')] | |
| def gen_dir(texts, from_lid, to_lid, max_length=128, nbest=1): | |
| if from_lid == 'picto': | |
| texts = [(' ' + text.strip()).replace(' ', picto_prefix) for text in texts] | |
| tokenizer.src_lang = from_lid | |
| enc = tokenizer( | |
| texts, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=max_length, | |
| ) | |
| enc = {k: v.to(device) for k, v in enc.items()} | |
| forced_id = tokenizer.convert_tokens_to_ids(to_lid) | |
| with torch.no_grad(): | |
| gen_out = model.generate( | |
| **enc, | |
| num_beams=4 if nbest==1 else nbest, | |
| num_return_sequences=nbest, | |
| #max_new_tokens=48, | |
| min_new_tokens=2, | |
| no_repeat_ngram_size=3, | |
| repetition_penalty=1.2, | |
| length_penalty=1.05, | |
| forced_bos_token_id=forced_id, | |
| decoder_start_token_id=forced_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| pad_token_id=tokenizer.pad_token_id, | |
| logits_processor=LogitsProcessorList([picto_token_ids]) if to_lid == 'picto' else None, | |
| ) | |
| #print(gen_out) | |
| #print([[tokenizer.decode(x) for x in line] for line in gen_out]) | |
| if to_lid == 'picto': | |
| return [' '.join([tokenizer.decode(x).replace(picto_prefix, '') for x in sequence if x not in tokenizer.all_special_ids + [tokenizer.encode(' ')]]) for sequence in gen_out] | |
| else: | |
| return tokenizer.batch_decode(gen_out, skip_special_tokens=True) | |
| #print(gen_dir(['Le général est arrivé par la fenêtre et il est reparti en grande pompe'], 'fra_Latn', 'picto')) | |
| #print(gen_dir(['24937 16807 3244 6606 4658 10334'], 'picto', 'fra_Latn')) | |
| from quart import Quart, send_file, request | |
| app = Quart(__name__) | |
| async def nbest(n, src, tgt): | |
| text = await request.get_json() | |
| #result = gen_dir([text], src, tgt, nbest=n) | |
| #return [x.strip().split()[0] for x in result] | |
| return lookup(text, src, tgt, n) | |
| async def translate(src, tgt): | |
| text = await request.get_json() | |
| result = gen_dir([text], src, tgt)[0] | |
| return result | |
| async def lexicon(): | |
| return await send_file('lexicon.js') | |
| async def index(): | |
| return await send_file('index.html') | |
| if __name__ == '__main__': | |
| app.run(debug=True) | |