Spaces:
Sleeping
Sleeping
File size: 2,468 Bytes
c2f781b 15c3238 c524b25 15c3238 c2f781b 15c3238 8f2eab9 15c3238 c2f781b 15c3238 782222a 15c3238 782222a 15c3238 c524b25 782222a 15c3238 782222a 15c3238 c2f781b d064b19 c2f781b 15c3238 c2f781b 15c3238 c2f781b 15c3238 c2f781b 15c3238 c2f781b 782222a 15c3238 c2f781b 782222a 15c3238 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 | """NeuralTuner environment client.
NeuralTuner Environment Client."""
from typing import Dict, Optional
from openenv.core import EnvClient
from openenv.core.client_types import StepResult
from models import NeuralTunerAction, NeuralTunerObservation, NeuralTunerState
class NeuralTunerEnv(EnvClient[NeuralTunerAction, NeuralTunerObservation, NeuralTunerState]):
"""
Client for the NeuralTuner Environment.
Maintains a persistent WebSocket connection to the environment server,
enabling efficient multi-step LLM-agent rollouts.
Example:
>>> with NeuralTunerEnv(base_url='http://localhost:8000') as env:
... result = env.reset()
... print(result.observation.output)
...
... action = NeuralTunerAction(action_type="profile_layer", layer_id="fc_classifier")
... result = env.step(action)
... print(result.observation.output)
"""
def _step_payload(self, action: NeuralTunerAction) -> Dict:
payload: Dict = {"action_type": action.action_type}
if action.layer_id is not None:
payload["layer_id"] = action.layer_id
if action.dtype is not None:
payload["dtype"] = action.dtype
if action.sparsity is not None:
payload["sparsity"] = action.sparsity
return payload
def _parse_result(self, payload: Dict) -> StepResult[NeuralTunerObservation]:
obs_data = payload.get("observation", {})
observation = NeuralTunerObservation(
output=obs_data.get("output", ""),
success=obs_data.get("success", True),
error=obs_data.get("error"),
done=payload.get("done", False),
reward=payload.get("reward", 0.0),
metadata=obs_data.get("metadata", {}),
)
return StepResult(
observation=observation,
reward=payload.get("reward", 0.0),
done=payload.get("done", False),
)
def _parse_state(self, payload: Dict) -> NeuralTunerState:
return NeuralTunerState(
episode_id=payload.get("episode_id"),
step_count=payload.get("step_count", 0),
model_id=payload.get("model_id", ""),
difficulty=payload.get("difficulty", "easy"),
submitted=payload.get("submitted", False),
benchmark_count=payload.get("benchmark_count", 0),
final_reward=payload.get("final_reward"),
)
|