farffadet commited on
Commit
349dcd3
·
verified ·
1 Parent(s): deae344

refactor: multi-turn Judge Agent — diversity + UCC generators

Browse files
Files changed (1) hide show
  1. client.py +30 -34
client.py CHANGED
@@ -1,23 +1,34 @@
1
  """
2
  SylloGym Environment Client.
3
 
4
- Typed client for connecting to a running SylloGym server.
 
 
 
5
 
6
  Example:
 
 
 
7
  >>> env = SylloGymEnv(base_url="http://localhost:8000")
8
  >>> result = env.reset()
9
  >>> obs = result.observation
10
  >>> print(obs.rule)
11
  >>> print(obs.facts)
 
12
  >>>
13
- >>> from syllogym_env.models import SylloAction
14
- >>> action = SylloAction(
15
- ... reasoning="<reasoning>The rule states X applies when Y. The facts show Y. Therefore X applies.</reasoning>",
16
- ... answer="<answer>Yes</answer>"
17
- ... )
18
- >>> result = env.step(action)
19
- >>> print(result.observation.reward) # 0.0 to 1.3
20
- >>> print(result.observation.done) # True
 
 
 
 
21
  >>> env.close()
22
  """
23
 
@@ -31,28 +42,14 @@ from .models import SylloAction, SylloObservation, SylloState
31
 
32
  class SylloGymEnv(EnvClient[SylloAction, SylloObservation, SylloState]):
33
  """
34
- Client for the SylloGym legal reasoning environment.
35
-
36
- Connects to a SylloGym server that serves LegalBench-based
37
- syllogistic reasoning tasks. Each episode is a single-step interaction:
38
- 1. reset() → receive a legal rule + case facts
39
- 2. step(SylloAction) → submit reasoning + answer, receive reward
40
 
41
- Args:
42
- base_url: URL of the running SylloGym server.
 
 
43
 
44
- Example:
45
- >>> env = SylloGymEnv(base_url="http://localhost:8000")
46
- >>> result = env.reset()
47
- >>> obs = result.observation
48
- >>>
49
- >>> action = SylloAction(
50
- ... reasoning="<reasoning>Applying the rule to the facts...</reasoning>",
51
- ... answer="<answer>Yes</answer>"
52
- ... )
53
- >>> result = env.step(action)
54
- >>> print(f"Reward: {result.observation.reward}")
55
- >>> env.close()
56
  """
57
 
58
  def _step_payload(self, action: SylloAction) -> dict:
@@ -64,11 +61,13 @@ class SylloGymEnv(EnvClient[SylloAction, SylloObservation, SylloState]):
64
  def _parse_result(self, payload: dict) -> StepResult[SylloObservation]:
65
  obs_data = payload.get("observation", {})
66
  reward = payload.get("reward")
67
- done = bool(payload.get("done", True))
68
  # Mirror reward/done into the observation for convenience
69
  obs_data["reward"] = reward
70
  obs_data["done"] = done
71
- obs = SylloObservation(**obs_data)
 
 
72
  return StepResult(
73
  observation=obs,
74
  reward=reward,
@@ -77,11 +76,8 @@ class SylloGymEnv(EnvClient[SylloAction, SylloObservation, SylloState]):
77
 
78
  def _parse_state(self, payload: dict) -> SylloState:
79
  return SylloState(
80
- episode_id=payload.get("episode_id", ""),
81
- step_count=payload.get("step_count", 0),
82
  task_name=payload.get("task_name", ""),
83
  task_mode=payload.get("task_mode", "mixed"),
84
- current_difficulty=payload.get("current_difficulty", 1.0),
85
  total_correct=payload.get("total_correct", 0),
86
  total_steps=payload.get("total_steps", 0),
87
  )
 
1
  """
2
  SylloGym Environment Client.
3
 
4
+ Multi-turn typed client for connecting to a running SylloGym server.
5
+
6
+ The agent plays a judge who receives new facts turn by turn.
7
+ Each episode: reset() → step() → step() → ... → done=True
8
 
9
  Example:
10
+ >>> from syllogym_env import SylloGymEnv
11
+ >>> from syllogym_env.models import SylloAction
12
+ >>>
13
  >>> env = SylloGymEnv(base_url="http://localhost:8000")
14
  >>> result = env.reset()
15
  >>> obs = result.observation
16
  >>> print(obs.rule)
17
  >>> print(obs.facts)
18
+ >>> print(obs.question) # Turn 0 question
19
  >>>
20
+ >>> while not obs.done:
21
+ ... action = SylloAction(
22
+ ... reasoning="<reasoning>Applying the rule...</reasoning>",
23
+ ... answer="<answer>Yes</answer>",
24
+ ... )
25
+ ... result = env.step(action)
26
+ ... obs = result.observation
27
+ ... if not obs.done:
28
+ ... print(f"Turn {obs.layer_index}: {obs.new_info}")
29
+ ... print(f"Next question: {obs.question}")
30
+ ...
31
+ >>> print(f"Final reward: {obs.reward}")
32
  >>> env.close()
33
  """
34
 
 
42
 
43
  class SylloGymEnv(EnvClient[SylloAction, SylloObservation, SylloState]):
44
  """
45
+ Client for the SylloGym multi-turn legal reasoning environment.
 
 
 
 
 
46
 
47
+ Each episode is a sequence of steps:
48
+ reset() → Turn 0 observation (rule + initial facts + first question)
49
+ step(action) → Turn 1 observation (new_info revealed + next question), reward=1.0
50
+ step(action) → ... until done=True
51
 
52
+ Reward is dense: 1.0 for each correct answer, 0.0 terminates the episode.
 
 
 
 
 
 
 
 
 
 
 
53
  """
54
 
55
  def _step_payload(self, action: SylloAction) -> dict:
 
61
  def _parse_result(self, payload: dict) -> StepResult[SylloObservation]:
62
  obs_data = payload.get("observation", {})
63
  reward = payload.get("reward")
64
+ done = bool(payload.get("done", False))
65
  # Mirror reward/done into the observation for convenience
66
  obs_data["reward"] = reward
67
  obs_data["done"] = done
68
+ # Only pass fields that SylloObservation knows about
69
+ valid_fields = SylloObservation.model_fields
70
+ obs = SylloObservation(**{k: v for k, v in obs_data.items() if k in valid_fields})
71
  return StepResult(
72
  observation=obs,
73
  reward=reward,
 
76
 
77
  def _parse_state(self, payload: dict) -> SylloState:
78
  return SylloState(
 
 
79
  task_name=payload.get("task_name", ""),
80
  task_mode=payload.get("task_mode", "mixed"),
 
81
  total_correct=payload.get("total_correct", 0),
82
  total_steps=payload.get("total_steps", 0),
83
  )