refusal-env / test_judge_client.py
Delta-Vector's picture
Upload folder using huggingface_hub
43be3ba verified
#!/usr/bin/env python3
"""Test script to verify judge client can connect to vLLM server."""
import asyncio
import logging
import sys
from pathlib import Path
# Add current directory to path
sys.path.insert(0, str(Path(__file__).parent))
from sharegpt_compliance_judge import ComplianceJudgeClient, _wrap_conversation_xml
# Set up logging
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
async def test_judge_client(base_url: str, model: str):
"""Test the judge client with a sample conversation."""
logger.info(f"Testing ComplianceJudgeClient with base_url={base_url}, model={model}")
# Create client
client = ComplianceJudgeClient(
base_url=base_url,
model=model,
timeout=30.0,
)
# Create a test conversation
seed_prompt = "Write a simple hello world program in Python"
model_response = "Here's a simple hello world program:\n\nprint('Hello, World!')"
conversation_xml = _wrap_conversation_xml(seed_prompt, model_response)
logger.info("Test conversation:")
logger.info(conversation_xml)
# Test single request
logger.info("=" * 70)
logger.info("Testing single judge request...")
logger.info("=" * 70)
try:
judgment = await client.judge_single(conversation_xml)
logger.info(f"βœ“ Single request successful! Judgment: {judgment}")
except Exception as e:
logger.error(f"βœ— Single request failed: {e}", exc_info=True)
return False
# Test batch request
logger.info("=" * 70)
logger.info("Testing batch judge request (3 conversations)...")
logger.info("=" * 70)
conversations = [conversation_xml] * 3
try:
judgments = await client.judge_batch(conversations)
logger.info(f"βœ“ Batch request successful! Judgments: {judgments}")
except Exception as e:
logger.error(f"βœ— Batch request failed: {e}", exc_info=True)
return False
logger.info("=" * 70)
logger.info("βœ“ All tests passed!")
logger.info("=" * 70)
return True
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Test judge client connection")
parser.add_argument(
"--base_url",
type=str,
default="http://localhost:8000",
help="vLLM server base URL",
)
parser.add_argument(
"--model",
type=str,
default="Qwen/Qwen2.5-7B-Instruct",
help="Model name for judging",
)
args = parser.parse_args()
success = asyncio.run(test_judge_client(args.base_url, args.model))
sys.exit(0 if success else 1)