import torch from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, LogitsProcessor, LogitsProcessorList import sys #model_id = 'benoitfavre/nllb-200-distilled-1.3B_text2picto' model_id = 'benoitfavre/nllb-200-distilled-600m_text2picto-multi' device = torch.device('mps' if torch.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu') picto_prefix = '\uE000' tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForSeq2SeqLM.from_pretrained(model_id).to(device) class RestrictTokensProcessor(LogitsProcessor): def __init__(self, allowed_token_ids, disallowed_token_ids=set()): self.allowed_token_ids = set(allowed_token_ids) - set(disallowed_token_ids) self.mask = None def __call__(self, input_ids, scores): # scores: (batch_size, vocab_size) if self.mask is None or self.mask.shape[-1] != scores.shape[-1]: self.mask = scores.new_full((1, scores.shape[-1]), float("-inf")) for token_id in self.allowed_token_ids: self.mask[:, token_id] = 0 # keep allowed tokens return scores + self.mask picto_token_ids = RestrictTokensProcessor(list(tokenizer._added_tokens_decoder.keys()) + tokenizer.all_special_ids + tokenizer.encode(' ')) #, tokenizer.encode('')) print('loaded') def lookup(text, from_lid, to_lid, n): # lookup single token from text if from_lid == 'picto': text = text.strip().replace(' ', picto_prefix) tokenizer.src_lang = from_lid input_ids = sum([tokenizer.encode(x, add_special_tokens=False) for x in [from_lid, text, '']], []) decoder_input_ids = tokenizer.encode(to_lid, add_special_tokens=False) * 2 # is it necessary to double the tgt token? #print(decoder_input_ids) with torch.no_grad(): result = model( input_ids=torch.tensor([input_ids], device=model.device), decoder_input_ids=torch.tensor([decoder_input_ids], device=model.device), ) scores = picto_token_ids(decoder_input_ids, result.logits) values, indices = torch.topk(scores, n) result = [tokenizer.decode(x).replace(picto_prefix, '') for x in indices[0][0]] return [x for x in result if x != '' and not x.startswith('<')] def gen_dir(texts, from_lid, to_lid, max_length=128, nbest=1): if from_lid == 'picto': texts = [(' ' + text.strip()).replace(' ', picto_prefix) for text in texts] tokenizer.src_lang = from_lid enc = tokenizer( texts, return_tensors="pt", padding=True, truncation=True, max_length=max_length, ) enc = {k: v.to(device) for k, v in enc.items()} forced_id = tokenizer.convert_tokens_to_ids(to_lid) with torch.no_grad(): gen_out = model.generate( **enc, num_beams=4 if nbest==1 else nbest, num_return_sequences=nbest, #max_new_tokens=48, min_new_tokens=2, no_repeat_ngram_size=3, repetition_penalty=1.2, length_penalty=1.05, forced_bos_token_id=forced_id, decoder_start_token_id=forced_id, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id, logits_processor=LogitsProcessorList([picto_token_ids]) if to_lid == 'picto' else None, ) #print(gen_out) #print([[tokenizer.decode(x) for x in line] for line in gen_out]) if to_lid == 'picto': return [' '.join([tokenizer.decode(x).replace(picto_prefix, '') for x in sequence if x not in tokenizer.all_special_ids + [tokenizer.encode(' ')]]) for sequence in gen_out] else: return tokenizer.batch_decode(gen_out, skip_special_tokens=True) #print(gen_dir(['Le général est arrivé par la fenêtre et il est reparti en grande pompe'], 'fra_Latn', 'picto')) #print(gen_dir(['24937 16807 3244 6606 4658 10334'], 'picto', 'fra_Latn')) from quart import Quart, send_file, request app = Quart(__name__) @app.route('/nbest///', methods=['POST']) async def nbest(n, src, tgt): text = await request.get_json() #result = gen_dir([text], src, tgt, nbest=n) #return [x.strip().split()[0] for x in result] return lookup(text, src, tgt, n) @app.route('/translate//', methods=['POST']) async def translate(src, tgt): text = await request.get_json() result = gen_dir([text], src, tgt)[0] return result @app.route('/lexicon.js') async def lexicon(): return await send_file('lexicon.js') @app.route('/') async def index(): return await send_file('index.html') if __name__ == '__main__': app.run(debug=True)