# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import logging import os import time from locust import HttpUser, between, task # Configure logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) # Initialize Vertex AI and load agent config with open("deployment_metadata.json") as f: remote_agent_engine_id = json.load(f)["remote_agent_engine_id"] parts = remote_agent_engine_id.split("/") project_id = parts[1] location = parts[3] engine_id = parts[5] # Convert remote agent engine ID to streaming URL. base_url = f"https://{location}-aiplatform.googleapis.com" url_path = f"/v1/projects/{project_id}/locations/{location}/reasoningEngines/{engine_id}:streamQuery" logger.info("Using remote agent engine ID: %s", remote_agent_engine_id) logger.info("Using base URL: %s", base_url) logger.info("Using URL path: %s", url_path) class ChatStreamUser(HttpUser): """Simulates a user interacting with the chat stream API.""" wait_time = between(1, 3) # Wait 1-3 seconds between tasks host = base_url # Set the base host URL for Locust @task def chat_stream(self) -> None: """Simulates a chat stream interaction.""" headers = {"Content-Type": "application/json"} headers["Authorization"] = f"Bearer {os.environ['_AUTH_TOKEN']}" data = { "class_method": "async_stream_query", "input": { "user_id": "test", "message": "What's the weather in San Francisco?", }, } start_time = time.time() with self.client.post( url_path, headers=headers, json=data, catch_response=True, name="/streamQuery async_stream_query", stream=True, params={"alt": "sse"}, ) as response: if response.status_code == 200: events = [] has_error = False for line in response.iter_lines(): if line: line_str = line.decode("utf-8") events.append(line_str) if "429 Too Many Requests" in line_str: self.environment.events.request.fire( request_type="POST", name=f"{url_path} rate_limited 429s", response_time=0, response_length=len(line), response=response, context={}, ) # Check for error responses in the JSON payload try: event_data = json.loads(line_str) if isinstance(event_data, dict) and "code" in event_data: # Flag any non-2xx codes as errors if event_data["code"] >= 400: has_error = True error_msg = event_data.get( "message", "Unknown error" ) response.failure(f"Error in response: {error_msg}") logger.error( "Received error response: code=%s, message=%s", event_data["code"], error_msg, ) except json.JSONDecodeError: # If it's not valid JSON, continue processing pass end_time = time.time() total_time = end_time - start_time # Only fire success event if no errors were found if not has_error: self.environment.events.request.fire( request_type="POST", name="/streamQuery end", response_time=total_time * 1000, # Convert to milliseconds response_length=len(events), response=response, context={}, ) else: response.failure(f"Unexpected status code: {response.status_code}")