| import aiohttp |
| import asyncio |
| import os |
| import time |
| from datetime import datetime |
| import random |
|
|
| def _base_url() -> str: |
| host = os.environ.get("SGLANG_HTTP_HOST", "127.0.0.1") |
| port = os.environ.get("SGLANG_HTTP_PORT", "30000") |
| return f"http://{host}:{port}" |
|
|
|
|
| database = [ |
| ["who are you?", |
| "i'm just a little fish. i have little body. they work. i eat the red flakes when i can.", |
| "hi", |
| "i'm just a snail would be nice.", |
| "what is the capital of France?", |
| "i am a bubble wall.", |
| ], |
| ["is there a cat in the room?", |
| "i don't like it. it puts its face on the glass by the bubbles.", |
| "I'm sorry. are you hungry?", |
| "i don't eat it.", |
| ] |
| ] |
|
|
| async def main(num_tests: int, concurrency: int): |
| url = f"{_base_url()}/v1/chat/completions" |
| headers = {"Content-Type": "application/json"} |
| RED = "\033[0;31m" |
| GREEN = "\033[0;32m" |
| YELLOW = "\033[0;33m" |
| BLUE = "\033[0;34m" |
| MAGENTA = "\033[0;35m" |
| CYAN = "\033[0;36m" |
| WHITE = "\033[0;37m" |
| END = "\033[0m" |
| timeout = aiohttp.ClientTimeout( |
| total=600, |
| connect=10, |
| sock_read=30 |
| ) |
| async with aiohttp.ClientSession(timeout=timeout) as session: |
|
|
| if num_tests == 0: |
| |
| json = {"model": "guppylm-9M", |
| "messages": [], |
| "temperature" : 0.001, |
| "top_k": 1, |
| "stream": False, |
| } |
| while True: |
| user_input = input("user: ") |
| json["messages"].append({"role": "user", "content": user_input}) |
| async with session.post(url, json=json, headers=headers) as response: |
| res = (await response.json())["choices"][0]["message"] |
| output = res["content"] |
| role = res["role"] |
| json["messages"].append({"role": role, "content": output}) |
| print(f"{role}: {output}") |
| else: |
| concurrency_sem = asyncio.Semaphore(max(1,concurrency)) |
| async def job(): |
| json = {"model": "guppylm-9M", |
| "messages": [], |
| "temperature" : 0.001, |
| "top_k": 1, |
| "stream": False, |
| } |
| random_index = random.randint(0, len(database) - 1) |
| user_inputs = database[random_index] |
|
|
| unexpected_count = 0 |
| logs = [] |
| for input, expected in zip(user_inputs[0::2], user_inputs[1::2]): |
| json["messages"].append({"role": "user", "content": input}) |
| logs.append(datetime.now().strftime("%H:%M:%S.%f")[:-3] + " " + " user: " + input) |
| async with concurrency_sem: |
| async with session.post(url, json=json, headers=headers) as response: |
| res = (await response.json())["choices"][0]["message"] |
| output = res["content"] |
| role = res["role"] |
| color = GREEN if output == expected else RED |
| unexpected_count += 1 if output != expected else 0 |
| json["messages"].append({"role": role, "content": output}) |
| logs.append(datetime.now().strftime("%H:%M:%S.%f")[:-3] + " " + "assistant: " + color + output + END) |
|
|
| print("=========\n","\n".join(logs)) |
| return unexpected_count |
|
|
| tasks = [] |
| for i in range(num_tests): |
| tasks.append(asyncio.create_task(job())) |
| results = await asyncio.gather(*tasks) |
| return sum(results) |
|
|
| if 1: |
| unexpected_count = asyncio.run(main(10, 10)) |
| print(f"unexpected count: {unexpected_count}") |
|
|
| asyncio.run(main(0, 0)) |
|
|
|
|
|
|