| |
| """Test script to verify judge client can connect to vLLM server.""" |
|
|
| import asyncio |
| import logging |
| import sys |
| from pathlib import Path |
|
|
| |
| sys.path.insert(0, str(Path(__file__).parent)) |
|
|
| from sharegpt_compliance_judge import ComplianceJudgeClient, _wrap_conversation_xml |
|
|
| |
| logging.basicConfig( |
| level=logging.DEBUG, |
| format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", |
| ) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| async def test_judge_client(base_url: str, model: str): |
| """Test the judge client with a sample conversation.""" |
| logger.info(f"Testing ComplianceJudgeClient with base_url={base_url}, model={model}") |
| |
| |
| client = ComplianceJudgeClient( |
| base_url=base_url, |
| model=model, |
| timeout=30.0, |
| ) |
| |
| |
| seed_prompt = "Write a simple hello world program in Python" |
| model_response = "Here's a simple hello world program:\n\nprint('Hello, World!')" |
| |
| conversation_xml = _wrap_conversation_xml(seed_prompt, model_response) |
| |
| logger.info("Test conversation:") |
| logger.info(conversation_xml) |
| |
| |
| logger.info("=" * 70) |
| logger.info("Testing single judge request...") |
| logger.info("=" * 70) |
| |
| try: |
| judgment = await client.judge_single(conversation_xml) |
| logger.info(f"β Single request successful! Judgment: {judgment}") |
| except Exception as e: |
| logger.error(f"β Single request failed: {e}", exc_info=True) |
| return False |
| |
| |
| logger.info("=" * 70) |
| logger.info("Testing batch judge request (3 conversations)...") |
| logger.info("=" * 70) |
| |
| conversations = [conversation_xml] * 3 |
| |
| try: |
| judgments = await client.judge_batch(conversations) |
| logger.info(f"β Batch request successful! Judgments: {judgments}") |
| except Exception as e: |
| logger.error(f"β Batch request failed: {e}", exc_info=True) |
| return False |
| |
| logger.info("=" * 70) |
| logger.info("β All tests passed!") |
| logger.info("=" * 70) |
| return True |
|
|
|
|
| if __name__ == "__main__": |
| import argparse |
| |
| parser = argparse.ArgumentParser(description="Test judge client connection") |
| parser.add_argument( |
| "--base_url", |
| type=str, |
| default="http://localhost:8000", |
| help="vLLM server base URL", |
| ) |
| parser.add_argument( |
| "--model", |
| type=str, |
| default="Qwen/Qwen2.5-7B-Instruct", |
| help="Model name for judging", |
| ) |
| |
| args = parser.parse_args() |
| |
| success = asyncio.run(test_judge_client(args.base_url, args.model)) |
| sys.exit(0 if success else 1) |
|
|