File size: 2,472 Bytes
0b87551
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b63102
0b87551
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
"""
Pure helper functions with no heavy (model) imports, so they can be unit-tested
quickly in CI without downloading embedding/NLI models.
"""

import re

# Source markers emitted by the retrievers, e.g. "[File: resume.pdf]" / "[Source: Page 3]"
_CITATION_RE = re.compile(r"\[((?:File|Source)[^\]]*)\]")
_RETRY_RE = re.compile(r"retryDelay.*?(\d+\.?\d*)s")

_INVALID_KEY_MARKERS = (
    "API_KEY_INVALID",
    "API key not valid",
    "PERMISSION_DENIED",
    "API key expired",
)


def extract_content(message) -> str:
    """Gemini may return message content as a list of typed blocks; flatten to text."""
    content = getattr(message, "content", message)
    if isinstance(content, list):
        content = " ".join(
            block["text"] if isinstance(block, dict) else str(block)
            for block in content
            if not isinstance(block, dict) or block.get("type") == "text"
        )
    return str(content)


def parse_tool_results(messages: list) -> tuple[str, str]:
    """Return (source_type, combined_tool_output) from the agent message chain.

    source_type is 'rag', 'web', 'rag+web', or 'unknown'. The combined output is
    the actual text the tools returned, which faithfulness is measured against.
    """
    rag_parts, web_parts = [], []
    for msg in messages:
        name = getattr(msg, "name", None)
        content = getattr(msg, "content", "") or ""
        if name == "lookup_documents":
            rag_parts.append(content)
        elif name == "search_web":
            web_parts.append(content)

    if rag_parts and web_parts:
        return "rag+web", " ".join(rag_parts + web_parts)
    if rag_parts:
        return "rag", " ".join(rag_parts)
    if web_parts:
        return "web", " ".join(web_parts)
    return "unknown", ""


def extract_citations(tool_output: str) -> list[str]:
    """Pull unique source labels (file names / page numbers) from retrieved text."""
    seen, out = set(), []
    for label in _CITATION_RE.findall(tool_output or ""):
        label = label.strip()
        if label and label not in seen:
            seen.add(label)
            out.append(label)
    return out


def is_invalid_key(error_str: str) -> bool:
    return any(m in error_str for m in _INVALID_KEY_MARKERS)


def retry_delay(error_str: str) -> float:
    """Parse 'retryDelay: Xs' from a Gemini 429 error string; 0.0 if absent."""
    match = _RETRY_RE.search(error_str)
    return float(match.group(1)) if match else 0.0