Meta-Hackathon-main / client.py
Parth3841's picture
Upload folder using huggingface_hub
7c2f148 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Compiler Pass Ordering Environment Client."""
from typing import Dict
from openenv.core import EnvClient
from openenv.core.client_types import StepResult
from openenv.core.env_server.types import State
from .models import CompilerOptAction, CompilerOptObservation
class CompilerOptEnv(
EnvClient[CompilerOptAction, CompilerOptObservation, State]
):
"""
Client for the Compiler Pass Ordering Environment.
Maintains a persistent WebSocket connection to the environment server.
Each client instance has its own dedicated environment session.
Example (sync):
>>> with CompilerOptEnv(base_url="http://localhost:8000").sync() as env:
... obs = env.reset()
... result = env.step(CompilerOptAction(pass_id=13, task_id=1))
... print(result.observation.improvement_pct)
Example (async):
>>> async with CompilerOptEnv(base_url="http://localhost:8000") as env:
... obs = await env.reset()
... result = await env.step(CompilerOptAction(pass_id=13, task_id=1))
"""
def _step_payload(self, action: CompilerOptAction) -> Dict:
return {
"pass_id": action.pass_id,
"task_id": action.task_id,
}
def _parse_result(self, payload: Dict) -> StepResult[CompilerOptObservation]:
obs_data = payload.get("observation", {})
observation = CompilerOptObservation(
estimated_cost = obs_data.get("estimated_cost", 0.0),
baseline_cost = obs_data.get("baseline_cost", 0.0),
num_instructions = obs_data.get("num_instructions", 0),
num_loops = obs_data.get("num_loops", 0),
num_branches = obs_data.get("num_branches", 0),
num_functions = obs_data.get("num_functions", 0),
loop_depth = obs_data.get("loop_depth", 0),
program_type = obs_data.get("program_type", ""),
passes_applied = obs_data.get("passes_applied", []),
passes_available = obs_data.get("passes_available", []),
step_count = obs_data.get("step_count", 0),
max_steps = obs_data.get("max_steps", 10),
synergy_state = obs_data.get("synergy_state", [1.0] * 15),
task_id = obs_data.get("task_id", 3),
task_description = obs_data.get("task_description", ""),
done = payload.get("done", False),
reward = payload.get("reward", 0.0),
improvement_pct = obs_data.get("improvement_pct", 0.0),
last_pass_name = obs_data.get("last_pass_name"),
grader_score = obs_data.get("grader_score"),
)
return StepResult(
observation = observation,
reward = payload.get("reward"),
done = payload.get("done", False),
)
def _parse_state(self, payload: Dict) -> State:
return State(
episode_id = payload.get("episode_id"),
step_count = payload.get("step_count", 0),
)