File size: 2,468 Bytes
c2f781b
15c3238
c524b25
15c3238
c2f781b
15c3238
 
 
 
8f2eab9
15c3238
 
c2f781b
15c3238
782222a
15c3238
782222a
 
15c3238
 
c524b25
782222a
 
15c3238
782222a
 
 
15c3238
 
 
c2f781b
 
 
 
 
d064b19
 
c2f781b
15c3238
 
 
 
c2f781b
 
 
15c3238
c2f781b
15c3238
 
 
 
c2f781b
15c3238
 
 
c2f781b
782222a
15c3238
 
c2f781b
 
 
 
782222a
15c3238
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
"""NeuralTuner environment client.

NeuralTuner Environment Client."""

from typing import Dict, Optional

from openenv.core import EnvClient
from openenv.core.client_types import StepResult

from models import NeuralTunerAction, NeuralTunerObservation, NeuralTunerState


class NeuralTunerEnv(EnvClient[NeuralTunerAction, NeuralTunerObservation, NeuralTunerState]):
    """
    Client for the NeuralTuner Environment.

    Maintains a persistent WebSocket connection to the environment server,
    enabling efficient multi-step LLM-agent rollouts.

    Example:
        >>> with NeuralTunerEnv(base_url='http://localhost:8000') as env:
        ...     result = env.reset()
        ...     print(result.observation.output)
        ...
        ...     action = NeuralTunerAction(action_type="profile_layer", layer_id="fc_classifier")
        ...     result = env.step(action)
        ...     print(result.observation.output)
    """

    def _step_payload(self, action: NeuralTunerAction) -> Dict:
        payload: Dict = {"action_type": action.action_type}
        if action.layer_id is not None:
            payload["layer_id"] = action.layer_id
        if action.dtype is not None:
            payload["dtype"] = action.dtype
        if action.sparsity is not None:
            payload["sparsity"] = action.sparsity
        return payload

    def _parse_result(self, payload: Dict) -> StepResult[NeuralTunerObservation]:
        obs_data = payload.get("observation", {})
        observation = NeuralTunerObservation(
            output=obs_data.get("output", ""),
            success=obs_data.get("success", True),
            error=obs_data.get("error"),
            done=payload.get("done", False),
            reward=payload.get("reward", 0.0),
            metadata=obs_data.get("metadata", {}),
        )
        return StepResult(
            observation=observation,
            reward=payload.get("reward", 0.0),
            done=payload.get("done", False),
        )

    def _parse_state(self, payload: Dict) -> NeuralTunerState:
        return NeuralTunerState(
            episode_id=payload.get("episode_id"),
            step_count=payload.get("step_count", 0),
            model_id=payload.get("model_id", ""),
            difficulty=payload.get("difficulty", "easy"),
            submitted=payload.get("submitted", False),
            benchmark_count=payload.get("benchmark_count", 0),
            final_reward=payload.get("final_reward"),
        )