| """This file contains the SGL programs used for unit testing.""" | |
| import json | |
| import re | |
| import time | |
| import numpy as np | |
| import sglang as sgl | |
| from sglang.utils import download_and_cache_file, read_jsonl | |
| def test_few_shot_qa(): | |
| def few_shot_qa(s, question): | |
| s += "The following are questions with answers.\n\n" | |
| s += "Q: What is the capital of France?\n" | |
| s += "A: Paris\n" | |
| s += "Q: What is the capital of Germany?\n" | |
| s += "A: Berlin\n" | |
| s += "Q: What is the capital of Italy?\n" | |
| s += "A: Rome\n" | |
| s += "Q: " + question + "\n" | |
| s += "A:" + sgl.gen("answer", stop="\n", temperature=0) | |
| ret = few_shot_qa.run(question="What is the capital of the United States?") | |
| assert "washington" in ret["answer"].strip().lower(), f"answer: {ret['answer']}" | |
| rets = few_shot_qa.run_batch( | |
| [ | |
| {"question": "What is the capital of Japan?"}, | |
| {"question": "What is the capital of the United Kingdom?"}, | |
| {"question": "What is the capital city of China?"}, | |
| ], | |
| temperature=0.1, | |
| ) | |
| answers = [x["answer"].strip().lower() for x in rets] | |
| assert answers == ["tokyo", "london", "beijing"], f"answers: {answers}" | |
| def test_mt_bench(): | |
| def answer_mt_bench(s, question_1, question_2): | |
| s += sgl.system("You are a helpful assistant.") | |
| s += sgl.user(question_1) | |
| s += sgl.assistant(sgl.gen("answer_1")) | |
| with s.user(): | |
| s += question_2 | |
| with s.assistant(): | |
| s += sgl.gen("answer_2") | |
| question_1 = "Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences and must-see attractions." | |
| question_2 = ( | |
| "Rewrite your previous response. Start every sentence with the letter A." | |
| ) | |
| ret = answer_mt_bench.run( | |
| question_1=question_1, question_2=question_2, temperature=0.7, max_new_tokens=64 | |
| ) | |
| assert len(ret.messages()) in [4, 5] | |
| def test_select(check_answer): | |
| def true_or_false(s, statement): | |
| s += "Determine whether the statement below is True, False, or Unknown.\n" | |
| s += "Statement: The capital of France is Pairs.\n" | |
| s += "Answer: True\n" | |
| s += "Statement: " + statement + "\n" | |
| s += "Answer:" + sgl.select("answer", ["True", "False", "Unknown"]) | |
| ret = true_or_false.run( | |
| statement="The capital of Germany is Berlin.", | |
| ) | |
| if check_answer: | |
| assert ret["answer"] == "True", ret.text() | |
| else: | |
| assert ret["answer"] in ["True", "False", "Unknown"] | |
| ret = true_or_false.run( | |
| statement="The capital of Canada is Tokyo.", | |
| ) | |
| if check_answer: | |
| assert ret["answer"] == "False", ret.text() | |
| else: | |
| assert ret["answer"] in ["True", "False", "Unknown"] | |
| ret = true_or_false.run( | |
| statement="Purple is a better color than green.", | |
| ) | |
| if check_answer: | |
| assert ret["answer"] == "Unknown", ret.text() | |
| else: | |
| assert ret["answer"] in ["True", "False", "Unknown"] | |
| def test_decode_int(): | |
| def decode_int(s): | |
| s += "The number of hours in a day is " + sgl.gen_int("hours") + "\n" | |
| s += "The number of days in a year is " + sgl.gen_int("days") + "\n" | |
| ret = decode_int.run(temperature=0.1) | |
| assert int(ret["hours"]) == 24, ret.text() | |
| assert int(ret["days"]) == 365, ret.text() | |
| def test_decode_json_regex(): | |
| def decode_json(s): | |
| from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STR | |
| s += "Generate a JSON object to describe the basic city information of Paris.\n" | |
| s += "Here are the JSON object:\n" | |
| # NOTE: we recommend using dtype gen or whole regex string to control the output | |
| with s.var_scope("json_output"): | |
| s += "{\n" | |
| s += ' "name": ' + sgl.gen(regex=REGEX_STR) + ",\n" | |
| s += ' "population": ' + sgl.gen(regex=REGEX_INT, stop=[" ", "\n"]) + ",\n" | |
| s += ' "area": ' + sgl.gen(regex=REGEX_INT, stop=[" ", "\n"]) + ",\n" | |
| s += ' "latitude": ' + sgl.gen(regex=REGEX_FLOAT, stop=[" ", "\n"]) + "\n" | |
| s += "}" | |
| ret = decode_json.run(temperature=0.0) | |
| try: | |
| js_obj = json.loads(ret["json_output"]) | |
| except json.decoder.JSONDecodeError: | |
| print("JSONDecodeError", ret["json_output"]) | |
| raise | |
| assert isinstance(js_obj["name"], str) | |
| assert isinstance(js_obj["population"], int) | |
| def test_decode_json(): | |
| def decode_json(s): | |
| s += "Generate a JSON object to describe the basic city information of Paris.\n" | |
| with s.var_scope("json_output"): | |
| s += "{\n" | |
| s += ' "name": ' + sgl.gen_string() + ",\n" | |
| s += ' "population": ' + sgl.gen_int() + ",\n" | |
| s += ' "area": ' + sgl.gen(dtype=int) + ",\n" | |
| s += ' "country": ' + sgl.gen_string() + ",\n" | |
| s += ' "timezone": ' + sgl.gen(dtype=str) + "\n" | |
| s += "}" | |
| ret = decode_json.run(max_new_tokens=64) | |
| try: | |
| js_obj = json.loads(ret["json_output"]) | |
| except json.decoder.JSONDecodeError: | |
| print("JSONDecodeError", ret["json_output"]) | |
| raise | |
| assert isinstance(js_obj["name"], str) | |
| assert isinstance(js_obj["population"], int) | |
| def test_expert_answer(check_answer=True): | |
| def expert_answer(s, question): | |
| s += "Question: " + question + "\n" | |
| s += ( | |
| "A good person to answer this question is" | |
| + sgl.gen("expert", stop=[".", "\n"]) | |
| + ".\n" | |
| ) | |
| s += ( | |
| "For example," | |
| + s["expert"] | |
| + " would answer that " | |
| + sgl.gen("answer", stop=".") | |
| + "." | |
| ) | |
| ret = expert_answer.run(question="What is the capital of France?", temperature=0.1) | |
| if check_answer: | |
| assert "paris" in ret.text().lower(), f"Answer: {ret.text()}" | |
| def test_tool_use(): | |
| def calculate(expression): | |
| return f"{eval(expression)}" | |
| def tool_use(s, lhs, rhs): | |
| s += "Please perform computations using a calculator. You can use calculate(expression) to get the results.\n" | |
| s += "For example,\ncalculate(1+2)=3\ncalculate(3*4)=12\n" | |
| s += "Question: What is the product of " + str(lhs) + " and " + str(rhs) + "?\n" | |
| s += ( | |
| "Answer: The answer is calculate(" | |
| + sgl.gen("expression", stop=")") | |
| + ") = " | |
| ) | |
| with s.var_scope("answer"): | |
| s += calculate(s["expression"]) | |
| lhs, rhs = 257, 983 | |
| ret = tool_use(lhs=lhs, rhs=rhs, temperature=0) | |
| assert int(ret["answer"]) == lhs * rhs | |
| def test_react(): | |
| def react(s, question): | |
| s += """ | |
| Question: Which country does the founder of Microsoft live in? | |
| Thought 1: I need to search for the founder of Microsoft. | |
| Action 1: Search [Founder of Microsoft]. | |
| Observation 1: The founder of Microsoft is Bill Gates. | |
| Thought 2: I need to search for the country where Bill Gates lives in. | |
| Action 2: Search [Where does Bill Gates live]. | |
| Observation 2: Bill Gates lives in the United States. | |
| Thought 3: The answer is the United States. | |
| Action 3: Finish [United States].\n | |
| """ | |
| s += "Question: " + question + "\n" | |
| for i in range(1, 5): | |
| s += f"Thought {i}:" + sgl.gen(stop=[".", "\n"]) + ".\n" | |
| s += f"Action {i}: " + sgl.select(f"action_{i}", ["Search", "Finish"]) | |
| if s[f"action_{i}"] == "Search": | |
| s += " [" + sgl.gen(stop="]") + "].\n" | |
| s += f"Observation {i}:" + sgl.gen(stop=[".", "\n"]) + ".\n" | |
| else: | |
| s += " [" + sgl.gen("answer", stop="]") + "].\n" | |
| break | |
| ret = react.run( | |
| question="What country does the creator of Linux live in?", | |
| temperature=0.1, | |
| ) | |
| answer = ret["answer"].lower() | |
| assert "finland" in answer or "states" in answer | |
| def test_parallel_decoding(): | |
| max_tokens = 64 | |
| fork_size = 5 | |
| def parallel_decoding(s, topic): | |
| s += "Act as a helpful assistant.\n" | |
| s += "USER: Give some tips for " + topic + ".\n" | |
| s += ( | |
| "ASSISTANT: Okay. Here are " | |
| + str(fork_size) | |
| + " concise tips, each under 8 words:\n" | |
| ) | |
| # Generate skeleton | |
| for i in range(1, 1 + fork_size): | |
| s += f"{i}." + sgl.gen(max_tokens=16, stop=[".", "\n"]) + ".\n" | |
| # Generate detailed tips | |
| forks = s.fork(fork_size) | |
| for i in range(fork_size): | |
| forks[ | |
| i | |
| ] += f"Now, I expand tip {i+1} into a detailed paragraph:\nTip {i+1}:" | |
| forks[i] += sgl.gen("detailed_tip", max_tokens, stop=["\n\n"]) | |
| forks.join() | |
| # Concatenate tips and summarize | |
| s += "Here are these tips with detailed explanation:\n" | |
| for i in range(fork_size): | |
| s += f"Tip {i+1}:" + forks[i]["detailed_tip"] + "\n" | |
| s += "\nIn summary," + sgl.gen("summary", max_tokens=512) | |
| ret = parallel_decoding.run(topic="writing a good blog post", temperature=0.3) | |
| assert isinstance(ret["summary"], str) | |
| def test_parallel_encoding(check_answer=True): | |
| max_tokens = 64 | |
| def parallel_encoding(s, question, context_0, context_1, context_2): | |
| s += "USER: I will ask a question based on some statements.\n" | |
| s += "ASSISTANT: Sure. I will give the answer.\n" | |
| s += "USER: Please memorize these statements.\n" | |
| contexts = [context_0, context_1, context_2] | |
| forks = s.fork(len(contexts)) | |
| forks += lambda i: f"Statement {i}: " + contexts[i] + "\n" | |
| forks.join(mode="concate_and_append") | |
| s += "Now, please answer the following question. " "Do not list options." | |
| s += "\nQuestion: " + question + "\n" | |
| s += "ASSISTANT:" + sgl.gen("answer", max_tokens=max_tokens) | |
| ret = parallel_encoding.run( | |
| question="Who is the father of Julian?", | |
| context_0="Ethan is the father of Liam.", | |
| context_1="Noah is the father of Julian.", | |
| context_2="Oliver is the father of Carlos.", | |
| temperature=0, | |
| ) | |
| answer = ret["answer"] | |
| if check_answer: | |
| assert "Noah" in answer | |
| def test_image_qa(): | |
| def image_qa(s, question): | |
| s += sgl.user(sgl.image("example_image.png") + question) | |
| s += sgl.assistant(sgl.gen("answer")) | |
| state = image_qa.run( | |
| question="Please describe this image in simple words.", | |
| temperature=0, | |
| max_new_tokens=64, | |
| ) | |
| assert ( | |
| "taxi" in state.messages()[-1]["content"] | |
| or "car" in state.messages()[-1]["content"] | |
| ), f"{state.messages()[-1]['content']}" | |
| def test_stream(): | |
| def qa(s, question): | |
| s += sgl.system("You are a helpful assistant.") | |
| s += sgl.user(question) | |
| s += sgl.assistant(sgl.gen("answer")) | |
| ret = qa( | |
| question="Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences and must-see attractions.", | |
| stream=True, | |
| ) | |
| out = "" | |
| for chunk in ret.text_iter(): | |
| out += chunk | |
| ret = qa( | |
| question="Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences and must-see attractions.", | |
| stream=True, | |
| ) | |
| out = "" | |
| for chunk in ret.text_iter("answer"): | |
| out += chunk | |
| def test_regex(): | |
| regex = r"((25[0-5]|2[0-4]\d|[01]?\d\d?).){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)" | |
| def regex_gen(s): | |
| s += "Q: What is the IP address of the Google DNS servers?\n" | |
| s += "A: " + sgl.gen( | |
| "answer", | |
| temperature=0, | |
| regex=regex, | |
| ) | |
| state = regex_gen.run() | |
| answer = state["answer"] | |
| assert re.match(regex, answer) | |
| def test_dtype_gen(): | |
| def dtype_gen(s): | |
| s += "Q: What is the full name of DNS?\n" | |
| s += "A: The full names is " + sgl.gen("str_res", dtype=str, stop="\n") + "\n" | |
| s += "Q: Which year was DNS invented?\n" | |
| s += "A: " + sgl.gen("int_res", dtype=int) + "\n" | |
| s += "Q: What is the value of pi?\n" | |
| s += "A: " + sgl.gen("float_res", dtype=float) + "\n" | |
| s += "Q: Is the sky blue?\n" | |
| s += "A: " + sgl.gen("bool_res", dtype=bool) + "\n" | |
| state = dtype_gen.run() | |
| try: | |
| state["int_res"] = int(state["int_res"]) | |
| state["float_res"] = float(state["float_res"]) | |
| state["bool_res"] = bool(state["bool_res"]) | |
| # assert state["str_res"].startswith('"') and state["str_res"].endswith('"') | |
| except ValueError: | |
| print(state) | |
| raise | |
| def test_completion_speculative(): | |
| def gen_character_spec(s): | |
| s += "Construct a character within the following format:\n" | |
| s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n" | |
| s += "\nPlease generate new Name, Birthday and Job.\n" | |
| s += ( | |
| "Name:" | |
| + sgl.gen("name", stop="\n") | |
| + "\nBirthday:" | |
| + sgl.gen("birthday", stop="\n") | |
| ) | |
| s += "\nJob:" + sgl.gen("job", stop="\n") + "\n" | |
| def gen_character_no_spec(s): | |
| s += "Construct a character within the following format:\n" | |
| s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n" | |
| s += "\nPlease generate new Name, Birthday and Job.\n" | |
| s += ( | |
| "Name:" | |
| + sgl.gen("name", stop="\n") | |
| + "\nBirthday:" | |
| + sgl.gen("birthday", stop="\n") | |
| ) | |
| s += "\nJob:" + sgl.gen("job", stop="\n") + "\n" | |
| token_usage = sgl.global_config.default_backend.token_usage | |
| token_usage.reset() | |
| gen_character_spec().sync() | |
| usage_with_spec = token_usage.prompt_tokens | |
| token_usage.reset() | |
| gen_character_no_spec().sync() | |
| usage_with_no_spec = token_usage.prompt_tokens | |
| assert ( | |
| usage_with_spec < usage_with_no_spec | |
| ), f"{usage_with_spec} vs {usage_with_no_spec}" | |
| def test_chat_completion_speculative(): | |
| def gen_character_spec(s): | |
| s += sgl.system("You are a helpful assistant.") | |
| s += sgl.user("Construct a character within the following format:") | |
| s += sgl.assistant( | |
| "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n" | |
| ) | |
| s += sgl.user("Please generate new Name, Birthday and Job.\n") | |
| s += sgl.assistant( | |
| "Name:" | |
| + sgl.gen("name", stop="\n") | |
| + "\nBirthday:" | |
| + sgl.gen("birthday", stop="\n") | |
| + "\nJob:" | |
| + sgl.gen("job", stop="\n") | |
| ) | |
| gen_character_spec().sync() | |
| def test_hellaswag_select(): | |
| """Benchmark the accuracy of sgl.select on the HellaSwag dataset.""" | |
| def get_one_example(lines, i, include_answer): | |
| ret = lines[i]["activity_label"] + ": " + lines[i]["ctx"] + " " | |
| if include_answer: | |
| ret += lines[i]["endings"][lines[i]["label"]] | |
| return ret | |
| def get_few_shot_examples(lines, k): | |
| ret = "" | |
| for i in range(k): | |
| ret += get_one_example(lines, i, True) + "\n\n" | |
| return ret | |
| # Read data | |
| url = "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl" | |
| filename = download_and_cache_file(url) | |
| lines = list(read_jsonl(filename)) | |
| # Construct prompts | |
| num_questions = 200 | |
| num_shots = 20 | |
| few_shot_examples = get_few_shot_examples(lines, num_shots) | |
| questions = [] | |
| choices = [] | |
| labels = [] | |
| for i in range(len(lines[:num_questions])): | |
| questions.append(get_one_example(lines, i, False)) | |
| choices.append(lines[i]["endings"]) | |
| labels.append(lines[i]["label"]) | |
| arguments = [{"question": q, "choices": c} for q, c in zip(questions, choices)] | |
| ##################################### | |
| ######### SGL Program Begin ######### | |
| ##################################### | |
| import sglang as sgl | |
| def few_shot_hellaswag(s, question, choices): | |
| s += few_shot_examples + question | |
| s += sgl.select("answer", choices=choices) | |
| ##################################### | |
| ########## SGL Program End ########## | |
| ##################################### | |
| # Run requests | |
| tic = time.perf_counter() | |
| rets = few_shot_hellaswag.run_batch( | |
| arguments, | |
| temperature=0, | |
| num_threads=64, | |
| progress_bar=True, | |
| generator_style=False, | |
| ) | |
| preds = [] | |
| for i, ret in enumerate(rets): | |
| preds.append(choices[i].index(ret["answer"])) | |
| latency = time.perf_counter() - tic | |
| # Compute accuracy | |
| accuracy = np.mean(np.array(preds) == np.array(labels)) | |
| # Test generator style of run_batch | |
| tic = time.perf_counter() | |
| rets = few_shot_hellaswag.run_batch( | |
| arguments, | |
| temperature=0, | |
| num_threads=64, | |
| progress_bar=True, | |
| generator_style=True, | |
| ) | |
| preds_gen = [] | |
| for i, ret in enumerate(rets): | |
| preds_gen.append(choices[i].index(ret["answer"])) | |
| latency_gen = time.perf_counter() - tic | |
| # Compute accuracy | |
| accuracy_gen = np.mean(np.array(preds_gen) == np.array(labels)) | |
| print(f"{accuracy=}, {accuracy_gen=}") | |
| assert np.abs(accuracy_gen - accuracy) < 0.1 | |
| assert np.abs(latency_gen - latency) < 1 | |
| return accuracy, latency | |
| def test_gen_min_new_tokens(): | |
| """ | |
| Validate sgl.gen(min_tokens) functionality. | |
| The test asks a question where, without a min_tokens constraint, the generated answer is expected to be short. | |
| By enforcing the min_tokens parameter, we ensure the generated answer has at least the specified number of tokens. | |
| We verify that the number of tokens in the answer is >= the min_tokens threshold. | |
| """ | |
| import sglang as sgl | |
| from sglang.srt.utils.hf_transformers_utils import get_tokenizer | |
| model_path = sgl.global_config.default_backend.endpoint.get_model_name() | |
| MIN_TOKENS, MAX_TOKENS = 64, 128 | |
| def convo_1(s): | |
| s += sgl.user("What is the capital of the United States?") | |
| s += sgl.assistant( | |
| sgl.gen("answer", min_tokens=MIN_TOKENS, max_tokens=MAX_TOKENS) | |
| ) | |
| def assert_min_tokens(tokenizer, text): | |
| token_ids = tokenizer.encode(text) | |
| assert ( | |
| len(token_ids) >= MIN_TOKENS | |
| ), f"Generated {len(token_ids)} tokens, min required: {MIN_TOKENS}. Text: {text}" | |
| tokenizer = get_tokenizer(model_path) | |
| state = convo_1.run() | |
| assert_min_tokens(tokenizer, state["answer"]) | |
Xet Storage Details
- Size:
- 18.9 kB
- Xet hash:
- af9fd7bc531934d215db458983a5b4d882e4c6da2ca3d820b0fbecbd6798fd9b
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.