| |
| from concurrent.futures import ThreadPoolExecutor |
|
|
| from json_decode import character_regex |
|
|
| from sglang.utils import http_request |
|
|
| character_names = ["Hermione Granger", "Ron Weasley", "Harry Potter"] |
|
|
| base_url = "http://localhost:30000" |
|
|
| prompt = "is a character in Harry Potter. Please fill in the following information about this character.\n" |
|
|
|
|
| def openai_api_request(name): |
| data = { |
| "model": "", |
| "prompt": name + prompt, |
| "temperature": 0, |
| "max_tokens": 128, |
| "regex": character_regex, |
| "logprobs": 3, |
| } |
| res = http_request(base_url + "/v1/completions", json=data).json() |
|
|
| |
| |
|
|
| logprobs = res["choices"][0]["logprobs"] |
| usage = res["usage"] |
| assert len(logprobs["token_logprobs"]) == len(logprobs["tokens"]) |
| assert len(logprobs["token_logprobs"]) == len(logprobs["top_logprobs"]) |
| assert len(logprobs["token_logprobs"]) == usage["completion_tokens"] - 1 |
|
|
| return res |
|
|
|
|
| def srt_api_request(name): |
| data = { |
| "text": name + prompt, |
| "sampling_params": { |
| "temperature": 0, |
| "max_new_tokens": 128, |
| "regex": character_regex, |
| }, |
| "return_logprob": True, |
| "logprob_start_len": 0, |
| "top_logprobs_num": 3, |
| "return_text_in_logprobs": True, |
| } |
|
|
| res = http_request(base_url + "/generate", json=data).json() |
|
|
| |
| |
|
|
| meta_info = res["meta_info"] |
| assert len(meta_info["input_token_logprobs"]) == len( |
| meta_info["input_top_logprobs"] |
| ) |
| assert len(meta_info["output_token_logprobs"]) == len( |
| meta_info["output_top_logprobs"] |
| ) |
| assert len(meta_info["input_token_logprobs"]) == meta_info["prompt_tokens"] |
| assert len(meta_info["output_token_logprobs"]) == meta_info["completion_tokens"] - 1 |
|
|
| return res |
|
|
|
|
| def pretty_print(res): |
| meta_info = res["meta_info"] |
|
|
| print("\n\n", "=" * 30, "Prefill", "=" * 30) |
| for i in range(len(meta_info["input_token_logprobs"])): |
| print(f"{str(meta_info['input_token_logprobs'][i][2].encode()): <20}", end="") |
| top_ks = ( |
| [str(t[2].encode()) for t in meta_info["input_top_logprobs"][i]] |
| if meta_info["input_top_logprobs"][i] |
| else [] |
| ) |
| for top_k in top_ks: |
| print(f"{top_k: <15}", end="") |
| print() |
|
|
| print("\n\n", "=" * 30, "Decode", "=" * 30) |
| for i in range(len(meta_info["output_token_logprobs"])): |
| print(f"{str(meta_info['output_token_logprobs'][i][2].encode()): <20}", end="") |
| top_ks = [str(t[2].encode()) for t in meta_info["output_top_logprobs"][i]] |
| for top_k in top_ks: |
| print(f"{top_k: <15}", end="") |
| print() |
|
|
| print(res["text"]) |
|
|
|
|
| if __name__ == "__main__": |
| with ThreadPoolExecutor() as executor: |
| ress = executor.map(srt_api_request, character_names) |
|
|
| for res in ress: |
| pretty_print(res) |
|
|
| openai_api_request("Hermione Granger") |
|
|