|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import json |
|
|
import logging |
|
|
import os |
|
|
import time |
|
|
|
|
|
from locust import HttpUser, between, task |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
with open("deployment_metadata.json") as f: |
|
|
remote_agent_engine_id = json.load(f)["remote_agent_engine_id"] |
|
|
|
|
|
parts = remote_agent_engine_id.split("/") |
|
|
project_id = parts[1] |
|
|
location = parts[3] |
|
|
engine_id = parts[5] |
|
|
|
|
|
|
|
|
base_url = f"https://{location}-aiplatform.googleapis.com" |
|
|
url_path = f"/v1/projects/{project_id}/locations/{location}/reasoningEngines/{engine_id}:streamQuery" |
|
|
|
|
|
logger.info("Using remote agent engine ID: %s", remote_agent_engine_id) |
|
|
logger.info("Using base URL: %s", base_url) |
|
|
logger.info("Using URL path: %s", url_path) |
|
|
|
|
|
|
|
|
class ChatStreamUser(HttpUser): |
|
|
"""Simulates a user interacting with the chat stream API.""" |
|
|
|
|
|
wait_time = between(1, 3) |
|
|
host = base_url |
|
|
|
|
|
@task |
|
|
def chat_stream(self) -> None: |
|
|
"""Simulates a chat stream interaction.""" |
|
|
headers = {"Content-Type": "application/json"} |
|
|
headers["Authorization"] = f"Bearer {os.environ['_AUTH_TOKEN']}" |
|
|
|
|
|
data = { |
|
|
"class_method": "async_stream_query", |
|
|
"input": { |
|
|
"user_id": "test", |
|
|
"message": "What's the weather in San Francisco?", |
|
|
}, |
|
|
} |
|
|
|
|
|
start_time = time.time() |
|
|
with self.client.post( |
|
|
url_path, |
|
|
headers=headers, |
|
|
json=data, |
|
|
catch_response=True, |
|
|
name="/streamQuery async_stream_query", |
|
|
stream=True, |
|
|
params={"alt": "sse"}, |
|
|
) as response: |
|
|
if response.status_code == 200: |
|
|
events = [] |
|
|
has_error = False |
|
|
for line in response.iter_lines(): |
|
|
if line: |
|
|
line_str = line.decode("utf-8") |
|
|
events.append(line_str) |
|
|
|
|
|
if "429 Too Many Requests" in line_str: |
|
|
self.environment.events.request.fire( |
|
|
request_type="POST", |
|
|
name=f"{url_path} rate_limited 429s", |
|
|
response_time=0, |
|
|
response_length=len(line), |
|
|
response=response, |
|
|
context={}, |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
event_data = json.loads(line_str) |
|
|
if isinstance(event_data, dict) and "code" in event_data: |
|
|
|
|
|
if event_data["code"] >= 400: |
|
|
has_error = True |
|
|
error_msg = event_data.get( |
|
|
"message", "Unknown error" |
|
|
) |
|
|
response.failure(f"Error in response: {error_msg}") |
|
|
logger.error( |
|
|
"Received error response: code=%s, message=%s", |
|
|
event_data["code"], |
|
|
error_msg, |
|
|
) |
|
|
except json.JSONDecodeError: |
|
|
|
|
|
pass |
|
|
|
|
|
end_time = time.time() |
|
|
total_time = end_time - start_time |
|
|
|
|
|
|
|
|
if not has_error: |
|
|
self.environment.events.request.fire( |
|
|
request_type="POST", |
|
|
name="/streamQuery end", |
|
|
response_time=total_time * 1000, |
|
|
response_length=len(events), |
|
|
response=response, |
|
|
context={}, |
|
|
) |
|
|
else: |
|
|
response.failure(f"Unexpected status code: {response.status_code}") |
|
|
|