File size: 1,901 Bytes
b4bc906
 
 
 
b22ac70
 
b4bc906
 
 
 
 
 
 
 
b22ac70
 
 
 
 
 
 
b4bc906
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b22ac70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b4bc906
 
 
 
 
 
 
 
 
9c8d6f1
 
b4bc906
 
9c8d6f1
 
b4bc906
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from collections.abc import Callable
from contextlib import contextmanager
from typing import Any

from langfuse import propagate_attributes

from gaia_agent.config import settings


def langfuse_enabled() -> bool:
    return bool(settings.langfuse_public_key and settings.langfuse_secret_key)


@contextmanager
def trace_agent_run(
    question: str,
    *,
    session_id: str | None = None,
    user_id: str | None = None,
    task_id: str | None = None,
):
    """Create a Langfuse trace when credentials are configured.

    This keeps local development and HF Space startup working before secrets are set.
    """
    if not langfuse_enabled():
        yield None
        return

    from langfuse import Langfuse

    client = Langfuse(
        public_key=settings.langfuse_public_key,
        secret_key=settings.langfuse_secret_key,
        host=settings.langfuse_host,
    )
    with propagate_attributes(
        trace_name="gaia-agent-run",
        user_id=user_id,
        session_id=session_id,
        metadata={"task_id": task_id} if task_id else None,
        tags=["gaia", "final-assignment"],
    ):
        with client.start_as_current_observation(
            name="gaia-agent-run",
            as_type="agent",
            input={"question": question},
            metadata={"component": "GaiaAgent", "task_id": task_id},
        ) as observation:
            try:
                yield observation
            finally:
                client.flush()


def traced_step(trace: Any, name: str, fn: Callable[[], dict[str, Any]]) -> dict[str, Any]:
    if trace is None:
        return fn()

    span = trace.start_observation(name=name, as_type="span")
    try:
        output = fn()
        span.update(output=output)
        span.end()
        return output
    except Exception as exc:
        span.update(level="ERROR", status_message=str(exc))
        span.end()
        raise