Spaces:
Configuration error
Configuration error
| #credit to shadowcz007 for this module | |
| #from https://github.com/shadowcz007/comfyui-mixlab-nodes/blob/main/nodes/TextGenerateNode.py | |
| import re | |
| import os | |
| import folder_paths | |
| import comfy.utils | |
| import torch | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| from .utils import install_package | |
| try: | |
| from lark import Lark, Transformer, v_args | |
| except: | |
| print('install lark-parser...') | |
| install_package('lark-parser') | |
| from lark import Lark, Transformer, v_args | |
| model_path = os.path.join(folder_paths.models_dir, 'prompt_generator') | |
| zh_en_model_path = os.path.join(model_path, 'opus-mt-zh-en') | |
| zh_en_model, zh_en_tokenizer = None, None | |
| def correct_prompt_syntax(prompt=""): | |
| # print("input prompt",prompt) | |
| corrected_elements = [] | |
| # 处理成统一的英文标点 | |
| prompt = prompt.replace('(', '(').replace(')', ')').replace(',', ',').replace(';', ',').replace('。', '.').replace(':',':').replace('\\',',') | |
| # 删除多余的空格 | |
| prompt = re.sub(r'\s+', ' ', prompt).strip() | |
| prompt = prompt.replace("< ","<").replace(" >",">").replace("( ","(").replace(" )",")").replace("[ ","[").replace(' ]',']') | |
| # 分词 | |
| prompt_elements = prompt.split(',') | |
| def balance_brackets(element, open_bracket, close_bracket): | |
| open_brackets_count = element.count(open_bracket) | |
| close_brackets_count = element.count(close_bracket) | |
| return element + close_bracket * (open_brackets_count - close_brackets_count) | |
| for element in prompt_elements: | |
| element = element.strip() | |
| # 处理空元素 | |
| if not element: | |
| continue | |
| # 检查并处理圆括号、方括号、尖括号 | |
| if element[0] in '([': | |
| corrected_element = balance_brackets(element, '(', ')') if element[0] == '(' else balance_brackets(element, '[', ']') | |
| elif element[0] == '<': | |
| corrected_element = balance_brackets(element, '<', '>') | |
| else: | |
| # 删除开头的右括号或右方括号 | |
| corrected_element = element.lstrip(')]') | |
| corrected_elements.append(corrected_element) | |
| # 重组修正后的prompt | |
| return ','.join(corrected_elements) | |
| def detect_language(input_str): | |
| # 统计中文和英文字符的数量 | |
| count_cn = count_en = 0 | |
| for char in input_str: | |
| if '\u4e00' <= char <= '\u9fff': | |
| count_cn += 1 | |
| elif char.isalpha(): | |
| count_en += 1 | |
| # 根据统计的字符数量判断主要语言 | |
| if count_cn > count_en: | |
| return "cn" | |
| elif count_en > count_cn: | |
| return "en" | |
| else: | |
| return "unknow" | |
| def has_chinese(text): | |
| has_cn = False | |
| _text = text | |
| _text = re.sub(r'<.*?>', '', _text) | |
| _text = re.sub(r'__.*?__', '', _text) | |
| _text = re.sub(r'embedding:.*?$', '', _text) | |
| for char in _text: | |
| if '\u4e00' <= char <= '\u9fff': | |
| has_cn = True | |
| break | |
| elif char.isalpha(): | |
| continue | |
| return has_cn | |
| def translate(text): | |
| global zh_en_model_path, zh_en_model, zh_en_tokenizer | |
| if not os.path.exists(zh_en_model_path): | |
| zh_en_model_path = 'Helsinki-NLP/opus-mt-zh-en' | |
| if zh_en_model is None: | |
| zh_en_model = AutoModelForSeq2SeqLM.from_pretrained(zh_en_model_path).eval() | |
| zh_en_tokenizer = AutoTokenizer.from_pretrained(zh_en_model_path, padding=True, truncation=True) | |
| zh_en_model.to("cuda" if torch.cuda.is_available() else "cpu") | |
| with torch.no_grad(): | |
| encoded = zh_en_tokenizer([text], return_tensors="pt") | |
| encoded.to(zh_en_model.device) | |
| sequences = zh_en_model.generate(**encoded) | |
| return zh_en_tokenizer.batch_decode(sequences, skip_special_tokens=True)[0] | |
| # Decorator to flatten the tree directly into the function arguments | |
| class ChinesePromptTranslate(Transformer): | |
| def sentence(self, *args): | |
| return ", ".join(args) | |
| def phrase(self, *args): | |
| return "".join(args) | |
| def emphasis(self, *args): | |
| # Reconstruct the emphasis with translated content | |
| return "(" + "".join(args) + ")" | |
| def weak_emphasis(self, *args): | |
| print('weak_emphasis:', args) | |
| return "[" + "".join(args) + "]" | |
| def embedding(self, *args): | |
| print('prompt embedding', args[0]) | |
| if len(args) == 1: | |
| embedding_name = str(args[0]) | |
| return f"embedding:{embedding_name}" | |
| elif len(args) > 1: | |
| embedding_name, *numbers = args | |
| if len(numbers) == 2: | |
| return f"embedding:{embedding_name}:{numbers[0]}:{numbers[1]}" | |
| elif len(numbers) == 1: | |
| return f"embedding:{embedding_name}:{numbers[0]}" | |
| else: | |
| return f"embedding:{embedding_name}" | |
| def lora(self, *args): | |
| if len(args) == 1: | |
| return f"<lora:{args[0]}>" | |
| elif len(args) > 1: | |
| # print('lora', args) | |
| _, loar_name, *numbers = args | |
| loar_name = str(loar_name).strip() | |
| if len(numbers) == 2: | |
| return f"<lora:{loar_name}:{numbers[0]}:{numbers[1]}>" | |
| elif len(numbers) == 1: | |
| return f"<lora:{loar_name}:{numbers[0]}>" | |
| else: | |
| return f"<lora:{loar_name}>" | |
| def weight(self, word, number): | |
| translated_word = translate(str(word)).rstrip('.') | |
| return f"({translated_word}:{str(number).strip()})" | |
| def schedule(self, *args): | |
| print('prompt schedule', args) | |
| data = [str(arg).strip() for arg in args] | |
| return f"[{':'.join(data)}]" | |
| def word(self, word): | |
| # Translate each word using the dictionary | |
| word = str(word) | |
| match_cn = re.search(r'@.*?@', word) | |
| if re.search(r'__.*?__', word): | |
| return word.rstrip('.') | |
| elif match_cn: | |
| chinese = match_cn.group() | |
| before = word.split('@', 1) | |
| before = before[0] if len(before) > 0 else '' | |
| before = translate(str(before)).rstrip('.') if before else '' | |
| after = word.rsplit('@', 1) | |
| after = after[len(after)-1] if len(after) > 1 else '' | |
| after = translate(after).rstrip('.') if after else '' | |
| return before + chinese.replace('@', '').rstrip('.') + after | |
| elif detect_language(word) == "cn": | |
| return translate(word).rstrip('.') | |
| else: | |
| return word.rstrip('.') | |
| #定义Prompt文法 | |
| grammar = """ | |
| start: sentence | |
| sentence: phrase ("," phrase)* | |
| phrase: emphasis | weight | word | lora | embedding | schedule | |
| emphasis: "(" sentence ")" -> emphasis | |
| | "[" sentence "]" -> weak_emphasis | |
| weight: "(" word ":" NUMBER ")" | |
| schedule: "[" word ":" word ":" NUMBER "]" | |
| lora: "<" WORD ":" WORD (":" NUMBER)? (":" NUMBER)? ">" | |
| embedding: "embedding" ":" WORD (":" NUMBER)? (":" NUMBER)? | |
| word: WORD | |
| NUMBER: /\s*-?\d+(\.\d+)?\s*/ | |
| WORD: /[^,:\(\)\[\]<>]+/ | |
| """ | |
| def zh_to_en(text): | |
| global zh_en_model_path, zh_en_model, zh_en_tokenizer | |
| # 进度条 | |
| pbar = comfy.utils.ProgressBar(len(text) + 1) | |
| texts = [correct_prompt_syntax(t) for t in text] | |
| install_package('sentencepiece', '0.2.0') | |
| if not os.path.exists(zh_en_model_path): | |
| zh_en_model_path = 'Helsinki-NLP/opus-mt-zh-en' | |
| if zh_en_model is None: | |
| zh_en_model = AutoModelForSeq2SeqLM.from_pretrained(zh_en_model_path).eval() | |
| zh_en_tokenizer = AutoTokenizer.from_pretrained(zh_en_model_path, padding=True, truncation=True) | |
| zh_en_model.to("cuda" if torch.cuda.is_available() else "cpu") | |
| prompt_result = [] | |
| en_texts = [] | |
| for t in texts: | |
| if t: | |
| # translated_text = translated_word = translate(zh_en_tokenizer,zh_en_model,str(t)) | |
| parser = Lark(grammar, start="start", parser="lalr", transformer=ChinesePromptTranslate()) | |
| # print('t',t) | |
| result = parser.parse(t).children | |
| # print('en_result',result) | |
| # en_text=translate(zh_en_tokenizer,zh_en_model,text_without_syntax) | |
| en_texts.append(result[0]) | |
| zh_en_model.to('cpu') | |
| # print("test en_text", en_texts) | |
| # en_text.to("cuda" if torch.cuda.is_available() else "cpu") | |
| pbar.update(1) | |
| for t in en_texts: | |
| prompt_result.append(t) | |
| pbar.update(1) | |
| # print('prompt_result', prompt_result, ) | |
| if len(prompt_result) == 0: | |
| prompt_result = [""] | |
| return prompt_result |