File size: 30,512 Bytes
d075a5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8437a4d
d075a5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8437a4d
 
d075a5b
8437a4d
d075a5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8437a4d
d075a5b
 
 
 
 
 
 
8437a4d
d075a5b
 
 
 
8437a4d
d075a5b
 
 
 
8437a4d
d075a5b
 
 
 
 
 
 
 
 
 
 
 
 
 
7b27cc2
 
 
 
 
 
 
d075a5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b27cc2
8437a4d
 
 
7b27cc2
 
8437a4d
7b27cc2
 
8437a4d
 
 
 
 
 
 
 
 
7b27cc2
 
 
 
 
 
 
 
 
8437a4d
 
 
 
 
 
 
 
 
 
 
d075a5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
"""CUGA SDK agent for BPO benchmark evaluation with Langfuse tracking."""

import asyncio
import logging
import os
import re
import threading
import time
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import uvicorn

logger = logging.getLogger(__name__)

# Global flags to track server status
_servers_started = False
_servers_lock = threading.Lock()


# ============================================================================
# Provider Configuration
# ============================================================================

PROVIDER_CONFIGS = {
    "groq": {
        "env_var": "GROQ_API_KEY",
        "settings_file": "settings.groq.toml",
        "default_model": "openai/gpt-oss-120b",
        "models": [
            "openai/gpt-oss-120b",
            "llama-3.3-70b-versatile",
            "llama-3.1-8b-instant",
            "mixtral-8x7b-32768",
        ],
        "placeholder": "gsk_...",
    },
    "openai": {
        "env_var": "OPENAI_API_KEY",
        "settings_file": "settings.openai.toml",
        "default_model": "gpt-4o-mini",
        "models": [
            "gpt-4o-mini",
            "gpt-4.1",
            "gpt-5",
            "gpt-4o",
        ],
        "placeholder": "sk-...",
    },
}


def get_provider_models(provider: str) -> List[str]:
    """Get available models for a provider."""
    config = PROVIDER_CONFIGS.get(provider.lower(), {})
    return config.get("models", [])


def get_provider_placeholder(provider: str) -> str:
    """Get API key placeholder for a provider."""
    config = PROVIDER_CONFIGS.get(provider.lower(), {})
    return config.get("placeholder", "...")


def get_default_model(provider: str) -> str:
    """Get default model for a provider."""
    config = PROVIDER_CONFIGS.get(provider.lower(), {})
    return config.get("default_model", "")


# ============================================================================
# Server Management
# ============================================================================

def start_servers():
    """Start BPO API and Registry servers if not already running."""
    global _servers_started

    with _servers_lock:
        if _servers_started:
            return
        _servers_started = True

    # Import here to avoid circular imports
    from server import app as bpo_app
    from cuga.backend.tools_env.registry.registry.api_registry_server import (
        app as registry_app,
    )

    # Start BPO API server on port 8000
    def run_bpo():
        uvicorn.run(bpo_app, host="0.0.0.0", port=8000, log_level="warning")

    bpo_thread = threading.Thread(target=run_bpo, daemon=True)
    bpo_thread.start()
    logger.info("BPO API server starting on port 8000")

    # Start Registry server on port 8001
    def run_registry():
        uvicorn.run(registry_app, host="0.0.0.0", port=8001, log_level="warning")

    registry_thread = threading.Thread(target=run_registry, daemon=True)
    registry_thread.start()
    logger.info("Registry server starting on port 8001")

    # Wait for servers to be ready
    time.sleep(4)
    logger.info("Servers started")


# ============================================================================
# Environment Setup
# ============================================================================

def setup_environment(api_key: str, provider: str, model: Optional[str] = None, policies_enabled: bool = True):
    """Set up environment variables for CUGA SDK."""
    # Clear conflicting env vars
    for key in ["OPENAI_BASE_URL", "OPENAI_API_KEY", "GROQ_API_KEY"]:
        if key in os.environ:
            del os.environ[key]

    provider_lower = provider.lower()
    config = PROVIDER_CONFIGS.get(provider_lower)

    if not config:
        raise ValueError(f"Unknown provider: {provider}. Supported: {list(PROVIDER_CONFIGS.keys())}")

    # Set provider-specific config
    os.environ[config["env_var"]] = api_key
    os.environ["AGENT_SETTING_CONFIG"] = config["settings_file"]
    os.environ["MODEL_NAME"] = model or config["default_model"]

    # Set MCP servers file path
    mcp_config = Path(__file__).parent / "mcp_servers" / "bpo.yaml"
    os.environ["MCP_SERVERS_FILE"] = str(mcp_config.resolve())

    # Policy toggle
    os.environ["DYNACONF_POLICY__ENABLED"] = "true" if policies_enabled else "false"

    logger.info(f"Environment configured: provider={provider}, model={os.environ.get('MODEL_NAME')}, policies={policies_enabled}")


# ============================================================================
# Langfuse Integration
# ============================================================================

class LangfuseTracker:
    """Tracks evaluation runs and task scores in Langfuse."""

    def __init__(self):
        self.enabled = False
        self.langfuse = None
        self.trace_id = None
        self.init_error = None
        self._init_langfuse()

    def _init_langfuse(self) -> None:
        """Initialize Langfuse client if credentials are available."""
        # Debug: show all LANGFUSE env vars
        langfuse_vars = {k: ('set' if v else 'empty') for k, v in os.environ.items() if 'LANGFUSE' in k.upper()}
        logger.info(f"Langfuse env vars found: {langfuse_vars}")

        public_key = os.environ.get("LANGFUSE_PUBLIC_KEY")
        secret_key = os.environ.get("LANGFUSE_SECRET_KEY")
        # Support both LANGFUSE_HOST and LANGFUSE_BASE_URL
        host = os.environ.get("LANGFUSE_HOST") or os.environ.get("LANGFUSE_BASE_URL") or "https://cloud.langfuse.com"

        logger.info(f"Langfuse init: public_key={'set' if public_key else 'not set'}, secret_key={'set' if secret_key else 'not set'}, host={host}")

        if not public_key or not secret_key:
            self.init_error = "Langfuse credentials not found"
            logger.info(self.init_error)
            return

        try:
            from langfuse import Langfuse

            self.langfuse = Langfuse(
                public_key=public_key,
                secret_key=secret_key,
                host=host,
            )
            # Test the connection by checking auth
            self.langfuse.auth_check()
            self.enabled = True
            logger.info(f"Langfuse tracking initialized successfully (host={host})")
        except ImportError as e:
            self.init_error = f"langfuse package not installed: {e}"
            logger.warning(self.init_error)
        except Exception as e:
            self.init_error = f"Failed to initialize Langfuse: {e}"
            logger.warning(self.init_error)

    def start_trace(self, name: str, metadata: Optional[Dict[str, Any]] = None) -> Optional[str]:
        """Start a new trace for an evaluation run."""
        if not self.enabled or not self.langfuse:
            return None

        try:
            # Use create_trace for newer Langfuse API
            trace = self.langfuse.trace(name=name, metadata=metadata or {})
            self.trace_id = trace.id
            return self.trace_id
        except AttributeError:
            # Fallback for different Langfuse versions
            try:
                self.trace_id = f"trace_{name}_{id(self)}"
                logger.info(f"Using fallback trace ID: {self.trace_id}")
                return self.trace_id
            except Exception as e:
                logger.warning(f"Failed to create trace (fallback): {e}")
                return None
        except Exception as e:
            logger.warning(f"Failed to create trace: {e}")
            return None

    def score_task(self, task_id: str, scores: Dict[str, float]) -> None:
        """Score a task within the current trace."""
        if not self.enabled or not self.langfuse or not self.trace_id:
            return

        try:
            for name, value in scores.items():
                self.langfuse.score(
                    trace_id=self.trace_id,
                    name=f"{task_id}_{name}",
                    value=value,
                )
        except Exception as e:
            logger.warning(f"Failed to score task {task_id}: {e}")

    def end_trace(self, summary: Optional[Dict[str, Any]] = None) -> None:
        """End the current trace with summary metrics."""
        if not self.enabled or not self.langfuse:
            return

        try:
            if summary and self.trace_id:
                for name, value in summary.items():
                    if isinstance(value, (int, float)) and not isinstance(value, bool):
                        self.langfuse.score(
                            trace_id=self.trace_id,
                            name=f"summary_{name}",
                            value=float(value),
                        )
            self.langfuse.flush()
        except Exception as e:
            logger.warning(f"Failed to end trace: {e}")
        finally:
            self.trace_id = None


def is_langfuse_configured() -> bool:
    """Check if Langfuse environment variables are set."""
    return bool(
        os.environ.get("LANGFUSE_PUBLIC_KEY") and
        os.environ.get("LANGFUSE_SECRET_KEY")
    )


def get_langfuse_host() -> str:
    """Get the configured Langfuse host."""
    return os.environ.get("LANGFUSE_HOST") or os.environ.get("LANGFUSE_BASE_URL") or "https://cloud.langfuse.com"


# ============================================================================
# CUGA Agent
# ============================================================================

class CUGAAgent:
    """CUGA SDK agent for BPO benchmark evaluation."""

    def __init__(
        self,
        api_key: str,
        provider: str = "groq",
        model: Optional[str] = None,
        policies_enabled: bool = True,
    ):
        """Initialize the CUGA agent.

        Args:
            api_key: API key for the LLM provider
            provider: "openai" or "groq"
            model: Model name (optional, uses defaults)
            policies_enabled: Whether to load CUGA policies (default True)
        """
        self.api_key = api_key
        self.provider = provider.lower()
        self.model = model
        self.policies_enabled = policies_enabled
        self.agent = None
        self.tool_provider = None

        # Set up environment BEFORE importing cuga modules
        setup_environment(api_key, self.provider, model, policies_enabled=policies_enabled)

        # Start servers
        start_servers()

    async def setup(self):
        """Initialize the CUGA agent with tools."""
        from cuga.sdk import CugaAgent
        from cuga.config import settings
        from cuga.backend.cuga_graph.nodes.cuga_lite.combined_tool_provider import (
            CombinedToolProvider,
        )

        logger.info("Setting up CUGA agent...")

        # Enable ActivityTracker and set policy toggle via Dynaconf directly
        # (env vars are only read on first import, so we must update settings at runtime)
        settings.update({
            "ADVANCED_FEATURES": {"TRACKER_ENABLED": True},
            "POLICY": {"ENABLED": self.policies_enabled},
        }, merge=True)
        logger.info(f"Dynaconf POLICY.ENABLED set to {self.policies_enabled}")

        # Initialize tool provider (will load from registry)
        self.tool_provider = CombinedToolProvider()
        await self.tool_provider.initialize()

        all_tools = await self.tool_provider.get_all_tools()
        logger.info(f"Loaded {len(all_tools)} tools from BPO API")

        if len(all_tools) == 0:
            raise RuntimeError("No tools loaded from registry. Check server status.")

        # Create agent
        self.agent = CugaAgent(tool_provider=self.tool_provider)
        logger.info("CUGA agent initialized")

        # Load or clear policies based on toggle
        if self.policies_enabled:
            await self._load_policies()
        else:
            await self._clear_policies()
            logger.info("Policies disabled and cleared")

    async def _clear_policies(self):
        """Remove all persisted policies from the vector DB."""
        try:
            existing = await self.agent.policies.list()
            for policy in existing:
                await self.agent.policies.delete(policy["id"])
            if existing:
                logger.info(f"Cleared {len(existing)} persisted policies")
        except Exception as e:
            logger.warning(f"Failed to clear policies: {e}")

    async def _load_policies(self):
        """Load policies from policies.json using CUGA SDK."""
        policies_json = Path(__file__).parent / "policies" / "policies.json"
        if not policies_json.exists():
            logger.warning(f"policies.json not found: {policies_json}")
            return

        await self._clear_policies()

        try:
            result = await self.agent.policies.load_from_json(
                str(policies_json), clear_existing=True
            )
            logger.info(f"Loaded {result['count']} policies from policies.json")
            if result.get("errors"):
                for err in result["errors"]:
                    logger.warning(f"Policy load warning: {err}")
        except Exception as e:
            logger.warning(f"Failed to load policies: {e}")

    async def run(self, query: str, thread_id: Optional[str] = None) -> Tuple[str, List[Dict[str, Any]]]:
        """Run the agent on a query.

        Args:
            query: The user's question
            thread_id: Optional thread ID for conversation context

        Returns:
            Tuple of (response_text, tool_calls)
        """
        if self.agent is None:
            await self.setup()

        from langchain_core.messages import HumanMessage

        # Get ActivityTracker singleton and reset for this task
        try:
            from cuga.backend.activity_tracker.tracker import ActivityTracker
            tracker = ActivityTracker()
            tracker.reset(intent=query, task_id=thread_id or "eval_task")
        except ImportError:
            tracker = None
            logger.warning("ActivityTracker not available, tool call tracking disabled")

        result = await self.agent.invoke(
            [HumanMessage(content=query)],
            thread_id=thread_id or "eval_task",
            track_tool_calls=True,  # Required for ActivityTracker to capture tool calls
        )

        # Debug: log result object structure
        result_attrs = [attr for attr in dir(result) if not attr.startswith('_')]
        logger.info(f"Result object attributes: {result_attrs}")
        if hasattr(result, '__dict__'):
            logger.info(f"Result __dict__ keys: {list(result.__dict__.keys())}")

        # Extract response
        response = result.answer if hasattr(result, "answer") else str(result)

        # Extract tool calls from ActivityTracker.steps (same approach as sdk_eval_helpers.py)
        tool_calls = []
        if tracker is not None:
            import json
            logger.info(f"ActivityTracker has {len(tracker.steps)} steps")

            # Debug: log step names to understand structure (first 5 only)
            step_names = [s.name for s in tracker.steps[:5]]
            logger.info(f"First step names: {step_names}")

            # Match "api_call" in step name (the standard CUGA SDK pattern)
            for step in tracker.steps:
                if step.name and "api_call" in step.name:
                    try:
                        call_data = json.loads(step.data) if step.data else {}
                        tool_name = call_data.get("function_name", "")
                        if tool_name:
                            tool_calls.append({
                                "name": tool_name,
                                "args": call_data.get("args", {}),
                            })
                    except (json.JSONDecodeError, TypeError) as e:
                        logger.warning(f"Failed to parse tool call step data: {e}")
                        continue

            logger.info(f"Extracted {len(tool_calls)} tool calls from ActivityTracker")

        # Fallback 1: try to extract from result.tool_calls attribute
        if not tool_calls and hasattr(result, 'tool_calls') and result.tool_calls:
            logger.info("Trying fallback: result.tool_calls")
            for tc in result.tool_calls:
                if isinstance(tc, dict):
                    tool_calls.append({"name": tc.get("name", ""), "args": tc.get("args", {})})
                elif hasattr(tc, 'name'):
                    tool_calls.append({"name": tc.name, "args": getattr(tc, 'args', {})})
            logger.info(f"Fallback extracted {len(tool_calls)} tool calls")

        return response, tool_calls

    def close(self):
        """Clean up resources."""
        pass  # Servers run as daemons, will stop with process


# ============================================================================
# Evaluation Metrics (copied from main repo for standalone use)
# ============================================================================

def normalize_text(text: str) -> str:
    """Normalize text for keyword matching."""
    import unicodedata

    text = unicodedata.normalize("NFC", text)
    # Replace special spaces
    text = text.replace("\u202f", " ").replace("\u00a0", " ").replace("\u2009", " ")
    # Replace dashes
    text = text.replace("\u2013", "-").replace("\u2014", "-").replace("\u2212", "-")
    text = text.lower()
    # Remove markdown
    text = re.sub(r"[`*_~]", "", text)
    # Replace punctuation except | (for OR alternatives)
    text = re.sub(r"[^\w\s%|]", " ", text)
    # Collapse whitespace
    text = re.sub(r"\s+", " ", text).strip()
    return text


def check_keywords(response: str, expected_keywords: List[str]) -> Dict[str, Any]:
    """Check if expected keywords are present in the response.

    Supports:
    - OR mechanism: keywords can use "|" to specify alternatives
    - Regex keywords: prefix with "re:" to use regex pattern

    Args:
        response: Agent's response text
        expected_keywords: List of keywords (can use "|" for OR, "re:" for regex)

    Returns:
        Dictionary with keyword check results
    """
    if not expected_keywords:
        return {
            "all_found": True,
            "match_rate": 1.0,
            "found_keywords": [],
            "missing_keywords": [],
            "total_keywords": 0,
            "found_count": 0,
        }

    response_normalized = normalize_text(response)
    found_keywords = []
    missing_keywords = []

    for keyword in expected_keywords:
        # Regex keyword support
        if keyword.strip().lower().startswith("re:"):
            pattern = keyword.strip()[3:]
            if re.search(pattern, response_normalized, flags=re.IGNORECASE):
                found_keywords.append(keyword)
            else:
                missing_keywords.append(keyword)
            continue

        keyword_normalized = normalize_text(keyword)

        # OR alternatives
        if "|" in keyword_normalized:
            alternatives = [alt.strip() for alt in keyword_normalized.split("|")]
            matched = any(alt in response_normalized for alt in alternatives)
        else:
            matched = keyword_normalized.strip() in response_normalized

        if matched:
            found_keywords.append(keyword)
        else:
            missing_keywords.append(keyword)

    total = len(expected_keywords)
    found_count = len(found_keywords)

    return {
        "all_found": len(missing_keywords) == 0,
        "match_rate": found_count / total if total else 1.0,
        "found_keywords": found_keywords,
        "missing_keywords": missing_keywords,
        "total_keywords": total,
        "found_count": found_count,
    }


def compute_string_similarity(predicted: str, expected: str) -> float:
    """Compute string similarity using RapidFuzz token set ratio."""
    try:
        from rapidfuzz import fuzz
        return fuzz.token_set_ratio(predicted.lower(), expected.lower()) / 100.0
    except ImportError:
        from difflib import SequenceMatcher
        return SequenceMatcher(None, predicted.lower(), expected.lower()).ratio()


def compute_exact_match(predicted: str, expected: str) -> bool:
    """Check if predicted exactly matches expected (case-insensitive)."""
    return predicted.strip().lower() == expected.strip().lower()


def compute_final_score(
    exact_match: bool,
    similarity: float,
    llm_judge_score: Optional[float] = None,
    llm_judge_requested: bool = False,
    agent_output: str = "",
    threshold_exact: float = 0.85,
    threshold_inexact: float = 0.9,
    apis_missing: Optional[List[str]] = None,
    require_api_match: bool = False,
) -> int:
    """Compute final binary score for a task.

    This matches the logic in bpo_benchmark/evaluation/metrics.py for consistency.

    Args:
        exact_match: Whether output exactly matched expected
        similarity: String similarity score (0.0-1.0)
        llm_judge_score: Optional LLM judge score (0.0-1.0)
        llm_judge_requested: True if LLM judge was requested for this evaluation
        agent_output: The agent's output string (to detect failures)
        threshold_exact: Threshold when exact match is True
        threshold_inexact: Threshold when exact match is False
        apis_missing: List of expected APIs that were not called
        require_api_match: If True, require apis_missing to be empty to pass

    Returns:
        1 if task passes, 0 otherwise
    """
    import math

    # Check for task failure indicators
    if not agent_output or (isinstance(agent_output, str) and agent_output.startswith("ERROR:")):
        return 0

    # Check for missing API calls when API metrics are required
    if require_api_match and apis_missing:
        return 0

    # Handle missing/invalid similarity
    if similarity is None or (isinstance(similarity, float) and math.isnan(similarity)):
        return 0

    # Determine the threshold based on exact match status
    threshold = threshold_exact if exact_match else threshold_inexact

    # If LLM judge was requested but failed/unavailable, return 0
    if llm_judge_requested:
        if llm_judge_score is None or (isinstance(llm_judge_score, float) and math.isnan(llm_judge_score)):
            return 0
        # Judge was requested and available: pass if EITHER score meets threshold
        if llm_judge_score >= threshold or similarity >= threshold:
            return 1
        return 0
    else:
        # No judge requested: use similarity only
        return 1 if similarity >= threshold else 0


# ============================================================================
# LLM Judge (for semantic similarity evaluation)
# ============================================================================

class LLMJudge:
    """LLM-based semantic judge using Groq's API."""

    def __init__(
        self,
        api_key: str,
        model: str = "llama-3.3-70b-versatile",
        timeout_s: int = 30,
    ):
        self.api_key = api_key
        self.model = model
        self.timeout_s = timeout_s
        self.base_url = "https://api.groq.com"

    @property
    def name(self) -> str:
        return f"groq:{self.model}"

    async def judge(
        self,
        predicted: str,
        expected: str,
        utterance: str = "",
    ) -> Dict[str, Any]:
        """Judge similarity between predicted and expected outputs.

        Returns:
            Dict with score (0.0-1.0), rationale, and metadata
        """
        import json

        try:
            import requests
        except ImportError:
            return {"score": None, "rationale": "requests library not available", "metadata": {}}

        # Truncate for cost/speed
        utterance = str(utterance)[:500]
        predicted = str(predicted)[:2000]
        expected = str(expected)[:2000]

        system = (
            "You are an evaluation judge assessing semantic equivalence between a PREDICTED and EXPECTED answer.\n\n"
            "Scoring Guidelines:\n"
            "- Score 1.0: Semantically identical - same meaning, entities, and facts (minor wording differences OK)\n"
            "- Score 0.8-0.9: Semantically equivalent - same core meaning with slight elaboration or different phrasing\n"
            "- Score 0.5-0.7: Partially equivalent - same topic but missing key details or extra information\n"
            "- Score 0.2-0.4: Somewhat related - addresses same question but with different focus or incomplete answer\n"
            "- Score 0.0-0.1: Unrelated or contradictory - different facts, wrong information, or completely different meaning\n\n"
            "CRITICAL:\n"
            "- Focus on SEMANTIC MEANING, not word-for-word matching or formatting\n"
            "- Both asking for same information (even differently phrased) should score high (0.8-1.0)\n"
            "- Consider context from the UTTERANCE to understand what's being asked\n"
            "- Be precise: don't score 0.0 unless answers are truly unrelated/contradictory\n\n"
            "Return ONLY valid JSON: {\"score\": <number 0.0-1.0>, \"rationale\": \"<explanation>\"}\n"
        )

        user = (
            f"UTTERANCE:\n{utterance}\n\n"
            f"EXPECTED:\n{expected}\n\n"
            f"PREDICTED:\n{predicted}\n"
        )

        payload = {
            "model": self.model,
            "temperature": 0,
            "messages": [
                {"role": "system", "content": system},
                {"role": "user", "content": user},
            ],
        }

        def _do_request() -> Dict[str, Any]:
            url = f"{self.base_url}/openai/v1/chat/completions"
            response = requests.post(
                url,
                headers={
                    "Authorization": f"Bearer {self.api_key}",
                    "Content-Type": "application/json",
                },
                json=payload,
                timeout=self.timeout_s,
            )
            response.raise_for_status()
            return response.json()

        try:
            data = await asyncio.to_thread(_do_request)
        except Exception as e:
            logger.warning(f"LLM judge request failed: {e}")
            return {"score": None, "rationale": f"Request failed: {e}", "metadata": {}}

        content = (
            data.get("choices", [{}])[0]
            .get("message", {})
            .get("content", "")
        )

        # Parse JSON response
        try:
            parsed = json.loads(content)
        except Exception:
            start = content.find("{")
            end = content.rfind("}")
            if start == -1 or end == -1 or end <= start:
                return {"score": None, "rationale": f"Invalid JSON response: {content[:200]}", "metadata": {}}
            try:
                parsed = json.loads(content[start:end + 1])
            except Exception:
                return {"score": None, "rationale": f"Failed to parse JSON: {content[:200]}", "metadata": {}}

        score = parsed.get("score")
        if score is not None:
            score = float(score)
            score = max(0.0, min(1.0, score))

        rationale = str(parsed.get("rationale", ""))[:1000]

        return {
            "score": score,
            "rationale": rationale,
            "metadata": {"judge": "groq", "model": self.model},
        }


def get_llm_judge(api_key: str, provider: str = "groq") -> Optional[LLMJudge]:
    """Get an LLM judge instance.

    Args:
        api_key: API key for the judge provider
        provider: Currently only "groq" is supported

    Returns:
        LLMJudge instance or None if not supported
    """
    if provider.lower() == "groq":
        return LLMJudge(api_key=api_key)
    return None


# ============================================================================
# API Call Tracking
# ============================================================================

def compare_api_calls(
    called_apis: List[str],
    expected_apis: List[str],
) -> Dict[str, Any]:
    """Compare called APIs against expected APIs.

    Args:
        called_apis: List of API names that were called
        expected_apis: List of expected API names

    Returns:
        Dict with missing, extra, correct count, and match info
    """
    # Normalize API names for comparison
    # Registry tool names are verbose: bpo_candidate_source_sla_per_source_candidate_source_sla_per_source_requisition_id_get
    # Expected names are short: candidate_source_sla_per_source
    def normalize_api_name(name: str) -> str:
        name = name.lower().strip()
        # Remove app prefix
        if name.startswith("bpo_"):
            name = name[4:]
        # Remove common suffixes (HTTP methods and parameter patterns)
        for suffix in ["_get", "_post", "_put", "_delete"]:
            if name.endswith(suffix):
                name = name[:-len(suffix)]
        for suffix in ["_requisition_id", "_skill_name"]:
            if name.endswith(suffix):
                name = name[:-len(suffix)]
        return name.replace("-", "_").replace(" ", "_")

    def api_matches(expected: str, actual: str) -> bool:
        """Check if expected API name matches actual (allowing for verbose registry names)."""
        exp_norm = normalize_api_name(expected)
        act_norm = normalize_api_name(actual)
        # Direct match
        if exp_norm == act_norm:
            return True
        # Check if expected is contained in actual (for verbose registry names)
        # e.g., "candidate_source_sla_per_source" in "candidate_source_sla_per_source_candidate_source_sla_per_source"
        if exp_norm in act_norm:
            return True
        return False

    logger.info(f"[API_TRACKING] Expected APIs: {expected_apis}")
    logger.info(f"[API_TRACKING] Actual APIs: {called_apis}")

    # Compute API metrics using flexible matching
    missing = []
    for exp_api in expected_apis:
        if not any(api_matches(exp_api, act_api) for act_api in called_apis):
            missing.append(exp_api)

    extra = []
    for act_api in called_apis:
        if not any(api_matches(exp_api, act_api) for exp_api in expected_apis):
            extra.append(act_api)

    correct = len(expected_apis) - len(missing)

    return {
        "missing": missing,
        "extra": extra,
        "correct": correct,
        "expected_count": len(expected_apis),
        "called_count": len(called_apis),
        "all_expected_called": len(missing) == 0,
    }