Spaces:
Running
Running
| #!/usr/bin/python3 | |
| # -*- coding: utf-8 -*- | |
| import argparse | |
| from collections import defaultdict | |
| import json | |
| import os | |
| import platform | |
| import re | |
| import string | |
| from typing import List | |
| from project_settings import project_path | |
| os.environ["HUGGINGFACE_HUB_CACHE"] = (project_path / "cache/huggingface/hub").as_posix() | |
| import gradio as gr | |
| from threading import Thread | |
| from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel | |
| from transformers.models.bert.tokenization_bert import BertTokenizer | |
| from transformers.generation.streamers import TextIteratorStreamer | |
| import torch | |
| def get_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--max_new_tokens", default=512, type=int) | |
| parser.add_argument("--top_p", default=0.9, type=float) | |
| parser.add_argument("--temperature", default=0.35, type=float) | |
| parser.add_argument("--repetition_penalty", default=1.0, type=float) | |
| parser.add_argument('--device', default="cuda" if torch.cuda.is_available() else "cpu", type=str) | |
| parser.add_argument( | |
| "--examples_json_file", | |
| default="examples.json", | |
| type=str | |
| ) | |
| args = parser.parse_args() | |
| return args | |
| def repl1(match): | |
| result = "{}{}".format(match.group(1), match.group(2)) | |
| return result | |
| def repl2(match): | |
| result = "{}".format(match.group(1)) | |
| return result | |
| def remove_space_between_cn_en(text): | |
| splits = re.split(" ", text) | |
| if len(splits) < 2: | |
| return text | |
| result = "" | |
| for t in splits: | |
| if t == "": | |
| continue | |
| if re.search(f"[a-zA-Z0-9{string.punctuation}]$", result) and re.search("^[a-zA-Z0-9]", t): | |
| result += " " | |
| result += t | |
| else: | |
| if not result == "": | |
| result += t | |
| else: | |
| result = t | |
| if text.endswith(" "): | |
| result += " " | |
| return result | |
| def main(): | |
| args = get_args() | |
| description = """ | |
| ## GPT2 Chat | |
| """ | |
| # example json | |
| with open(args.examples_json_file, "r", encoding="utf-8") as f: | |
| examples = json.load(f) | |
| if args.device == 'auto': | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| else: | |
| device = args.device | |
| input_text_box = gr.Text(label="text") | |
| output_text_box = gr.Text(lines=4, label="generated_content") | |
| def fn_stream(text: str, | |
| max_new_tokens: int = 200, | |
| top_p: float = 0.85, | |
| temperature: float = 0.35, | |
| repetition_penalty: float = 1.2, | |
| model_name: str = "qgyd2021/lip_service_4chan", | |
| is_chat: bool = True, | |
| ): | |
| tokenizer = BertTokenizer.from_pretrained(model_name) | |
| model = GPT2LMHeadModel.from_pretrained(model_name) | |
| model = model.eval() | |
| text_encoded = tokenizer.__call__(text, add_special_tokens=False) | |
| input_ids_ = text_encoded["input_ids"] | |
| input_ids = [tokenizer.cls_token_id] | |
| input_ids.extend(input_ids_) | |
| if is_chat: | |
| input_ids.append(tokenizer.sep_token_id) | |
| input_ids = torch.tensor([input_ids], dtype=torch.long) | |
| input_ids = input_ids.to(device) | |
| streamer = TextIteratorStreamer(tokenizer=tokenizer) | |
| generation_kwargs = dict( | |
| inputs=input_ids, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=True, | |
| top_p=top_p, | |
| temperature=temperature, | |
| repetition_penalty=repetition_penalty, | |
| eos_token_id=tokenizer.sep_token_id if is_chat else None, | |
| pad_token_id=tokenizer.pad_token_id, | |
| streamer=streamer, | |
| ) | |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| output: str = "" | |
| first_answer = True | |
| for output_ in streamer: | |
| if first_answer: | |
| first_answer = False | |
| continue | |
| output_ = output_.replace("[UNK] ", "") | |
| output_ = output_.replace("[UNK]", "") | |
| output_ = output_.replace("[CLS] ", "") | |
| output_ = output_.replace("[CLS]", "") | |
| output += output_ | |
| if output.startswith("[SEP]"): | |
| output = output[5:] | |
| output = output.lstrip(" ,.!?") | |
| output = remove_space_between_cn_en(output) | |
| # output = re.sub(r"([,。!?\u4e00-\u9fa5]) ([,。!?\u4e00-\u9fa5])", repl1, output) | |
| # output = re.sub(r"([,。!?\u4e00-\u9fa5]) ", repl2, output) | |
| output = output.replace("[SEP] ", "\n") | |
| output = output.replace("[SEP]", "\n") | |
| yield output | |
| model_name_choices = [ | |
| "trained_models/lip_service_4chan", | |
| "trained_models/chinese_porn_novel" | |
| ] if platform.system() == "Windows" else \ | |
| [ | |
| "qgyd2021/lip_service_4chan", "qgyd2021/chinese_chitchat", | |
| "qgyd2021/chinese_porn_novel", "qgyd2021/few_shot_intent_gpt2_base", | |
| "qgyd2021/similar_question_generation", | |
| ] | |
| # model_name_choices = [ | |
| # "qgyd2021/lip_service_4chan", "qgyd2021/chinese_chitchat", | |
| # "qgyd2021/chinese_porn_novel", "qgyd2021/few_shot_intent_gpt2_base", | |
| # "qgyd2021/similar_question_generation", | |
| # ] | |
| demo = gr.Interface( | |
| fn=fn_stream, | |
| inputs=[ | |
| input_text_box, | |
| gr.Slider(minimum=0, maximum=512, value=512, step=1, label="max_new_tokens"), | |
| gr.Slider(minimum=0, maximum=1, value=0.85, step=0.01, label="top_p"), | |
| gr.Slider(minimum=0, maximum=1, value=0.35, step=0.01, label="temperature"), | |
| gr.Slider(minimum=0, maximum=2, value=1.2, step=0.01, label="repetition_penalty"), | |
| gr.Dropdown(choices=model_name_choices, value=model_name_choices[0], label="model_name"), | |
| gr.Checkbox(value=True, label="is_chat") | |
| ], | |
| outputs=[output_text_box], | |
| examples=examples, | |
| cache_examples=False, | |
| examples_per_page=50, | |
| title="GPT2 Chat", | |
| description=description, | |
| ) | |
| demo.queue().launch() | |
| return | |
| if __name__ == '__main__': | |
| main() | |