Spaces:
Sleeping
Sleeping
File size: 3,129 Bytes
d1221ff | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 | """HTTP client for the Long-Context Summarization environment.
Usage:
from client import SummarizationClient
client = SummarizationClient(base_url="http://localhost:7860")
obs = client.reset(task_name="easy", seed=42)
while not obs.done:
action_text = my_llm(obs.messages)
obs = client.step(action_text)
print(f"Final reward: {obs.reward}")
"""
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
import requests
from typing import Optional
from models import SummarizationObservation, SummarizationState
class SummarizationClient:
"""Thin HTTP wrapper around the summarization environment REST API."""
def __init__(self, base_url: str = "http://localhost:7860", timeout: int = 60):
self.base_url = base_url.rstrip("/")
self.timeout = timeout
# ------------------------------------------------------------------
# Core API
# ------------------------------------------------------------------
def reset(
self,
task_name: Optional[str] = None,
seed: Optional[int] = None,
) -> SummarizationObservation:
"""Reset the environment and return the initial observation."""
payload: dict = {}
if task_name is not None:
payload["task_name"] = task_name
if seed is not None:
payload["seed"] = seed
resp = requests.post(
f"{self.base_url}/reset", json=payload, timeout=self.timeout
)
resp.raise_for_status()
return self._parse_response(resp.json())
def step(self, response: str) -> SummarizationObservation:
"""Send an action (text response) and return the next observation."""
payload = {"action": {"response": response}}
resp = requests.post(
f"{self.base_url}/step", json=payload, timeout=self.timeout
)
resp.raise_for_status()
return self._parse_response(resp.json())
def state(self) -> SummarizationState:
"""Return current episode metadata."""
resp = requests.get(f"{self.base_url}/state", timeout=self.timeout)
resp.raise_for_status()
return SummarizationState(**resp.json())
def health(self) -> dict:
resp = requests.get(f"{self.base_url}/health", timeout=self.timeout)
resp.raise_for_status()
return resp.json()
# ------------------------------------------------------------------
# Internal
# ------------------------------------------------------------------
def _parse_response(self, data: dict) -> SummarizationObservation:
"""Parse /reset or /step JSON response into a typed observation."""
obs_data: dict = data.get("observation", data)
# Top-level reward/done may override what's inside obs_data
if "reward" in data and data["reward"] is not None:
obs_data = dict(obs_data)
obs_data["reward"] = data["reward"]
if "done" in data:
obs_data = dict(obs_data)
obs_data["done"] = data["done"]
return SummarizationObservation(**obs_data)
|