Spaces:
Runtime error
Runtime error
| import random | |
| import re | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM | |
| from transformers import AutoModelForSeq2SeqLM | |
| from transformers import AutoTokenizer | |
| from transformers import AutoProcessor | |
| from transformers import pipeline | |
| from transformers import set_seed | |
| global ButtonIndex | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| big_processor = AutoProcessor.from_pretrained("microsoft/git-base-coco") | |
| big_model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-coco") | |
| pipeline_01 = pipeline('text-generation', model='succinctly/text2image-prompt-generator', max_new_tokens=256) | |
| pipeline_02 = pipeline('text-generation', model='Gustavosta/MagicPrompt-Stable-Diffusion', max_new_tokens=256) | |
| pipeline_03 = pipeline('text-generation', model='johnsu6616/ModelExport', max_new_tokens=256) | |
| zh2en_model = AutoModelForSeq2SeqLM.from_pretrained('Helsinki-NLP/opus-mt-zh-en').eval() | |
| zh2en_tokenizer = AutoTokenizer.from_pretrained('Helsinki-NLP/opus-mt-zh-en') | |
| en2zh_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-zh").eval() | |
| en2zh_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-zh") | |
| def translate_zh2en(text): | |
| with torch.no_grad(): | |
| text = re.sub(r"[:\-–.!;?_#]", '', text) | |
| text = re.sub(r'([^\u4e00-\u9fa5])([\u4e00-\u9fa5])', r'\1\n\2', text) | |
| text = re.sub(r'([\u4e00-\u9fa5])([^\u4e00-\u9fa5])', r'\1\n\2', text) | |
| text = text.replace('\n', ',') | |
| text =re.sub(r'(?<![a-zA-Z])\s+|\s+(?![a-zA-Z])', '', text) | |
| text = re.sub(r',+', ',', text) | |
| encoded = zh2en_tokenizer([text], return_tensors='pt') | |
| sequences = zh2en_model.generate(**encoded) | |
| result = zh2en_tokenizer.batch_decode(sequences, skip_special_tokens=True)[0] | |
| result = result.strip() | |
| if result == "No,no," : | |
| result = text | |
| result = re.sub(r'<.*?>', '', result) | |
| result = re.sub(r'\b(\w+)\b(?:\W+\1\b)+', r'\1', result, flags=re.IGNORECASE) | |
| return result | |
| def translate_en2zh(text): | |
| with torch.no_grad(): | |
| encoded = en2zh_tokenizer([text], return_tensors="pt") | |
| sequences = en2zh_model.generate(**encoded) | |
| result = en2zh_tokenizer.batch_decode(sequences, skip_special_tokens=True)[0] | |
| result = re.sub(r'\b(\w+)\b(?:\W+\1\b)+', r'\1', result, flags=re.IGNORECASE) | |
| return result | |
| def load_prompter(): | |
| prompter_model = AutoModelForCausalLM.from_pretrained("microsoft/Promptist") | |
| tokenizer = AutoTokenizer.from_pretrained("gpt2") | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer.padding_side = "left" | |
| return prompter_model, tokenizer | |
| prompter_model, prompter_tokenizer = load_prompter() | |
| def generate_prompter_pipeline_01(text): | |
| seed = random.randint(100, 1000000) | |
| set_seed(seed) | |
| text_in_english = translate_zh2en(text) | |
| response = pipeline_01(text_in_english, num_return_sequences=3) | |
| response_list = [] | |
| for x in response: | |
| resp = x['generated_text'].strip() | |
| if resp != text_in_english and len(resp) > (len(text_in_english) + 4): | |
| response_list.append(translate_en2zh(resp)+"\n") | |
| response_list.append(resp+"\n") | |
| response_list.append("\n") | |
| result = "".join(response_list) | |
| result = re.sub('[^ ]+\.[^ ]+','', result) | |
| result = result.replace("<", "").replace(">", "") | |
| if result != "": | |
| return result | |
| def generate_prompter_tokenizer_01(text): | |
| text_in_english = translate_zh2en(text) | |
| input_ids = prompter_tokenizer(text_in_english.strip()+" Rephrase:", return_tensors="pt").input_ids | |
| outputs = prompter_model.generate( | |
| input_ids, | |
| do_sample=False, | |
| num_beams=3, | |
| num_return_sequences=3, | |
| pad_token_id= 50256, | |
| eos_token_id = 50256, | |
| length_penalty=-1.0 | |
| ) | |
| output_texts = prompter_tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
| result = [] | |
| for output_text in output_texts: | |
| output_text = output_text.replace('<', '').replace('>', '') | |
| output_text = output_text.split("Rephrase:", 1)[-1].strip() | |
| result.append(translate_en2zh(output_text)+"\n") | |
| result.append(output_text+"\n") | |
| result.append("\n") | |
| return "".join(result) | |
| def generate_prompter_pipeline_02(text): | |
| seed = random.randint(100, 1000000) | |
| set_seed(seed) | |
| text_in_english = translate_zh2en(text) | |
| response = pipeline_02(text_in_english, num_return_sequences=3) | |
| response_list = [] | |
| for x in response: | |
| resp = x['generated_text'].strip() | |
| if resp != text_in_english and len(resp) > (len(text_in_english) + 4): | |
| response_list.append(translate_en2zh(resp)+"\n") | |
| response_list.append(resp+"\n") | |
| response_list.append("\n") | |
| result = "".join(response_list) | |
| result = re.sub('[^ ]+\.[^ ]+','', result) | |
| result = result.replace("<", "").replace(">", "") | |
| if result != "": | |
| return result | |
| def generate_prompter_pipeline_03(text): | |
| seed = random.randint(100, 1000000) | |
| set_seed(seed) | |
| text_in_english = translate_zh2en(text) | |
| response = pipeline_03(text_in_english, num_return_sequences=3) | |
| response_list = [] | |
| for x in response: | |
| resp = x['generated_text'].strip() | |
| if resp != text_in_english and len(resp) > (len(text_in_english) + 4): | |
| response_list.append(translate_en2zh(resp)+"\n") | |
| response_list.append(resp+"\n") | |
| response_list.append("\n") | |
| result = "".join(response_list) | |
| result = re.sub('[^ ]+\.[^ ]+','', result) | |
| result = result.replace("<", "").replace(">", "") | |
| if result != "": | |
| return result | |
| def generate_render(text,choice): | |
| if choice == '★pipeline模式(succinctly)': | |
| outputs = generate_prompter_pipeline_01(text) | |
| return outputs,choice | |
| elif choice == '★★tokenizer模式': | |
| outputs = generate_prompter_tokenizer_01(text) | |
| return outputs,choice | |
| elif choice == '★★★pipeline模型(Gustavosta)': | |
| outputs = generate_prompter_pipeline_02(text) | |
| return outputs,choice | |
| elif choice == 'pipeline模型(John)_自訓測試,資料不穩定': | |
| outputs = generate_prompter_pipeline_03(text) | |
| return outputs,choice | |
| def get_prompt_from_image(input_image,choice): | |
| image = input_image.convert('RGB') | |
| pixel_values = big_processor(images=image, return_tensors="pt").to(device).pixel_values | |
| generated_ids = big_model.to(device).generate(pixel_values=pixel_values) | |
| generated_caption = big_processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| text = re.sub(r"[:\-–.!;?_#]", '', generated_caption) | |
| if choice == '★pipeline模式(succinctly)': | |
| outputs = generate_prompter_pipeline_01(text) | |
| return outputs | |
| elif choice == '★★tokenizer模式': | |
| outputs = generate_prompter_tokenizer_01(text) | |
| return outputs | |
| elif choice == '★★★pipeline模型(Gustavosta)': | |
| outputs = generate_prompter_pipeline_02(text) | |
| return outputs | |
| elif choice == 'pipeline模型(John)_自訓測試,資料不穩定': | |
| outputs = generate_prompter_pipeline_03(text) | |
| return outputs | |
| with gr.Blocks() as block: | |
| with gr.Column(): | |
| with gr.Tab('工作區'): | |
| with gr.Row(): | |
| input_text = gr.Textbox(lines=12, label='輸入文字', placeholder='在此输入文字...') | |
| input_image = gr.Image(type='pil', label="選擇圖片(辨識度不佳)") | |
| with gr.Row(): | |
| txt_prompter_btn = gr.Button('文生文') | |
| pic_prompter_btn = gr.Button('圖生文') | |
| with gr.Row(): | |
| radio_btn = gr.Radio( | |
| label="請選擇產出方式", | |
| choices=['★pipeline模式(succinctly)', '★★tokenizer模式', '★★★pipeline模型(Gustavosta)', | |
| 'pipeline模型(John)_自訓測試,資料不穩定'], | |
| value='★pipeline模式(succinctly)' | |
| ) | |
| with gr.Row(): | |
| Textbox_1 = gr.Textbox(lines=6, label='提示詞生成') | |
| with gr.Row(): | |
| Textbox_2 = gr.Textbox(lines=6, label='測試資訊') | |
| with gr.Tab('測試區'): | |
| with gr.Row(): | |
| input_test01 = gr.Textbox(lines=2, label='中英翻譯', placeholder='在此输入文字...') | |
| test01_btn = gr.Button('執行') | |
| Textbox_test01 = gr.Textbox(lines=2, label='輸出結果') | |
| with gr.Row(): | |
| input_test02 = gr.Textbox(lines=2, label='英中翻譯(不精準)', placeholder='在此输入文字...') | |
| test02_btn = gr.Button('執行') | |
| Textbox_test02 = gr.Textbox(lines=2, label='輸出結果') | |
| with gr.Row(): | |
| input_test03 = gr.Textbox(lines=2, label='★pipeline模式(succinctly)', placeholder='在此输入文字...') | |
| test03_btn = gr.Button('執行') | |
| Textbox_test03 = gr.Textbox(lines=2, label='輸出結果') | |
| with gr.Row(): | |
| input_test04 = gr.Textbox(lines=2, label='★★tokenizer模式', placeholder='在此输入文字...') | |
| test04_btn = gr.Button('執行') | |
| Textbox_test04 = gr.Textbox(lines=2, label='輸出結果') | |
| with gr.Row(): | |
| input_test05 = gr.Textbox(lines=2, label='★★★pipeline模型(Gustavosta)', placeholder='在此输入文字...') | |
| test05_btn = gr.Button('執行') | |
| Textbox_test05 = gr.Textbox(lines=2, label='輸出結果') | |
| with gr.Row(): | |
| input_test06 = gr.Textbox(lines=2, label='pipeline模型(John)_自訓測試,資料不穩定', placeholder='在此输入文字...') | |
| test06_btn = gr.Button('執行') | |
| Textbox_test06 = gr.Textbox(lines=2, label='輸出結果') | |
| txt_prompter_btn.click ( | |
| fn=generate_render, | |
| inputs=[input_text,radio_btn], | |
| outputs=[Textbox_1,Textbox_2] | |
| ) | |
| pic_prompter_btn.click( | |
| fn=get_prompt_from_image, | |
| inputs=[input_image,radio_btn], | |
| outputs=Textbox_1 | |
| ) | |
| test01_btn.click( | |
| fn=translate_zh2en, | |
| inputs=input_test01, | |
| outputs=Textbox_test01 | |
| ) | |
| test02_btn.click( | |
| fn=translate_en2zh, | |
| inputs=input_test02, | |
| outputs=Textbox_test02 | |
| ) | |
| test03_btn.click( | |
| fn= generate_prompter_pipeline_01, | |
| inputs=input_test03, | |
| outputs=Textbox_test03 | |
| ) | |
| test04_btn.click( | |
| fn= generate_prompter_tokenizer_01, | |
| inputs=input_test04, | |
| outputs=Textbox_test04 | |
| ) | |
| test05_btn.click( | |
| fn= generate_prompter_pipeline_02, | |
| inputs=input_test05, | |
| outputs=Textbox_test05 | |
| ) | |
| test06_btn.click( | |
| fn= generate_prompter_pipeline_03, | |
| inputs= input_test06, | |
| outputs= Textbox_test06 | |
| ) | |
| block.launch(show_api=False, debug=True, share=False, server_name='0.0.0.0') | |