| import sys |
| import fire |
| import gradio as gr |
| import json |
| import torch |
| from peft import PeftModel |
| from transformers import GenerationConfig, AutoModel, AutoTokenizer |
| import mdtex2html |
| import re |
| from textwrap import indent |
| from huggingface_hub import login |
| access_token_read = "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" |
| access_token_write = "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" |
| login(token = access_token_read) |
|
|
| if torch.cuda.is_available(): |
| device = "cuda" |
| else: |
| device = "cpu" |
|
|
| try: |
| if torch.backends.mps.is_available(): |
| device = "mps" |
| except: |
| pass |
|
|
| with open("node_map.json") as json_file: |
| data = json.load(json_file) |
| node_type_map = data.get('node_type_map') |
| node_name_map = data.get('node_name_map') |
|
|
| def main( |
| base_model: str = "THUDM/chatglm-6b", |
| lora_weights: str = "JIAFENG7/BFF-workflow-glm", |
| share_gradio: bool = True, |
| ): |
| assert ( |
| base_model |
| ), "Please specify a --base_model, e.g. --base_model='THUDM/chatglm-6b'" |
|
|
| tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) |
| if device == "cuda": |
| torch.set_default_tensor_type(torch.cuda.HalfTensor) |
| model = AutoModel.from_pretrained(base_model, trust_remote_code=True).half().cuda() |
| else: |
| model = AutoModel.from_pretrained(base_model, trust_remote_code=True).float() |
| model = PeftModel.from_pretrained(model, lora_weights, torch_dtype=torch.float16) |
| model.eval() |
|
|
| def postprocess(self, y): |
| if y is None: |
| return [] |
| for i, (message, response) in enumerate(y): |
| y[i] = ( |
| None if message is None else mdtex2html.convert(message), |
| None if response is None else mdtex2html.convert(response) if isinstance(response, |
| str) else format_json( |
| json.dumps(response, indent=4, sort_keys=True)), |
| ) |
| return y |
|
|
| gr.Chatbot.postprocess = postprocess |
|
|
| def parse_text(text): |
| lines = text.split("\n") |
| lines = [line for line in lines if line != ""] |
| count = 0 |
| for i, line in enumerate(lines): |
| print('[line]:', line) |
| if "```" in line: |
| count += 1 |
| items = line.split('`') |
| if count % 2 == 1: |
| lines[i] = f'<pre><code class="language-{items[-1]}">' |
| else: |
| lines[i] = f'<br></code></pre>' |
| else: |
| if i > 0: |
| |
| line = line.replace("`", "\`") |
| line = line.replace("<", "<") |
| line = line.replace(">", ">") |
| line = line.replace(" ", " ") |
| line = line.replace("*", "*") |
| line = line.replace("_", "_") |
| line = line.replace("-", "-") |
| line = line.replace(".", ".") |
| line = line.replace("!", "!") |
| line = line.replace("(", "(") |
| line = line.replace(")", ")") |
| line = line.replace("$", "$") |
| lines[i] = ("<br>" if i > 0 else "") + line |
| text = "".join(lines) |
| return text |
|
|
| def format_json(json_data): |
| key_pattern = re.compile("\"(.*)\"(?=:)") |
| value_pattern = re.compile("(?<=: )(\"(.*)\"|\\d+)") |
| a = re.sub(key_pattern, lambda m: f'<span class="json-key">{m.group(1)}</span>', json_data) |
| b = re.sub(value_pattern, lambda m: f'<span class="json-value">{m.group(1)}</span>', a) |
|
|
| return f'<pre id="json-code">{b}</pre>' |
|
|
| def parse_parallel_str(instruction, nodes, connections, type): |
| result = {} |
| if instruction.find("and") == -1: |
| match_node = re.search(r'([a-zA-Z]+)(\d?)', instruction.strip()) |
| if match_node: |
| match_node_name = match_node.group(1) |
| match_node_version = match_node.group(2) |
| node_name = node_name_map[match_node_name] + match_node_version |
|
|
| if f"{node_name}" not in connections: |
| result.update({ |
| f"{node_name}": { |
| "type": "main", |
| "index": 0, |
| "main": [[]] |
| } |
| }) |
| else: |
| result.update({ |
| f"{node_name}": connections.get(f"{node_name}") |
| }) |
| else: |
| merge_name = '-'.join([node.strip() for node in instruction.split('and')]) |
| for i, node in enumerate(instruction.split('and')): |
| match_node = re.search(r'([a-zA-Z]+)(\d?)', node.strip()) |
| if match_node: |
| match_node_name = match_node.group(1) |
| match_node_version = match_node.group(2) |
| node_name = node_name_map[match_node_name] + match_node_version |
|
|
| result.update({ |
| f"{node_name}": { |
| "main": [[{ |
| "node": f"Merge-Node-{merge_name}", |
| "type": "main", |
| "index": i |
| }] if type == 'input' else []] |
| }, |
| }) |
| if type == 'input': |
| if f"Merge-Node-{merge_name}" not in connections: |
| result.update({ |
| f"Merge-Node-{merge_name}": { |
| "main": [[]] |
| } |
| }) |
| else: |
| result.update({ |
| f"Merge-Node-{merge_name}": connections.get(f"Merge-Node-{merge_name}") |
| }) |
| nodes.append({ |
| "node": "BFF-Merge", |
| "name": f"Merge-Node-{merge_name}", |
| "param": "mode: chooseBranch\noutput: empty" |
| }) |
|
|
| return result |
|
|
| def parse_serial_str(input_nodes, output_nodes): |
| try: |
| if len(list(input_nodes.keys())) != 1: |
| for input_key, input_value in input_nodes.items(): |
| if not re.match(r"Merge-Node", input_key): |
| continue |
| for output_key, output_value in output_nodes.items(): |
| input_value.get("main")[0].append({ |
| "index": 0, |
| "node": output_key, |
| "type": "main" |
| }) |
| else: |
| for input_key, input_value in input_nodes.items(): |
| for output_key, output_value in output_nodes.items(): |
| input_value.get("main")[0].append({ |
| "index": 0, |
| "node": output_key, |
| "type": "main" |
| }) |
|
|
| return input_nodes |
| except: |
| pass |
|
|
| def normalize_input(instruction, input): |
| nodes = [] |
| connections = {} |
|
|
| |
| match_node_pattern = f'({data.get("node_name")})(\d)' |
| nodes_from_instruction = re.findall(fr"{match_node_pattern}", str(input)) |
| node_list = ["".join(node) for node in nodes_from_instruction] if nodes_from_instruction else [] |
| for i, node in enumerate(node_list): |
| start_node_line = re.search(fr"#{node}(.*)\n", input) |
| start_node_index = 0 |
| if not start_node_line: |
| pass |
| else: |
| start_node_line = start_node_line.group(0) |
| start_node_index = input.index(start_node_line) + len(start_node_line) |
| match_node = re.search(r'([a-zA-Z]+)(\d?)', node.strip()) |
| if match_node: |
| match_node_name = match_node.group(1) |
| match_node_version = match_node.group(2) |
| next_node_index = node_list.index(node) + 1 |
| end_node_index = len(input) - 1 if next_node_index == len(node_list) else input.find( |
| f"#{node_list[next_node_index]}") |
| params = re.split(rf"{match_node_pattern}.*?\n", str(input[start_node_index:end_node_index])) |
| nodes.append({ |
| "node": f"{node_type_map[match_node_name]}", |
| "name": f"{node_name_map[match_node_name]}{match_node_version}", |
| "param": list(filter(lambda x: x.strip() != '', params))[0], |
| }) |
| else: |
| pass |
|
|
| |
| for split_comma_node in instruction.split(','): |
| if split_comma_node.find("output") == -1: |
| continue |
| [split_input, split_output] = split_comma_node.split("output") |
| merge_input = parse_parallel_str(split_input, nodes, connections, 'input') |
| merge_output = parse_parallel_str(split_output, nodes, connections, 'output') |
| connections.update(parse_serial_str(merge_input, merge_output)) |
|
|
| return nodes, connections |
|
|
| def predict( |
| instruction, |
| input_content, |
| temperature, |
| top_p, |
| top_k, |
| max_new_tokens, |
| ): |
| input_text = f"Instruction: {instruction}\n" |
| if input_content is not None: |
| input_text += f"Input: {input_content}\n" |
| input_text += "Answer: " |
| print('---', input_text) |
| ids = tokenizer.encode(input_text) |
| input_ids = torch.LongTensor([ids]) |
| inputs = input_ids.to(device) |
| output = model.generate( |
| input_ids=inputs, |
| max_length=max_new_tokens, |
| do_sample=True, |
| temperature=temperature, |
| top_p = top_p, |
| top_k = top_k |
| ) |
| decode_output = tokenizer.decode(output[0]).split("Answer:")[1] |
| |
| return decode_output |
| |
|
|
| def evaluate( |
| instruction, |
| input_content=None, |
| temperature=0.1, |
| top_p=0.75, |
| top_k=40, |
| max_new_tokens=256, |
| history = [] |
| ): |
| |
|
|
| with torch.autocast(device): |
| merged_nodes, connections = normalize_input(instruction, input_content) |
| nodes = []; |
|
|
| if len(merged_nodes) == 0: |
| output = predict( |
| instruction, |
| input_content, |
| temperature, |
| top_p, |
| top_k, |
| max_new_tokens, |
| ) |
| print('[normal output]:', output) |
| else: |
| for node_data in merged_nodes: |
| merged_instruction = node_data["node"] |
| merged_input = node_data["param"] + "\n" + f"name: \"{node_data['name']}\"" |
| output = predict( |
| merged_instruction, |
| merged_input, |
| temperature, |
| top_p, |
| top_k, |
| max_new_tokens, |
| ) |
| print('[node output]:', output) |
| nodes.append(json.loads(output[output.find("{"):]) if output else [{"error": "errorFormat"}]) |
| print('[nodes output]:', nodes) |
|
|
| request = f"Summary: \n" + \ |
| f"{indent(instruction, ' ')} \n" + \ |
| f"Details: \n" + \ |
| f"{indent(input_content, ' ') if input_content is not None else ''}" |
| response = {'nodes': nodes, 'connections': connections} if len(merged_nodes) > 0 else output |
|
|
| history.append((parse_text(request), response)) |
|
|
| return history, history |
|
|
| gr.Interface( |
| fn=evaluate, |
| inputs=[ |
| gr.components.Textbox( |
| lines=2, |
| label="Summary", |
| placeholder="Tell me the task you want to do with bff.", |
| ), |
| gr.components.Textbox(lines=2, label="Details", placeholder="""Example: |
| #service1 gets ubtc_trip_in_aidsid key from cookies |
| serviceName: 'userInfoReportService' |
| serviceCode: '18768' |
| method: 'reportOrderAttribution' |
| #service2 transfers all of cookies |
| serviceName: 'userInfoReportService' |
| serviceCode: '18768' |
| """), |
| 'state' |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| ], |
| |
| outputs=[ |
| gr.Chatbot(), |
| 'state' |
| ], |
| examples=[ |
| ["How many nodes in tripflow?"], |
| ["How many parameters in Cargo?"], |
| ["How many parameters in shark?"] |
| ], |
| title="Workflow BFF Chat", |
| description = """ |
| <span id="desc" style="display: block">The bot was trained to answer questions based on tripflow. Ask anything!</p> |
| <img src="https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcQ47PUVxQeg6-zfBavc-i_1oeUlVR90LsFvyMfsVItHwRdFhqA4h3vY3GlobNNLWAWWOGk&usqp=CAU" id="desc-img" /> |
| """, |
| css="style.css" |
| ).launch(share=share_gradio) |
|
|
| """ |
| # testing code for readme |
| for instruction in [ |
| "What is the n8n", |
| "Tell me about the president of Mexico in 2019.", |
| "Tell me about the king of France in 2019.", |
| "List all Canadian provinces in alphabetical order.", |
| "Write a Python program that prints the first 10 Fibonacci numbers.", |
| "Write a program that prints the numbers from 1 to 100. But for multiples of three print 'Fizz' instead of the number and for the multiples of five print 'Buzz'. For numbers which are multiples of both three and five print 'FizzBuzz'.", # noqa: E501 |
| "Tell me five words that rhyme with 'shock'.", |
| "Translate the sentence 'I have no mouth but I must scream' into Spanish.", |
| "Count up from 1 to 500.", |
| ]: |
| print("Instruction:", instruction) |
| print("Response:", evaluate(instruction)) |
| print() |
| """ |
|
|
| if __name__ == "__main__": |
| fire.Fire(main) |
|
|