BFF / app.py
JIAFENG7's picture
Update App
741bb73
import sys
import fire
import gradio as gr
import json
import torch
from peft import PeftModel
from transformers import GenerationConfig, AutoModel, AutoTokenizer
import mdtex2html
import re
from textwrap import indent
from huggingface_hub import login
access_token_read = "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"
access_token_write = "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"
login(token = access_token_read)
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
try:
if torch.backends.mps.is_available():
device = "mps"
except: # noqa: E722
pass
with open("node_map.json") as json_file:
data = json.load(json_file)
node_type_map = data.get('node_type_map')
node_name_map = data.get('node_name_map')
def main(
base_model: str = "THUDM/chatglm-6b",
lora_weights: str = "JIAFENG7/BFF-workflow-glm",
share_gradio: bool = True,
):
assert (
base_model
), "Please specify a --base_model, e.g. --base_model='THUDM/chatglm-6b'"
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
if device == "cuda":
torch.set_default_tensor_type(torch.cuda.HalfTensor)
model = AutoModel.from_pretrained(base_model, trust_remote_code=True).half().cuda()
else:
model = AutoModel.from_pretrained(base_model, trust_remote_code=True).float()
model = PeftModel.from_pretrained(model, lora_weights, torch_dtype=torch.float16)
model.eval()
def postprocess(self, y):
if y is None:
return []
for i, (message, response) in enumerate(y):
y[i] = (
None if message is None else mdtex2html.convert(message),
None if response is None else mdtex2html.convert(response) if isinstance(response,
str) else format_json(
json.dumps(response, indent=4, sort_keys=True)),
)
return y
gr.Chatbot.postprocess = postprocess
def parse_text(text):
lines = text.split("\n")
lines = [line for line in lines if line != ""]
count = 0
for i, line in enumerate(lines):
print('[line]:', line)
if "```" in line:
count += 1
items = line.split('`')
if count % 2 == 1:
lines[i] = f'<pre><code class="language-{items[-1]}">'
else:
lines[i] = f'<br></code></pre>'
else:
if i > 0:
# if count % 2 == 1:
line = line.replace("`", "\`")
line = line.replace("<", "&lt;")
line = line.replace(">", "&gt;")
line = line.replace(" ", "&nbsp;")
line = line.replace("*", "&ast;")
line = line.replace("_", "&lowbar;")
line = line.replace("-", "&#45;")
line = line.replace(".", "&#46;")
line = line.replace("!", "&#33;")
line = line.replace("(", "&#40;")
line = line.replace(")", "&#41;")
line = line.replace("$", "&#36;")
lines[i] = ("<br>" if i > 0 else "") + line
text = "".join(lines)
return text
def format_json(json_data):
key_pattern = re.compile("\"(.*)\"(?=:)")
value_pattern = re.compile("(?<=: )(\"(.*)\"|\\d+)")
a = re.sub(key_pattern, lambda m: f'<span class="json-key">{m.group(1)}</span>', json_data)
b = re.sub(value_pattern, lambda m: f'<span class="json-value">{m.group(1)}</span>', a)
return f'<pre id="json-code">{b}</pre>'
def parse_parallel_str(instruction, nodes, connections, type):
result = {}
if instruction.find("and") == -1:
match_node = re.search(r'([a-zA-Z]+)(\d?)', instruction.strip())
if match_node:
match_node_name = match_node.group(1)
match_node_version = match_node.group(2)
node_name = node_name_map[match_node_name] + match_node_version
if f"{node_name}" not in connections:
result.update({
f"{node_name}": {
"type": "main",
"index": 0,
"main": [[]]
}
})
else:
result.update({
f"{node_name}": connections.get(f"{node_name}")
})
else:
merge_name = '-'.join([node.strip() for node in instruction.split('and')])
for i, node in enumerate(instruction.split('and')):
match_node = re.search(r'([a-zA-Z]+)(\d?)', node.strip())
if match_node:
match_node_name = match_node.group(1)
match_node_version = match_node.group(2)
node_name = node_name_map[match_node_name] + match_node_version
result.update({
f"{node_name}": {
"main": [[{
"node": f"Merge-Node-{merge_name}",
"type": "main",
"index": i
}] if type == 'input' else []]
},
})
if type == 'input':
if f"Merge-Node-{merge_name}" not in connections:
result.update({
f"Merge-Node-{merge_name}": {
"main": [[]]
}
})
else:
result.update({
f"Merge-Node-{merge_name}": connections.get(f"Merge-Node-{merge_name}")
})
nodes.append({
"node": "BFF-Merge",
"name": f"Merge-Node-{merge_name}",
"param": "mode: chooseBranch\noutput: empty"
})
return result
def parse_serial_str(input_nodes, output_nodes):
try:
if len(list(input_nodes.keys())) != 1:
for input_key, input_value in input_nodes.items():
if not re.match(r"Merge-Node", input_key):
continue
for output_key, output_value in output_nodes.items():
input_value.get("main")[0].append({
"index": 0,
"node": output_key,
"type": "main"
})
else:
for input_key, input_value in input_nodes.items():
for output_key, output_value in output_nodes.items():
input_value.get("main")[0].append({
"index": 0,
"node": output_key,
"type": "main"
})
return input_nodes
except:
pass
def normalize_input(instruction, input):
nodes = []
connections = {}
# get nodes
match_node_pattern = f'({data.get("node_name")})(\d)'
nodes_from_instruction = re.findall(fr"{match_node_pattern}", str(input))
node_list = ["".join(node) for node in nodes_from_instruction] if nodes_from_instruction else []
for i, node in enumerate(node_list):
start_node_line = re.search(fr"#{node}(.*)\n", input)
start_node_index = 0
if not start_node_line:
pass
else:
start_node_line = start_node_line.group(0)
start_node_index = input.index(start_node_line) + len(start_node_line)
match_node = re.search(r'([a-zA-Z]+)(\d?)', node.strip())
if match_node:
match_node_name = match_node.group(1)
match_node_version = match_node.group(2)
next_node_index = node_list.index(node) + 1
end_node_index = len(input) - 1 if next_node_index == len(node_list) else input.find(
f"#{node_list[next_node_index]}")
params = re.split(rf"{match_node_pattern}.*?\n", str(input[start_node_index:end_node_index]))
nodes.append({
"node": f"{node_type_map[match_node_name]}",
"name": f"{node_name_map[match_node_name]}{match_node_version}",
"param": list(filter(lambda x: x.strip() != '', params))[0],
})
else:
pass
# get connections
for split_comma_node in instruction.split(','):
if split_comma_node.find("output") == -1:
continue
[split_input, split_output] = split_comma_node.split("output")
merge_input = parse_parallel_str(split_input, nodes, connections, 'input')
merge_output = parse_parallel_str(split_output, nodes, connections, 'output')
connections.update(parse_serial_str(merge_input, merge_output))
return nodes, connections
def predict(
instruction,
input_content,
temperature,
top_p,
top_k,
max_new_tokens,
):
input_text = f"Instruction: {instruction}\n"
if input_content is not None:
input_text += f"Input: {input_content}\n"
input_text += "Answer: "
print('---', input_text)
ids = tokenizer.encode(input_text)
input_ids = torch.LongTensor([ids])
inputs = input_ids.to(device)
output = model.generate(
input_ids=inputs,
max_length=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p = top_p,
top_k = top_k
)
decode_output = tokenizer.decode(output[0]).split("Answer:")[1]
return decode_output
def evaluate(
instruction,
input_content=None,
temperature=0.1,
top_p=0.75,
top_k=40,
max_new_tokens=256,
history = []
):
# import ipdb; ipdb.set_trace()
with torch.autocast(device):
merged_nodes, connections = normalize_input(instruction, input_content)
nodes = [];
if len(merged_nodes) == 0:
output = predict(
instruction,
input_content,
temperature,
top_p,
top_k,
max_new_tokens,
)
print('[normal output]:', output)
else:
for node_data in merged_nodes:
merged_instruction = node_data["node"]
merged_input = node_data["param"] + "\n" + f"name: \"{node_data['name']}\""
output = predict(
merged_instruction,
merged_input,
temperature,
top_p,
top_k,
max_new_tokens,
)
print('[node output]:', output)
nodes.append(json.loads(output[output.find("{"):]) if output else [{"error": "errorFormat"}])
print('[nodes output]:', nodes)
request = f"Summary: \n" + \
f"{indent(instruction, ' ')} \n" + \
f"Details: \n" + \
f"{indent(input_content, ' ') if input_content is not None else ''}"
response = {'nodes': nodes, 'connections': connections} if len(merged_nodes) > 0 else output
history.append((parse_text(request), response))
return history, history
gr.Interface(
fn=evaluate,
inputs=[
gr.components.Textbox(
lines=2,
label="Summary",
placeholder="Tell me the task you want to do with bff.",
),
gr.components.Textbox(lines=2, label="Details", placeholder="""Example:
#service1 gets ubtc_trip_in_aidsid key from cookies
serviceName: 'userInfoReportService'
serviceCode: '18768'
method: 'reportOrderAttribution'
#service2 transfers all of cookies
serviceName: 'userInfoReportService'
serviceCode: '18768'
"""),
'state'
# gr.components.Slider(
# minimum=0, maximum=1, value=0.1, label="Temperature",
# info="Controls randomness, higher values increase diversity."
# ),
# gr.components.Slider(
# minimum=0, maximum=1, value=0.75, label="Top p",
# info="The cumulative probability cutoff for token selection. Lower values mean sampling from a smaller, more top-weighted nucleus."
# ),
# gr.components.Slider(
# minimum=0, maximum=100, step=1, value=40, label="Top k",
# info="Sample from the k most likely next tokens at each step. Lower k focuses on higher probability tokens."
# ),
# # gr.components.Slider(
# minimum=1, maximum=4, step=1, value=4, label="Beams"
# ),
# gr.components.Slider(
# minimum=1, maximum=2000, step=1, value=512, label="Max tokens"
# )
],
# outputs=['json'],
outputs=[
gr.Chatbot(),
'state'
],
examples=[
["How many nodes in tripflow?"],
["How many parameters in Cargo?"],
["How many parameters in shark?"]
],
title="Workflow BFF Chat",
description = """
<span id="desc" style="display: block">The bot was trained to answer questions based on tripflow. Ask anything!</p>
<img src="https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcQ47PUVxQeg6-zfBavc-i_1oeUlVR90LsFvyMfsVItHwRdFhqA4h3vY3GlobNNLWAWWOGk&usqp=CAU" id="desc-img" />
""",
css="style.css"
).launch(share=share_gradio)
"""
# testing code for readme
for instruction in [
"What is the n8n",
"Tell me about the president of Mexico in 2019.",
"Tell me about the king of France in 2019.",
"List all Canadian provinces in alphabetical order.",
"Write a Python program that prints the first 10 Fibonacci numbers.",
"Write a program that prints the numbers from 1 to 100. But for multiples of three print 'Fizz' instead of the number and for the multiples of five print 'Buzz'. For numbers which are multiples of both three and five print 'FizzBuzz'.", # noqa: E501
"Tell me five words that rhyme with 'shock'.",
"Translate the sentence 'I have no mouth but I must scream' into Spanish.",
"Count up from 1 to 500.",
]:
print("Instruction:", instruction)
print("Response:", evaluate(instruction))
print()
"""
if __name__ == "__main__":
fire.Fire(main)