Benoit Favre
select correct model
878b865
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, LogitsProcessor, LogitsProcessorList
import sys
#model_id = 'benoitfavre/nllb-200-distilled-1.3B_text2picto'
model_id = 'benoitfavre/nllb-200-distilled-600m_text2picto-multi'
device = torch.device('mps' if torch.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu')
picto_prefix = '\uE000'
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSeq2SeqLM.from_pretrained(model_id).to(device)
class RestrictTokensProcessor(LogitsProcessor):
def __init__(self, allowed_token_ids, disallowed_token_ids=set()):
self.allowed_token_ids = set(allowed_token_ids) - set(disallowed_token_ids)
self.mask = None
def __call__(self, input_ids, scores):
# scores: (batch_size, vocab_size)
if self.mask is None or self.mask.shape[-1] != scores.shape[-1]:
self.mask = scores.new_full((1, scores.shape[-1]), float("-inf"))
for token_id in self.allowed_token_ids:
self.mask[:, token_id] = 0 # keep allowed tokens
return scores + self.mask
picto_token_ids = RestrictTokensProcessor(list(tokenizer._added_tokens_decoder.keys()) + tokenizer.all_special_ids + tokenizer.encode(' ')) #, tokenizer.encode('<unk>'))
print('loaded')
def lookup(text, from_lid, to_lid, n): # lookup single token from text
if from_lid == 'picto': text = text.strip().replace(' ', picto_prefix)
tokenizer.src_lang = from_lid
input_ids = sum([tokenizer.encode(x, add_special_tokens=False) for x in [from_lid, text, '</s>']], [])
decoder_input_ids = tokenizer.encode(to_lid, add_special_tokens=False) * 2 # is it necessary to double the tgt token?
#print(decoder_input_ids)
with torch.no_grad():
result = model(
input_ids=torch.tensor([input_ids], device=model.device),
decoder_input_ids=torch.tensor([decoder_input_ids], device=model.device),
)
scores = picto_token_ids(decoder_input_ids, result.logits)
values, indices = torch.topk(scores, n)
result = [tokenizer.decode(x).replace(picto_prefix, '') for x in indices[0][0]]
return [x for x in result if x != '' and not x.startswith('<')]
def gen_dir(texts, from_lid, to_lid, max_length=128, nbest=1):
if from_lid == 'picto':
texts = [(' ' + text.strip()).replace(' ', picto_prefix) for text in texts]
tokenizer.src_lang = from_lid
enc = tokenizer(
texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_length,
)
enc = {k: v.to(device) for k, v in enc.items()}
forced_id = tokenizer.convert_tokens_to_ids(to_lid)
with torch.no_grad():
gen_out = model.generate(
**enc,
num_beams=4 if nbest==1 else nbest,
num_return_sequences=nbest,
#max_new_tokens=48,
min_new_tokens=2,
no_repeat_ngram_size=3,
repetition_penalty=1.2,
length_penalty=1.05,
forced_bos_token_id=forced_id,
decoder_start_token_id=forced_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
logits_processor=LogitsProcessorList([picto_token_ids]) if to_lid == 'picto' else None,
)
#print(gen_out)
#print([[tokenizer.decode(x) for x in line] for line in gen_out])
if to_lid == 'picto':
return [' '.join([tokenizer.decode(x).replace(picto_prefix, '') for x in sequence if x not in tokenizer.all_special_ids + [tokenizer.encode(' ')]]) for sequence in gen_out]
else:
return tokenizer.batch_decode(gen_out, skip_special_tokens=True)
#print(gen_dir(['Le général est arrivé par la fenêtre et il est reparti en grande pompe'], 'fra_Latn', 'picto'))
#print(gen_dir(['24937 16807 3244 6606 4658 10334'], 'picto', 'fra_Latn'))
from quart import Quart, send_file, request
app = Quart(__name__)
@app.route('/nbest/<int:n>/<string:src>/<string:tgt>', methods=['POST'])
async def nbest(n, src, tgt):
text = await request.get_json()
#result = gen_dir([text], src, tgt, nbest=n)
#return [x.strip().split()[0] for x in result]
return lookup(text, src, tgt, n)
@app.route('/translate/<string:src>/<string:tgt>', methods=['POST'])
async def translate(src, tgt):
text = await request.get_json()
result = gen_dir([text], src, tgt)[0]
return result
@app.route('/lexicon.js')
async def lexicon():
return await send_file('lexicon.js')
@app.route('/')
async def index():
return await send_file('index.html')
if __name__ == '__main__':
app.run(debug=True)