poseidon666's picture
Upload folder using huggingface_hub
13260f8 verified
Raw
History Blame Contribute Delete
4.24 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Quantum Circuit Optimization Environment Client."""
from typing import Dict
from openenv.core import EnvClient
from openenv.core.client_types import StepResult
from .models import QuantumAction, QuantumObservation, QuantumState
class QuantumCircuitEnv(
EnvClient[QuantumAction, QuantumObservation, QuantumState]
):
"""
Client for the Quantum Circuit Optimization Environment.
Maintains a persistent WebSocket connection to the environment server
for efficient multi-step interactions with lower latency.
Example:
>>> with QuantumCircuitEnv(base_url="http://localhost:8000") as client:
... result = client.reset()
... print(result.observation.fidelity)
...
... from models import ActionType, GateType
... action = QuantumAction(
... action_type=ActionType.ADD,
... gate=GateType.H,
... qubits=[0],
... )
... result = client.step(action)
... print(f"Fidelity: {result.observation.fidelity}")
"""
def _step_payload(self, action: QuantumAction) -> Dict:
"""
Convert QuantumAction to JSON payload for step message.
Args:
action: QuantumAction instance.
Returns:
Dictionary representation suitable for JSON encoding.
"""
payload = {
"action_type": action.action_type.value if hasattr(action.action_type, 'value') else action.action_type,
"qubits": action.qubits,
}
if action.gate is not None:
payload["gate"] = action.gate.value if hasattr(action.gate, 'value') else action.gate
if action.parameter is not None:
payload["parameter"] = action.parameter
return payload
def _parse_result(self, payload: Dict) -> StepResult[QuantumObservation]:
"""
Parse server response into StepResult[QuantumObservation].
Args:
payload: JSON response data from server.
Returns:
StepResult with QuantumObservation.
"""
obs_data = payload.get("observation", {})
observation = QuantumObservation(
circuit_gates=obs_data.get("circuit_gates", []),
fidelity=obs_data.get("fidelity", 0.0),
depth=obs_data.get("depth", 0),
gate_count=obs_data.get("gate_count", 0),
noise_estimate=obs_data.get("noise_estimate", 0.0),
valid_actions=obs_data.get("valid_actions", []),
score=obs_data.get("score", 0.0),
task_id=obs_data.get("task_id", ""),
num_qubits=obs_data.get("num_qubits", 0),
max_steps=obs_data.get("max_steps", 0),
steps_taken=obs_data.get("steps_taken", 0),
target_description=obs_data.get("target_description", ""),
done=payload.get("done", False),
reward=payload.get("reward"),
metadata=obs_data.get("metadata", {}),
)
return StepResult(
observation=observation,
reward=payload.get("reward"),
done=payload.get("done", False),
)
def _parse_state(self, payload: Dict) -> QuantumState:
"""
Parse server response into QuantumState object.
Args:
payload: JSON response from state request.
Returns:
QuantumState with full environment state.
"""
return QuantumState(
episode_id=payload.get("episode_id"),
step_count=payload.get("step_count", 0),
circuit_gates=payload.get("circuit_gates", []),
target_description=payload.get("target_description", ""),
task_id=payload.get("task_id", ""),
max_steps=payload.get("max_steps", 20),
noise_model_name=payload.get("noise_model_name", "none"),
current_fidelity=payload.get("current_fidelity", 0.0),
current_score=payload.get("current_score", 0.0),
)