Spaces:
Running
Running
| """CUGA SDK agent for BPO benchmark evaluation with Langfuse tracking.""" | |
| import asyncio | |
| import logging | |
| import os | |
| import re | |
| import threading | |
| import time | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import uvicorn | |
| logger = logging.getLogger(__name__) | |
| # Global flags to track server status | |
| _servers_started = False | |
| _servers_lock = threading.Lock() | |
| # ============================================================================ | |
| # Provider Configuration | |
| # ============================================================================ | |
| PROVIDER_CONFIGS = { | |
| "groq": { | |
| "env_var": "GROQ_API_KEY", | |
| "settings_file": "settings.groq.toml", | |
| "default_model": "openai/gpt-oss-120b", | |
| "models": [ | |
| "openai/gpt-oss-120b", | |
| "llama-3.3-70b-versatile", | |
| "llama-3.1-8b-instant", | |
| "mixtral-8x7b-32768", | |
| ], | |
| "placeholder": "gsk_...", | |
| }, | |
| "openai": { | |
| "env_var": "OPENAI_API_KEY", | |
| "settings_file": "settings.openai.toml", | |
| "default_model": "gpt-4o-mini", | |
| "models": [ | |
| "gpt-4o-mini", | |
| "gpt-4.1", | |
| "gpt-5", | |
| "gpt-4o", | |
| ], | |
| "placeholder": "sk-...", | |
| }, | |
| } | |
| def get_provider_models(provider: str) -> List[str]: | |
| """Get available models for a provider.""" | |
| config = PROVIDER_CONFIGS.get(provider.lower(), {}) | |
| return config.get("models", []) | |
| def get_provider_placeholder(provider: str) -> str: | |
| """Get API key placeholder for a provider.""" | |
| config = PROVIDER_CONFIGS.get(provider.lower(), {}) | |
| return config.get("placeholder", "...") | |
| def get_default_model(provider: str) -> str: | |
| """Get default model for a provider.""" | |
| config = PROVIDER_CONFIGS.get(provider.lower(), {}) | |
| return config.get("default_model", "") | |
| # ============================================================================ | |
| # Server Management | |
| # ============================================================================ | |
| def start_servers(): | |
| """Start BPO API and Registry servers if not already running.""" | |
| global _servers_started | |
| with _servers_lock: | |
| if _servers_started: | |
| return | |
| _servers_started = True | |
| # Import here to avoid circular imports | |
| from server import app as bpo_app | |
| from cuga.backend.tools_env.registry.registry.api_registry_server import ( | |
| app as registry_app, | |
| ) | |
| # Start BPO API server on port 8000 | |
| def run_bpo(): | |
| uvicorn.run(bpo_app, host="0.0.0.0", port=8000, log_level="warning") | |
| bpo_thread = threading.Thread(target=run_bpo, daemon=True) | |
| bpo_thread.start() | |
| logger.info("BPO API server starting on port 8000") | |
| # Start Registry server on port 8001 | |
| def run_registry(): | |
| uvicorn.run(registry_app, host="0.0.0.0", port=8001, log_level="warning") | |
| registry_thread = threading.Thread(target=run_registry, daemon=True) | |
| registry_thread.start() | |
| logger.info("Registry server starting on port 8001") | |
| # Wait for servers to be ready | |
| time.sleep(4) | |
| logger.info("Servers started") | |
| # ============================================================================ | |
| # Environment Setup | |
| # ============================================================================ | |
| def setup_environment(api_key: str, provider: str, model: Optional[str] = None): | |
| """Set up environment variables for CUGA SDK.""" | |
| # Clear conflicting env vars | |
| for key in ["OPENAI_BASE_URL", "OPENAI_API_KEY", "GROQ_API_KEY"]: | |
| if key in os.environ: | |
| del os.environ[key] | |
| provider_lower = provider.lower() | |
| config = PROVIDER_CONFIGS.get(provider_lower) | |
| if not config: | |
| raise ValueError(f"Unknown provider: {provider}. Supported: {list(PROVIDER_CONFIGS.keys())}") | |
| # Set provider-specific config | |
| os.environ[config["env_var"]] = api_key | |
| os.environ["AGENT_SETTING_CONFIG"] = config["settings_file"] | |
| os.environ["MODEL_NAME"] = model or config["default_model"] | |
| # Set MCP servers file path | |
| mcp_config = Path(__file__).parent / "mcp_servers" / "bpo.yaml" | |
| os.environ["MCP_SERVERS_FILE"] = str(mcp_config.resolve()) | |
| # Disable policies for benchmark | |
| os.environ["DYNACONF_POLICY__ENABLED"] = "false" | |
| logger.info(f"Environment configured: provider={provider}, model={os.environ.get('MODEL_NAME')}") | |
| # ============================================================================ | |
| # Langfuse Integration | |
| # ============================================================================ | |
| class LangfuseTracker: | |
| """Tracks evaluation runs and task scores in Langfuse.""" | |
| def __init__(self): | |
| self.enabled = False | |
| self.langfuse = None | |
| self.trace_id = None | |
| self.init_error = None | |
| self._init_langfuse() | |
| def _init_langfuse(self) -> None: | |
| """Initialize Langfuse client if credentials are available.""" | |
| # Debug: show all LANGFUSE env vars | |
| langfuse_vars = {k: ('set' if v else 'empty') for k, v in os.environ.items() if 'LANGFUSE' in k.upper()} | |
| logger.info(f"Langfuse env vars found: {langfuse_vars}") | |
| public_key = os.environ.get("LANGFUSE_PUBLIC_KEY") | |
| secret_key = os.environ.get("LANGFUSE_SECRET_KEY") | |
| # Support both LANGFUSE_HOST and LANGFUSE_BASE_URL | |
| host = os.environ.get("LANGFUSE_HOST") or os.environ.get("LANGFUSE_BASE_URL") or "https://cloud.langfuse.com" | |
| logger.info(f"Langfuse init: public_key={'set' if public_key else 'not set'}, secret_key={'set' if secret_key else 'not set'}, host={host}") | |
| if not public_key or not secret_key: | |
| self.init_error = "Langfuse credentials not found" | |
| logger.info(self.init_error) | |
| return | |
| try: | |
| from langfuse import Langfuse | |
| self.langfuse = Langfuse( | |
| public_key=public_key, | |
| secret_key=secret_key, | |
| host=host, | |
| ) | |
| # Test the connection by checking auth | |
| self.langfuse.auth_check() | |
| self.enabled = True | |
| logger.info(f"Langfuse tracking initialized successfully (host={host})") | |
| except ImportError as e: | |
| self.init_error = f"langfuse package not installed: {e}" | |
| logger.warning(self.init_error) | |
| except Exception as e: | |
| self.init_error = f"Failed to initialize Langfuse: {e}" | |
| logger.warning(self.init_error) | |
| def start_trace(self, name: str, metadata: Optional[Dict[str, Any]] = None) -> Optional[str]: | |
| """Start a new trace for an evaluation run.""" | |
| if not self.enabled or not self.langfuse: | |
| return None | |
| try: | |
| # Use create_trace for newer Langfuse API | |
| trace = self.langfuse.trace(name=name, metadata=metadata or {}) | |
| self.trace_id = trace.id | |
| return self.trace_id | |
| except AttributeError: | |
| # Fallback for different Langfuse versions | |
| try: | |
| self.trace_id = f"trace_{name}_{id(self)}" | |
| logger.info(f"Using fallback trace ID: {self.trace_id}") | |
| return self.trace_id | |
| except Exception as e: | |
| logger.warning(f"Failed to create trace (fallback): {e}") | |
| return None | |
| except Exception as e: | |
| logger.warning(f"Failed to create trace: {e}") | |
| return None | |
| def score_task(self, task_id: str, scores: Dict[str, float]) -> None: | |
| """Score a task within the current trace.""" | |
| if not self.enabled or not self.langfuse or not self.trace_id: | |
| return | |
| try: | |
| for name, value in scores.items(): | |
| self.langfuse.score( | |
| trace_id=self.trace_id, | |
| name=f"{task_id}_{name}", | |
| value=value, | |
| ) | |
| except Exception as e: | |
| logger.warning(f"Failed to score task {task_id}: {e}") | |
| def end_trace(self, summary: Optional[Dict[str, Any]] = None) -> None: | |
| """End the current trace with summary metrics.""" | |
| if not self.enabled or not self.langfuse: | |
| return | |
| try: | |
| if summary and self.trace_id: | |
| for name, value in summary.items(): | |
| if isinstance(value, (int, float)) and not isinstance(value, bool): | |
| self.langfuse.score( | |
| trace_id=self.trace_id, | |
| name=f"summary_{name}", | |
| value=float(value), | |
| ) | |
| self.langfuse.flush() | |
| except Exception as e: | |
| logger.warning(f"Failed to end trace: {e}") | |
| finally: | |
| self.trace_id = None | |
| def is_langfuse_configured() -> bool: | |
| """Check if Langfuse environment variables are set.""" | |
| return bool( | |
| os.environ.get("LANGFUSE_PUBLIC_KEY") and | |
| os.environ.get("LANGFUSE_SECRET_KEY") | |
| ) | |
| def get_langfuse_host() -> str: | |
| """Get the configured Langfuse host.""" | |
| return os.environ.get("LANGFUSE_HOST") or os.environ.get("LANGFUSE_BASE_URL") or "https://cloud.langfuse.com" | |
| # ============================================================================ | |
| # CUGA Agent | |
| # ============================================================================ | |
| class CUGAAgent: | |
| """CUGA SDK agent for BPO benchmark evaluation.""" | |
| def __init__( | |
| self, | |
| api_key: str, | |
| provider: str = "groq", | |
| model: Optional[str] = None, | |
| ): | |
| """Initialize the CUGA agent. | |
| Args: | |
| api_key: API key for the LLM provider | |
| provider: "openai" or "groq" | |
| model: Model name (optional, uses defaults) | |
| """ | |
| self.api_key = api_key | |
| self.provider = provider.lower() | |
| self.model = model | |
| self.agent = None | |
| self.tool_provider = None | |
| # Set up environment BEFORE importing cuga modules | |
| setup_environment(api_key, self.provider, model) | |
| # Start servers | |
| start_servers() | |
| async def setup(self): | |
| """Initialize the CUGA agent with tools.""" | |
| from cuga.sdk import CugaAgent | |
| from cuga.config import settings | |
| from cuga.backend.cuga_graph.nodes.cuga_lite.combined_tool_provider import ( | |
| CombinedToolProvider, | |
| ) | |
| logger.info("Setting up CUGA agent...") | |
| # Enable ActivityTracker for tool call tracking | |
| settings.update({"ADVANCED_FEATURES": {"TRACKER_ENABLED": True}}, merge=True) | |
| # Initialize tool provider (will load from registry) | |
| self.tool_provider = CombinedToolProvider() | |
| await self.tool_provider.initialize() | |
| all_tools = await self.tool_provider.get_all_tools() | |
| logger.info(f"Loaded {len(all_tools)} tools from BPO API") | |
| if len(all_tools) == 0: | |
| raise RuntimeError("No tools loaded from registry. Check server status.") | |
| # Create agent | |
| self.agent = CugaAgent(tool_provider=self.tool_provider) | |
| logger.info("CUGA agent initialized") | |
| async def run(self, query: str, thread_id: Optional[str] = None) -> Tuple[str, List[Dict[str, Any]]]: | |
| """Run the agent on a query. | |
| Args: | |
| query: The user's question | |
| thread_id: Optional thread ID for conversation context | |
| Returns: | |
| Tuple of (response_text, tool_calls) | |
| """ | |
| if self.agent is None: | |
| await self.setup() | |
| from langchain_core.messages import HumanMessage | |
| # Get ActivityTracker singleton and reset for this task | |
| try: | |
| from cuga.backend.activity_tracker.tracker import ActivityTracker | |
| tracker = ActivityTracker() | |
| tracker.reset(intent=query, task_id=thread_id or "eval_task") | |
| except ImportError: | |
| tracker = None | |
| logger.warning("ActivityTracker not available, tool call tracking disabled") | |
| result = await self.agent.invoke( | |
| [HumanMessage(content=query)], | |
| thread_id=thread_id or "eval_task", | |
| track_tool_calls=True, # Required for ActivityTracker to capture tool calls | |
| ) | |
| # Debug: log result object structure | |
| result_attrs = [attr for attr in dir(result) if not attr.startswith('_')] | |
| logger.info(f"Result object attributes: {result_attrs}") | |
| if hasattr(result, '__dict__'): | |
| logger.info(f"Result __dict__ keys: {list(result.__dict__.keys())}") | |
| # Extract response | |
| response = result.answer if hasattr(result, "answer") else str(result) | |
| # Extract tool calls from ActivityTracker.steps (same approach as sdk_eval_helpers.py) | |
| tool_calls = [] | |
| if tracker is not None: | |
| import json | |
| logger.info(f"ActivityTracker has {len(tracker.steps)} steps") | |
| # Debug: log step names to understand structure (first 5 only) | |
| step_names = [s.name for s in tracker.steps[:5]] | |
| logger.info(f"First step names: {step_names}") | |
| # Match "api_call" in step name (the standard CUGA SDK pattern) | |
| for step in tracker.steps: | |
| if step.name and "api_call" in step.name: | |
| try: | |
| call_data = json.loads(step.data) if step.data else {} | |
| tool_name = call_data.get("function_name", "") | |
| if tool_name: | |
| tool_calls.append({ | |
| "name": tool_name, | |
| "args": call_data.get("args", {}), | |
| }) | |
| except (json.JSONDecodeError, TypeError) as e: | |
| logger.warning(f"Failed to parse tool call step data: {e}") | |
| continue | |
| logger.info(f"Extracted {len(tool_calls)} tool calls from ActivityTracker") | |
| # Fallback 1: try to extract from result.tool_calls attribute | |
| if not tool_calls and hasattr(result, 'tool_calls') and result.tool_calls: | |
| logger.info("Trying fallback: result.tool_calls") | |
| for tc in result.tool_calls: | |
| if isinstance(tc, dict): | |
| tool_calls.append({"name": tc.get("name", ""), "args": tc.get("args", {})}) | |
| elif hasattr(tc, 'name'): | |
| tool_calls.append({"name": tc.name, "args": getattr(tc, 'args', {})}) | |
| logger.info(f"Fallback extracted {len(tool_calls)} tool calls") | |
| return response, tool_calls | |
| def close(self): | |
| """Clean up resources.""" | |
| pass # Servers run as daemons, will stop with process | |
| # ============================================================================ | |
| # Evaluation Metrics (copied from main repo for standalone use) | |
| # ============================================================================ | |
| def normalize_text(text: str) -> str: | |
| """Normalize text for keyword matching.""" | |
| import unicodedata | |
| text = unicodedata.normalize("NFC", text) | |
| # Replace special spaces | |
| text = text.replace("\u202f", " ").replace("\u00a0", " ").replace("\u2009", " ") | |
| # Replace dashes | |
| text = text.replace("\u2013", "-").replace("\u2014", "-").replace("\u2212", "-") | |
| text = text.lower() | |
| # Remove markdown | |
| text = re.sub(r"[`*_~]", "", text) | |
| # Replace punctuation except | (for OR alternatives) | |
| text = re.sub(r"[^\w\s%|]", " ", text) | |
| # Collapse whitespace | |
| text = re.sub(r"\s+", " ", text).strip() | |
| return text | |
| def check_keywords(response: str, expected_keywords: List[str]) -> Dict[str, Any]: | |
| """Check if expected keywords are present in the response. | |
| Supports: | |
| - OR mechanism: keywords can use "|" to specify alternatives | |
| - Regex keywords: prefix with "re:" to use regex pattern | |
| Args: | |
| response: Agent's response text | |
| expected_keywords: List of keywords (can use "|" for OR, "re:" for regex) | |
| Returns: | |
| Dictionary with keyword check results | |
| """ | |
| if not expected_keywords: | |
| return { | |
| "all_found": True, | |
| "match_rate": 1.0, | |
| "found_keywords": [], | |
| "missing_keywords": [], | |
| "total_keywords": 0, | |
| "found_count": 0, | |
| } | |
| response_normalized = normalize_text(response) | |
| found_keywords = [] | |
| missing_keywords = [] | |
| for keyword in expected_keywords: | |
| # Regex keyword support | |
| if keyword.strip().lower().startswith("re:"): | |
| pattern = keyword.strip()[3:] | |
| if re.search(pattern, response_normalized, flags=re.IGNORECASE): | |
| found_keywords.append(keyword) | |
| else: | |
| missing_keywords.append(keyword) | |
| continue | |
| keyword_normalized = normalize_text(keyword) | |
| # OR alternatives | |
| if "|" in keyword_normalized: | |
| alternatives = [alt.strip() for alt in keyword_normalized.split("|")] | |
| matched = any(alt in response_normalized for alt in alternatives) | |
| else: | |
| matched = keyword_normalized.strip() in response_normalized | |
| if matched: | |
| found_keywords.append(keyword) | |
| else: | |
| missing_keywords.append(keyword) | |
| total = len(expected_keywords) | |
| found_count = len(found_keywords) | |
| return { | |
| "all_found": len(missing_keywords) == 0, | |
| "match_rate": found_count / total if total else 1.0, | |
| "found_keywords": found_keywords, | |
| "missing_keywords": missing_keywords, | |
| "total_keywords": total, | |
| "found_count": found_count, | |
| } | |
| def compute_string_similarity(predicted: str, expected: str) -> float: | |
| """Compute string similarity using RapidFuzz token set ratio.""" | |
| try: | |
| from rapidfuzz import fuzz | |
| return fuzz.token_set_ratio(predicted.lower(), expected.lower()) / 100.0 | |
| except ImportError: | |
| from difflib import SequenceMatcher | |
| return SequenceMatcher(None, predicted.lower(), expected.lower()).ratio() | |
| def compute_exact_match(predicted: str, expected: str) -> bool: | |
| """Check if predicted exactly matches expected (case-insensitive).""" | |
| return predicted.strip().lower() == expected.strip().lower() | |
| def compute_final_score( | |
| exact_match: bool, | |
| similarity: float, | |
| llm_judge_score: Optional[float] = None, | |
| llm_judge_requested: bool = False, | |
| agent_output: str = "", | |
| threshold_exact: float = 0.85, | |
| threshold_inexact: float = 0.9, | |
| apis_missing: Optional[List[str]] = None, | |
| require_api_match: bool = False, | |
| ) -> int: | |
| """Compute final binary score for a task. | |
| This matches the logic in bpo_benchmark/evaluation/metrics.py for consistency. | |
| Args: | |
| exact_match: Whether output exactly matched expected | |
| similarity: String similarity score (0.0-1.0) | |
| llm_judge_score: Optional LLM judge score (0.0-1.0) | |
| llm_judge_requested: True if LLM judge was requested for this evaluation | |
| agent_output: The agent's output string (to detect failures) | |
| threshold_exact: Threshold when exact match is True | |
| threshold_inexact: Threshold when exact match is False | |
| apis_missing: List of expected APIs that were not called | |
| require_api_match: If True, require apis_missing to be empty to pass | |
| Returns: | |
| 1 if task passes, 0 otherwise | |
| """ | |
| import math | |
| # Check for task failure indicators | |
| if not agent_output or (isinstance(agent_output, str) and agent_output.startswith("ERROR:")): | |
| return 0 | |
| # Check for missing API calls when API metrics are required | |
| if require_api_match and apis_missing: | |
| return 0 | |
| # Handle missing/invalid similarity | |
| if similarity is None or (isinstance(similarity, float) and math.isnan(similarity)): | |
| return 0 | |
| # Determine the threshold based on exact match status | |
| threshold = threshold_exact if exact_match else threshold_inexact | |
| # If LLM judge was requested but failed/unavailable, return 0 | |
| if llm_judge_requested: | |
| if llm_judge_score is None or (isinstance(llm_judge_score, float) and math.isnan(llm_judge_score)): | |
| return 0 | |
| # Judge was requested and available: pass if EITHER score meets threshold | |
| if llm_judge_score >= threshold or similarity >= threshold: | |
| return 1 | |
| return 0 | |
| else: | |
| # No judge requested: use similarity only | |
| return 1 if similarity >= threshold else 0 | |
| # ============================================================================ | |
| # LLM Judge (for semantic similarity evaluation) | |
| # ============================================================================ | |
| class LLMJudge: | |
| """LLM-based semantic judge using Groq's API.""" | |
| def __init__( | |
| self, | |
| api_key: str, | |
| model: str = "llama-3.3-70b-versatile", | |
| timeout_s: int = 30, | |
| ): | |
| self.api_key = api_key | |
| self.model = model | |
| self.timeout_s = timeout_s | |
| self.base_url = "https://api.groq.com" | |
| def name(self) -> str: | |
| return f"groq:{self.model}" | |
| async def judge( | |
| self, | |
| predicted: str, | |
| expected: str, | |
| utterance: str = "", | |
| ) -> Dict[str, Any]: | |
| """Judge similarity between predicted and expected outputs. | |
| Returns: | |
| Dict with score (0.0-1.0), rationale, and metadata | |
| """ | |
| import json | |
| try: | |
| import requests | |
| except ImportError: | |
| return {"score": None, "rationale": "requests library not available", "metadata": {}} | |
| # Truncate for cost/speed | |
| utterance = str(utterance)[:500] | |
| predicted = str(predicted)[:2000] | |
| expected = str(expected)[:2000] | |
| system = ( | |
| "You are an evaluation judge assessing semantic equivalence between a PREDICTED and EXPECTED answer.\n\n" | |
| "Scoring Guidelines:\n" | |
| "- Score 1.0: Semantically identical - same meaning, entities, and facts (minor wording differences OK)\n" | |
| "- Score 0.8-0.9: Semantically equivalent - same core meaning with slight elaboration or different phrasing\n" | |
| "- Score 0.5-0.7: Partially equivalent - same topic but missing key details or extra information\n" | |
| "- Score 0.2-0.4: Somewhat related - addresses same question but with different focus or incomplete answer\n" | |
| "- Score 0.0-0.1: Unrelated or contradictory - different facts, wrong information, or completely different meaning\n\n" | |
| "CRITICAL:\n" | |
| "- Focus on SEMANTIC MEANING, not word-for-word matching or formatting\n" | |
| "- Both asking for same information (even differently phrased) should score high (0.8-1.0)\n" | |
| "- Consider context from the UTTERANCE to understand what's being asked\n" | |
| "- Be precise: don't score 0.0 unless answers are truly unrelated/contradictory\n\n" | |
| "Return ONLY valid JSON: {\"score\": <number 0.0-1.0>, \"rationale\": \"<explanation>\"}\n" | |
| ) | |
| user = ( | |
| f"UTTERANCE:\n{utterance}\n\n" | |
| f"EXPECTED:\n{expected}\n\n" | |
| f"PREDICTED:\n{predicted}\n" | |
| ) | |
| payload = { | |
| "model": self.model, | |
| "temperature": 0, | |
| "messages": [ | |
| {"role": "system", "content": system}, | |
| {"role": "user", "content": user}, | |
| ], | |
| } | |
| def _do_request() -> Dict[str, Any]: | |
| url = f"{self.base_url}/openai/v1/chat/completions" | |
| response = requests.post( | |
| url, | |
| headers={ | |
| "Authorization": f"Bearer {self.api_key}", | |
| "Content-Type": "application/json", | |
| }, | |
| json=payload, | |
| timeout=self.timeout_s, | |
| ) | |
| response.raise_for_status() | |
| return response.json() | |
| try: | |
| data = await asyncio.to_thread(_do_request) | |
| except Exception as e: | |
| logger.warning(f"LLM judge request failed: {e}") | |
| return {"score": None, "rationale": f"Request failed: {e}", "metadata": {}} | |
| content = ( | |
| data.get("choices", [{}])[0] | |
| .get("message", {}) | |
| .get("content", "") | |
| ) | |
| # Parse JSON response | |
| try: | |
| parsed = json.loads(content) | |
| except Exception: | |
| start = content.find("{") | |
| end = content.rfind("}") | |
| if start == -1 or end == -1 or end <= start: | |
| return {"score": None, "rationale": f"Invalid JSON response: {content[:200]}", "metadata": {}} | |
| try: | |
| parsed = json.loads(content[start:end + 1]) | |
| except Exception: | |
| return {"score": None, "rationale": f"Failed to parse JSON: {content[:200]}", "metadata": {}} | |
| score = parsed.get("score") | |
| if score is not None: | |
| score = float(score) | |
| score = max(0.0, min(1.0, score)) | |
| rationale = str(parsed.get("rationale", ""))[:1000] | |
| return { | |
| "score": score, | |
| "rationale": rationale, | |
| "metadata": {"judge": "groq", "model": self.model}, | |
| } | |
| def get_llm_judge(api_key: str, provider: str = "groq") -> Optional[LLMJudge]: | |
| """Get an LLM judge instance. | |
| Args: | |
| api_key: API key for the judge provider | |
| provider: Currently only "groq" is supported | |
| Returns: | |
| LLMJudge instance or None if not supported | |
| """ | |
| if provider.lower() == "groq": | |
| return LLMJudge(api_key=api_key) | |
| return None | |
| # ============================================================================ | |
| # API Call Tracking | |
| # ============================================================================ | |
| def compare_api_calls( | |
| called_apis: List[str], | |
| expected_apis: List[str], | |
| ) -> Dict[str, Any]: | |
| """Compare called APIs against expected APIs. | |
| Args: | |
| called_apis: List of API names that were called | |
| expected_apis: List of expected API names | |
| Returns: | |
| Dict with missing, extra, correct count, and match info | |
| """ | |
| # Normalize API names for comparison | |
| # Registry tool names are verbose: bpo_candidate_source_sla_per_source_candidate_source_sla_per_source_requisition_id_get | |
| # Expected names are short: candidate_source_sla_per_source | |
| def normalize_api_name(name: str) -> str: | |
| name = name.lower().strip() | |
| # Remove app prefix | |
| if name.startswith("bpo_"): | |
| name = name[4:] | |
| # Remove common suffixes (HTTP methods and parameter patterns) | |
| for suffix in ["_get", "_post", "_put", "_delete"]: | |
| if name.endswith(suffix): | |
| name = name[:-len(suffix)] | |
| for suffix in ["_requisition_id", "_skill_name"]: | |
| if name.endswith(suffix): | |
| name = name[:-len(suffix)] | |
| return name.replace("-", "_").replace(" ", "_") | |
| def api_matches(expected: str, actual: str) -> bool: | |
| """Check if expected API name matches actual (allowing for verbose registry names).""" | |
| exp_norm = normalize_api_name(expected) | |
| act_norm = normalize_api_name(actual) | |
| # Direct match | |
| if exp_norm == act_norm: | |
| return True | |
| # Check if expected is contained in actual (for verbose registry names) | |
| # e.g., "candidate_source_sla_per_source" in "candidate_source_sla_per_source_candidate_source_sla_per_source" | |
| if exp_norm in act_norm: | |
| return True | |
| return False | |
| logger.info(f"[API_TRACKING] Expected APIs: {expected_apis}") | |
| logger.info(f"[API_TRACKING] Actual APIs: {called_apis}") | |
| # Compute API metrics using flexible matching | |
| missing = [] | |
| for exp_api in expected_apis: | |
| if not any(api_matches(exp_api, act_api) for act_api in called_apis): | |
| missing.append(exp_api) | |
| extra = [] | |
| for act_api in called_apis: | |
| if not any(api_matches(exp_api, act_api) for exp_api in expected_apis): | |
| extra.append(act_api) | |
| correct = len(expected_apis) - len(missing) | |
| return { | |
| "missing": missing, | |
| "extra": extra, | |
| "correct": correct, | |
| "expected_count": len(expected_apis), | |
| "called_count": len(called_apis), | |
| "all_expected_called": len(missing) == 0, | |
| } | |