Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import json | |
| import re | |
| import torch | |
| from transformers import GPT2Tokenizer, T5ForConditionalGeneration | |
| # re_tokens = re.compile(r"[а-яА-Я]+\s*|\d+(?:\.\d+)?\s*|[^а-яА-Я\d\s]+\s*") | |
| re_tokens = re.compile(r"(?:[.,!?]|[а-яА-Я]\S*|\d\S*(?:\.\d+)?|[^а-яА-Я\d\s]+)\s*") | |
| def tokenize(text): | |
| return re.findall(re_tokens, text) | |
| def strip_numbers(s): | |
| """ | |
| From `1234567` to `1 234 567` | |
| """ | |
| result = [] | |
| for part in s.split(): | |
| if part.isdigit(): | |
| while len(part) > 3: | |
| result.append(part[:- 3 * ((len(part) - 1) // 3)]) | |
| part = part[- 3 * ((len(part) - 1) // 3):] | |
| if part: | |
| result.append(part) | |
| else: | |
| result.append(part) | |
| return " ".join(result) | |
| def construct_prompt(text): | |
| """ | |
| From `я купил iphone 12X за 142 990 руб без 3-x часов 12:00, и т.д.` \ | |
| to `<SC1>я купил [iphone 12X]<extra_id_0> за [142 990]<extra_id_1> руб без [3-x]<extra_id_2> часов [12:00]<extra_id_3>, и т.д.`. | |
| """ | |
| result = "<SC1>" | |
| etid = 0 | |
| token_to_add = "" | |
| for token in tokenize(text) + [""]: | |
| if not re.search("[a-zA-Z\d]", token): | |
| if token_to_add: | |
| end_match = re.search(r"(.+?)(\W*)$", token_to_add, re.M).groups() | |
| result += f"[{strip_numbers(end_match[0])}]<extra_id_{etid}>{end_match[1]}" | |
| etid += 1 | |
| token_to_add = "" | |
| result += token | |
| else: | |
| token_to_add += token | |
| return result | |
| def construct_answer(prompt:str, prediction:str) -> str: | |
| re_prompt = re.compile(r"\[([^\]]+)\]<extra_id_(\d+)>") | |
| re_pred = re.compile(r"\<extra_id_(\d+)\>(.+?)(?=\<extra_id_\d+\>|</s>)") | |
| pred_data = {} | |
| for match in re.finditer(re_pred, prediction.replace("\n", " ")): | |
| pred_data[match[1]] = match[2].strip() | |
| while match := re.search(re_prompt, prompt): | |
| replace = pred_data.get(match[2], match[1]) | |
| prompt = prompt[:match.span()[0]] + replace + prompt[match.span()[1]:] | |
| return prompt.replace("<SC1>", "") | |
| with open("examples.json") as f: | |
| test_examples = json.load(f) | |
| tokenizer = GPT2Tokenizer.from_pretrained("saarus72/russian_text_normalizer", eos_token='</s>') | |
| model = T5ForConditionalGeneration.from_pretrained("saarus72/russian_text_normalizer") | |
| def predict(text): | |
| input_ids = torch.tensor([tokenizer.encode(text)]) | |
| outputs = model.generate(input_ids, max_new_tokens=50, eos_token_id=tokenizer.eos_token_id, early_stopping=True) | |
| return tokenizer.decode(outputs[0][1:]) | |
| def norm(message, history): | |
| prompt = construct_prompt(message) | |
| yield f"```Prompt:\n{prompt}\nPrediction:\n...```\n..." | |
| prediction = predict(prompt) | |
| answer = construct_answer(prompt, prediction) | |
| # yield f"```\nPrompt:\n{prompt}\nPrediction:\n{prediction}\n```\n{answer}" | |
| yield f"Prompt:\n```{prompt}```\nPrediction:\n```\n{prediction}\n```\n{answer}" | |
| demo = gr.ChatInterface(fn=norm, stop_btn=None, examples=list(test_examples.keys())).queue() | |
| demo.launch() | |
| # |