Nigou Julien commited on
Commit
b22ac70
·
1 Parent(s): 9c8d6f1

Use LiteLLM for agent model calls

Browse files
.env.example CHANGED
@@ -1,9 +1,21 @@
1
  # Copy this file to .env for local development.
2
  # Do not commit real secrets.
3
 
4
- # LLM provider
 
 
 
 
 
 
 
 
 
 
 
5
  OPENAI_API_KEY=
6
- OPENAI_MODEL=gpt-4o-mini
 
7
 
8
  # Langfuse tracing
9
  LANGFUSE_PUBLIC_KEY=
 
1
  # Copy this file to .env for local development.
2
  # Do not commit real secrets.
3
 
4
+ # LiteLLM settings.
5
+ # Example models:
6
+ # openai/gpt-4o-mini
7
+ # anthropic/claude-3-5-sonnet-latest
8
+ # gemini/gemini-2.0-flash
9
+ LITELLM_MODEL=
10
+ LITELLM_TEMPERATURE=0
11
+ LITELLM_API_KEY=
12
+ LITELLM_API_BASE=
13
+
14
+ # Provider keys are still read by LiteLLM when you call provider-backed models.
15
+ # Set only the keys you need for your selected LITELLM_MODEL.
16
  OPENAI_API_KEY=
17
+ ANTHROPIC_API_KEY=
18
+ GEMINI_API_KEY=
19
 
20
  # Langfuse tracing
21
  LANGFUSE_PUBLIC_KEY=
README.md CHANGED
@@ -12,4 +12,30 @@ hf_oauth: true
12
  hf_oauth_expiration_minutes: 480
13
  ---
14
 
15
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  hf_oauth_expiration_minutes: 480
13
  ---
14
 
15
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
16
+
17
+ ## Local Setup
18
+
19
+ Install dependencies and run the smoke checks with `uv`:
20
+
21
+ ```bash
22
+ uv sync
23
+ uv run pytest
24
+ uv run python scripts/run_one.py
25
+ ```
26
+
27
+ Copy `.env.example` to `.env` and set the LiteLLM model you want to test:
28
+
29
+ ```env
30
+ LITELLM_MODEL=anthropic/claude-3-5-sonnet-latest
31
+ ANTHROPIC_API_KEY=...
32
+ ```
33
+
34
+ LiteLLM can also route OpenAI-compatible models through the same code path:
35
+
36
+ ```env
37
+ LITELLM_MODEL=openai/gpt-4o-mini
38
+ OPENAI_API_KEY=...
39
+ ```
40
+
41
+ If you run a LiteLLM proxy, set `LITELLM_API_BASE` and `LITELLM_API_KEY`.
app.py CHANGED
@@ -1,4 +1,7 @@
1
  import os
 
 
 
2
  import gradio as gr
3
  import requests
4
  import pandas as pd
@@ -63,6 +66,11 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
63
  results_log = []
64
  answers_payload = []
65
  print(f"Running agent on {len(questions_data)} questions...")
 
 
 
 
 
66
  for item in questions_data:
67
  task_id = item.get("task_id")
68
  question_text = item.get("question")
@@ -70,7 +78,12 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
70
  print(f"Skipping item with missing task_id or question: {item}")
71
  continue
72
  try:
73
- submitted_answer = agent(question_text)
 
 
 
 
 
74
  answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
75
  results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
76
  except Exception as e:
 
1
  import os
2
+ from datetime import UTC, datetime
3
+ from uuid import uuid4
4
+
5
  import gradio as gr
6
  import requests
7
  import pandas as pd
 
66
  results_log = []
67
  answers_payload = []
68
  print(f"Running agent on {len(questions_data)} questions...")
69
+ session_id = (
70
+ f"gaia-eval-{username.strip()}-"
71
+ f"{datetime.now(UTC).strftime('%Y%m%dT%H%M%SZ')}-"
72
+ f"{uuid4().hex[:8]}"
73
+ )
74
  for item in questions_data:
75
  task_id = item.get("task_id")
76
  question_text = item.get("question")
 
78
  print(f"Skipping item with missing task_id or question: {item}")
79
  continue
80
  try:
81
+ submitted_answer = agent(
82
+ question_text,
83
+ session_id=session_id,
84
+ user_id=username.strip(),
85
+ task_id=task_id,
86
+ )
87
  answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
88
  results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
89
  except Exception as e:
gaia_agent/agent.py CHANGED
@@ -3,13 +3,26 @@ from gaia_agent.observability import trace_agent_run
3
 
4
 
5
  class GaiaAgent:
6
- def __init__(self):
 
7
  print("GaiaAgent initialized.")
8
 
9
- def __call__(self, question: str) -> str:
 
 
 
 
 
 
 
10
  print(f"Agent received question (first 80 chars): {question[:80]}...")
11
- with trace_agent_run(question) as trace:
12
- graph = build_graph(trace=trace)
 
 
 
 
 
13
  result = graph.invoke({"question": question})
14
  final_answer = result["final_answer"]
15
  print(f"Agent returning answer: {final_answer}")
 
3
 
4
 
5
  class GaiaAgent:
6
+ def __init__(self, llm=None):
7
+ self.llm = llm
8
  print("GaiaAgent initialized.")
9
 
10
+ def __call__(
11
+ self,
12
+ question: str,
13
+ *,
14
+ session_id: str | None = None,
15
+ user_id: str | None = None,
16
+ task_id: str | None = None,
17
+ ) -> str:
18
  print(f"Agent received question (first 80 chars): {question[:80]}...")
19
+ with trace_agent_run(
20
+ question,
21
+ session_id=session_id,
22
+ user_id=user_id,
23
+ task_id=task_id,
24
+ ) as trace:
25
+ graph = build_graph(trace=trace, llm=self.llm)
26
  result = graph.invoke({"question": question})
27
  final_answer = result["final_answer"]
28
  print(f"Agent returning answer: {final_answer}")
gaia_agent/answer.py CHANGED
@@ -1,3 +1,3 @@
1
  def normalize_answer(answer: str) -> str:
2
  """Apply minimal GAIA answer cleanup without changing meaning."""
3
- return answer.strip()
 
1
  def normalize_answer(answer: str) -> str:
2
  """Apply minimal GAIA answer cleanup without changing meaning."""
3
+ return answer.strip().removesuffix(".")
gaia_agent/config.py CHANGED
@@ -9,8 +9,10 @@ load_dotenv()
9
 
10
  @dataclass(frozen=True)
11
  class Settings:
12
- openai_api_key: str | None = os.getenv("OPENAI_API_KEY")
13
- openai_model: str = os.getenv("OPENAI_MODEL", "gpt-4o-mini")
 
 
14
  langfuse_public_key: str | None = os.getenv("LANGFUSE_PUBLIC_KEY")
15
  langfuse_secret_key: str | None = os.getenv("LANGFUSE_SECRET_KEY")
16
  langfuse_host: str = os.getenv("LANGFUSE_HOST", "https://cloud.langfuse.com")
 
9
 
10
  @dataclass(frozen=True)
11
  class Settings:
12
+ litellm_model: str | None = os.getenv("LITELLM_MODEL")
13
+ litellm_temperature: float = float(os.getenv("LITELLM_TEMPERATURE", "0"))
14
+ litellm_api_key: str | None = os.getenv("LITELLM_API_KEY")
15
+ litellm_api_base: str | None = os.getenv("LITELLM_API_BASE")
16
  langfuse_public_key: str | None = os.getenv("LANGFUSE_PUBLIC_KEY")
17
  langfuse_secret_key: str | None = os.getenv("LANGFUSE_SECRET_KEY")
18
  langfuse_host: str = os.getenv("LANGFUSE_HOST", "https://cloud.langfuse.com")
gaia_agent/graph.py CHANGED
@@ -1,19 +1,25 @@
1
  from langgraph.graph import END, StateGraph
2
 
3
  from gaia_agent.answer import normalize_answer
 
4
  from gaia_agent.observability import traced_step
 
5
  from gaia_agent.state import GaiaState
6
 
7
 
8
- PLACEHOLDER_ANSWER = "Julien test"
9
-
10
-
11
- def build_graph(trace=None):
12
  graph = StateGraph(GaiaState)
 
13
 
14
  def draft_answer(state: GaiaState) -> GaiaState:
15
  def run() -> dict[str, str]:
16
- return {"draft_answer": PLACEHOLDER_ANSWER}
 
 
 
 
 
 
17
 
18
  return traced_step(trace, "draft_answer", run)
19
 
 
1
  from langgraph.graph import END, StateGraph
2
 
3
  from gaia_agent.answer import normalize_answer
4
+ from gaia_agent.llms import create_chat_model
5
  from gaia_agent.observability import traced_step
6
+ from gaia_agent.prompts import DUMMY_LLM_TEST_PROMPT
7
  from gaia_agent.state import GaiaState
8
 
9
 
10
+ def build_graph(trace=None, llm=None):
 
 
 
11
  graph = StateGraph(GaiaState)
12
+ chat_model = llm or create_chat_model()
13
 
14
  def draft_answer(state: GaiaState) -> GaiaState:
15
  def run() -> dict[str, str]:
16
+ response = chat_model.invoke(
17
+ [
18
+ ("system", DUMMY_LLM_TEST_PROMPT),
19
+ ("user", state["question"]),
20
+ ]
21
+ )
22
+ return {"draft_answer": str(response.content)}
23
 
24
  return traced_step(trace, "draft_answer", run)
25
 
gaia_agent/llms.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.language_models.chat_models import BaseChatModel
2
+ from langchain_litellm import ChatLiteLLM
3
+
4
+ from gaia_agent.config import Settings, settings
5
+
6
+
7
+ def create_chat_model(config: Settings = settings) -> BaseChatModel:
8
+ """Create the configured LiteLLM chat model."""
9
+ if not config.litellm_model:
10
+ raise ValueError("LITELLM_MODEL must be set to create a chat model.")
11
+
12
+ kwargs = {
13
+ "model": config.litellm_model,
14
+ "temperature": config.litellm_temperature,
15
+ }
16
+ if config.litellm_api_key:
17
+ kwargs["api_key"] = config.litellm_api_key
18
+ if config.litellm_api_base:
19
+ kwargs["api_base"] = config.litellm_api_base
20
+
21
+ return ChatLiteLLM(**kwargs)
gaia_agent/observability.py CHANGED
@@ -2,6 +2,8 @@ from collections.abc import Callable
2
  from contextlib import contextmanager
3
  from typing import Any
4
 
 
 
5
  from gaia_agent.config import settings
6
 
7
 
@@ -10,7 +12,13 @@ def langfuse_enabled() -> bool:
10
 
11
 
12
  @contextmanager
13
- def trace_agent_run(question: str):
 
 
 
 
 
 
14
  """Create a Langfuse trace when credentials are configured.
15
 
16
  This keeps local development and HF Space startup working before secrets are set.
@@ -26,16 +34,23 @@ def trace_agent_run(question: str):
26
  secret_key=settings.langfuse_secret_key,
27
  host=settings.langfuse_host,
28
  )
29
- with client.start_as_current_observation(
30
- name="gaia-agent-run",
31
- as_type="agent",
32
- input={"question": question},
33
- metadata={"component": "GaiaAgent"},
34
- ) as observation:
35
- try:
36
- yield observation
37
- finally:
38
- client.flush()
 
 
 
 
 
 
 
39
 
40
 
41
  def traced_step(trace: Any, name: str, fn: Callable[[], dict[str, Any]]) -> dict[str, Any]:
 
2
  from contextlib import contextmanager
3
  from typing import Any
4
 
5
+ from langfuse import propagate_attributes
6
+
7
  from gaia_agent.config import settings
8
 
9
 
 
12
 
13
 
14
  @contextmanager
15
+ def trace_agent_run(
16
+ question: str,
17
+ *,
18
+ session_id: str | None = None,
19
+ user_id: str | None = None,
20
+ task_id: str | None = None,
21
+ ):
22
  """Create a Langfuse trace when credentials are configured.
23
 
24
  This keeps local development and HF Space startup working before secrets are set.
 
34
  secret_key=settings.langfuse_secret_key,
35
  host=settings.langfuse_host,
36
  )
37
+ with propagate_attributes(
38
+ trace_name="gaia-agent-run",
39
+ user_id=user_id,
40
+ session_id=session_id,
41
+ metadata={"task_id": task_id} if task_id else None,
42
+ tags=["gaia", "final-assignment"],
43
+ ):
44
+ with client.start_as_current_observation(
45
+ name="gaia-agent-run",
46
+ as_type="agent",
47
+ input={"question": question},
48
+ metadata={"component": "GaiaAgent", "task_id": task_id},
49
+ ) as observation:
50
+ try:
51
+ yield observation
52
+ finally:
53
+ client.flush()
54
 
55
 
56
  def traced_step(trace: Any, name: str, fn: Callable[[], dict[str, Any]]) -> dict[str, Any]:
gaia_agent/prompts.py CHANGED
@@ -3,3 +3,9 @@ Return only the final answer.
3
  The answer should be a number, as few words as possible, or a comma-separated
4
  list of numbers and/or strings.
5
  """.strip()
 
 
 
 
 
 
 
3
  The answer should be a number, as few words as possible, or a comma-separated
4
  list of numbers and/or strings.
5
  """.strip()
6
+
7
+
8
+ DUMMY_LLM_TEST_PROMPT = """
9
+ You are testing the LLM connection for a GAIA agent.
10
+ Answer the user question directly in a few words.
11
+ """.strip()
pyproject.toml CHANGED
@@ -10,9 +10,9 @@ dependencies = [
10
  "pandas>=2.2.0",
11
  "python-dotenv>=1.0.1",
12
  "langchain>=0.3.0",
13
- "langchain-openai>=0.3.0",
14
  "langgraph>=0.2.60",
15
  "langfuse>=2.57.0",
 
16
  ]
17
 
18
  [build-system]
 
10
  "pandas>=2.2.0",
11
  "python-dotenv>=1.0.1",
12
  "langchain>=0.3.0",
 
13
  "langgraph>=0.2.60",
14
  "langfuse>=2.57.0",
15
+ "langchain-litellm>=0.6.4",
16
  ]
17
 
18
  [build-system]
tests/test_answer.py CHANGED
@@ -3,3 +3,7 @@ from gaia_agent.answer import normalize_answer
3
 
4
  def test_normalize_answer_strips_whitespace():
5
  assert normalize_answer(" Paris \n") == "Paris"
 
 
 
 
 
3
 
4
  def test_normalize_answer_strips_whitespace():
5
  assert normalize_answer(" Paris \n") == "Paris"
6
+
7
+
8
+ def test_normalize_answer_removes_trailing_period():
9
+ assert normalize_answer("Paris.") == "Paris"
tests/test_graph_smoke.py CHANGED
@@ -1,7 +1,22 @@
 
 
1
  from gaia_agent.agent import GaiaAgent
2
 
3
 
4
- def test_agent_returns_placeholder_answer():
5
- agent = GaiaAgent()
 
 
 
 
 
6
 
7
- assert agent("What is the answer?") == "Julien test"
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.messages import AIMessage
2
+
3
  from gaia_agent.agent import GaiaAgent
4
 
5
 
6
+ class FakeChatModel:
7
+ def invoke(self, messages):
8
+ return AIMessage(content="Dummy LLM answer")
9
+
10
+
11
+ def test_agent_returns_llm_answer():
12
+ agent = GaiaAgent(llm=FakeChatModel())
13
 
14
+ assert (
15
+ agent(
16
+ "What is the answer?",
17
+ session_id="test-session",
18
+ user_id="test-user",
19
+ task_id="test-task",
20
+ )
21
+ == "Dummy LLM answer"
22
+ )
tests/test_llms.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+
3
+ from gaia_agent.config import Settings
4
+ from gaia_agent.llms import create_chat_model
5
+
6
+
7
+ def test_create_litellm_chat_model():
8
+ model = create_chat_model(
9
+ Settings(
10
+ litellm_model="anthropic/claude-3-5-sonnet-latest",
11
+ litellm_api_key="test-key",
12
+ )
13
+ )
14
+
15
+ assert type(model).__name__ == "ChatLiteLLM"
16
+ assert model.model == "anthropic/claude-3-5-sonnet-latest"
17
+
18
+
19
+ def test_create_litellm_chat_model_supports_proxy_settings():
20
+ model = create_chat_model(
21
+ Settings(
22
+ litellm_model="gaia-router",
23
+ litellm_api_key="test-key",
24
+ litellm_api_base="http://localhost:4000",
25
+ )
26
+ )
27
+
28
+ assert type(model).__name__ == "ChatLiteLLM"
29
+ assert model.model == "gaia-router"
30
+ assert model.api_base == "http://localhost:4000"
31
+
32
+
33
+ def test_create_chat_model_requires_litellm_model():
34
+ with pytest.raises(ValueError, match="LITELLM_MODEL must be set"):
35
+ create_chat_model(Settings(litellm_model=None))
uv.lock CHANGED
The diff for this file is too large to render. See raw diff