pbanavara commited on
Commit
fb9e30c
·
verified ·
1 Parent(s): b7e2b4d

Upload folder using huggingface_hub

Browse files
client.py CHANGED
@@ -6,17 +6,17 @@ from openenv.core.client_types import StepResult
6
  from openenv.core.env_server.types import State
7
  from openenv.core import EnvClient
8
 
9
- from .models import AdminAction, AdminObservation
10
 
11
 
12
- class PranaEnv(EnvClient[AdminAction, AdminObservation, State]):
13
  """
14
  Client for PRANA-Env.
15
 
16
  Example:
17
  >>> with PranaEnv(base_url="http://localhost:8000") as client:
18
  ... client.reset()
19
- ... result = client.step(AdminAction(
20
  ... action_type="query_db",
21
  ... target="PatientDB",
22
  ... field="hba1c",
@@ -25,12 +25,12 @@ class PranaEnv(EnvClient[AdminAction, AdminObservation, State]):
25
  ... print(result.observation.query_result) # "7.2"
26
  """
27
 
28
- def _step_payload(self, action: AdminAction) -> Dict:
29
  return {k: v for k, v in action.model_dump().items() if v is not None}
30
 
31
- def _parse_result(self, payload: Dict) -> StepResult[AdminObservation]:
32
  obs_data = payload.get("observation", {})
33
- observation = AdminObservation(
34
  query_result=obs_data.get("query_result", ""),
35
  active_task=obs_data.get("active_task", "t1"),
36
  policy_alerts=obs_data.get("policy_alerts", ""),
 
6
  from openenv.core.env_server.types import State
7
  from openenv.core import EnvClient
8
 
9
+ from .models import PranaAction, PranaObservation
10
 
11
 
12
+ class PranaEnv(EnvClient[PranaAction, PranaObservation, State]):
13
  """
14
  Client for PRANA-Env.
15
 
16
  Example:
17
  >>> with PranaEnv(base_url="http://localhost:8000") as client:
18
  ... client.reset()
19
+ ... result = client.step(PranaAction(
20
  ... action_type="query_db",
21
  ... target="PatientDB",
22
  ... field="hba1c",
 
25
  ... print(result.observation.query_result) # "7.2"
26
  """
27
 
28
+ def _step_payload(self, action: PranaAction) -> Dict:
29
  return {k: v for k, v in action.model_dump().items() if v is not None}
30
 
31
+ def _parse_result(self, payload: Dict) -> StepResult[PranaObservation]:
32
  obs_data = payload.get("observation", {})
33
+ observation = PranaObservation(
34
  query_result=obs_data.get("query_result", ""),
35
  active_task=obs_data.get("active_task", "t1"),
36
  policy_alerts=obs_data.get("policy_alerts", ""),
models.py CHANGED
@@ -12,14 +12,14 @@ from pydantic import Field
12
  from openenv.core.env_server.types import Action, Observation
13
 
14
 
15
- class AdminAction(Action):
16
  """
17
  Action for PRANA-Env.
18
 
19
  For the smoke test, only action_type='query_db' is supported.
20
 
21
  Example:
22
- >>> action = AdminAction(
23
  ... action_type="query_db",
24
  ... target="PatientDB",
25
  ... field="hba1c",
@@ -55,7 +55,7 @@ class AdminAction(Action):
55
  )
56
 
57
 
58
- class AdminObservation(Observation):
59
  """
60
  Observation from PRANA-Env.
61
 
 
12
  from openenv.core.env_server.types import Action, Observation
13
 
14
 
15
+ class PranaAction(Action):
16
  """
17
  Action for PRANA-Env.
18
 
19
  For the smoke test, only action_type='query_db' is supported.
20
 
21
  Example:
22
+ >>> action = PranaAction(
23
  ... action_type="query_db",
24
  ... target="PatientDB",
25
  ... field="hba1c",
 
55
  )
56
 
57
 
58
+ class PranaObservation(Observation):
59
  """
60
  Observation from PRANA-Env.
61
 
openenv_prana_env.egg-info/PKG-INFO ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.4
2
+ Name: openenv-prana_env
3
+ Version: 0.1.0
4
+ Summary: Prana Env environment for OpenEnv
5
+ Requires-Python: >=3.10
6
+ Requires-Dist: openenv-core[core]>=0.2.0
7
+ Provides-Extra: dev
8
+ Requires-Dist: pytest>=8.0.0; extra == "dev"
9
+ Requires-Dist: pytest-cov>=4.0.0; extra == "dev"
openenv_prana_env.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ README.md
2
+ pyproject.toml
3
+ ./__init__.py
4
+ ./client.py
5
+ ./models.py
6
+ ./test_client.py
7
+ openenv_prana_env.egg-info/PKG-INFO
8
+ openenv_prana_env.egg-info/SOURCES.txt
9
+ openenv_prana_env.egg-info/dependency_links.txt
10
+ openenv_prana_env.egg-info/entry_points.txt
11
+ openenv_prana_env.egg-info/requires.txt
12
+ openenv_prana_env.egg-info/top_level.txt
13
+ server/__init__.py
14
+ server/app.py
15
+ server/prana_env_environment.py
openenv_prana_env.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
openenv_prana_env.egg-info/entry_points.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [console_scripts]
2
+ server = prana_env.server.app:main
openenv_prana_env.egg-info/requires.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ openenv-core[core]>=0.2.0
2
+
3
+ [dev]
4
+ pytest>=8.0.0
5
+ pytest-cov>=4.0.0
openenv_prana_env.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ prana_env
server/app.py CHANGED
@@ -36,15 +36,15 @@ except Exception as e: # pragma: no cover
36
  ) from e
37
 
38
  # Import from local models.py (PYTHONPATH includes /app/env in Docker)
39
- from models import AdminAction, AdminObservation
40
  from .prana_env_environment import PranaEnvironment
41
 
42
 
43
  # Create the app with web interface and README integration
44
  app = create_app(
45
  PranaEnvironment,
46
- AdminAction,
47
- AdminObservation,
48
  env_name="prana_env",
49
  max_concurrent_envs=1, # increase this number to allow more concurrent WebSocket sessions
50
  )
 
36
  ) from e
37
 
38
  # Import from local models.py (PYTHONPATH includes /app/env in Docker)
39
+ from models import PranaAction, PranaObservation
40
  from .prana_env_environment import PranaEnvironment
41
 
42
 
43
  # Create the app with web interface and README integration
44
  app = create_app(
45
  PranaEnvironment,
46
+ PranaAction,
47
+ PranaObservation,
48
  env_name="prana_env",
49
  max_concurrent_envs=1, # increase this number to allow more concurrent WebSocket sessions
50
  )
server/prana_env_environment.py CHANGED
@@ -13,7 +13,7 @@ from uuid import uuid4
13
  from openenv.core.env_server.interfaces import Environment
14
  from openenv.core.env_server.types import State
15
 
16
- from models import AdminAction, AdminObservation
17
 
18
  tag = "[prana_env/environment]"
19
  logger = logging.getLogger(__name__)
@@ -52,7 +52,7 @@ class PranaEnvironment(Environment):
52
  with open(path) as f:
53
  return json.load(f)
54
 
55
- def reset(self, seed: int | None = None, episode_id: str | None = None, **kwargs) -> AdminObservation:
56
  patient_id: str | None = kwargs.get("patient_id")
57
  self._state = State(episode_id=episode_id or str(uuid4()), step_count=0)
58
  self._active_task = "t1"
@@ -60,14 +60,14 @@ class PranaEnvironment(Environment):
60
 
61
  logger.info(f"{tag} reset — episode={self._state.episode_id} patient_id={patient_id}")
62
 
63
- return AdminObservation(
64
  query_result="Episode reset. Ready for task t1: Initial Labs.",
65
  active_task=self._active_task,
66
  done=False,
67
  reward=0.0,
68
  )
69
 
70
- def step(self, action: AdminAction) -> AdminObservation: # type: ignore[override]
71
  self._state.step_count += 1
72
  logger.info(
73
  f"{tag} step={self._state.step_count} action_type={action.action_type} "
@@ -78,14 +78,14 @@ class PranaEnvironment(Environment):
78
  return self._handle_query_db(action)
79
 
80
  logger.warning(f"{tag} Unsupported action_type={action.action_type}")
81
- return AdminObservation(
82
  query_result=f"NOT_SUPPORTED: action_type '{action.action_type}' not implemented yet.",
83
  active_task=self._active_task,
84
  done=False,
85
  reward=0.0,
86
  )
87
 
88
- def _handle_query_db(self, action: AdminAction) -> AdminObservation:
89
  db_name = (action.target or "").lower()
90
  field = (action.field or "").lower()
91
  patient_id = action.patient_id or self._patient_id
@@ -94,7 +94,7 @@ class PranaEnvironment(Environment):
94
 
95
  if db_name != "patientdb":
96
  logger.warning(f"{tag} Datastore '{db_name}' not available in Phase 1")
97
- return AdminObservation(
98
  query_result=f"NOT_AVAILABLE: datastore '{action.target}' not loaded in Phase 1.",
99
  active_task=self._active_task,
100
  done=False,
@@ -102,7 +102,7 @@ class PranaEnvironment(Environment):
102
  )
103
 
104
  if not patient_id:
105
- return AdminObservation(
106
  query_result="ERROR: patient_id is required for query_db.",
107
  active_task=self._active_task,
108
  done=False,
@@ -114,7 +114,7 @@ class PranaEnvironment(Environment):
114
 
115
  if not patient:
116
  logger.info(f"{tag} patient_id={patient_id} NOT_FOUND in PatientDB")
117
- return AdminObservation(
118
  query_result=f"NOT_FOUND: patient '{patient_id}' not in PatientDB.",
119
  active_task=self._active_task,
120
  done=False,
@@ -123,7 +123,7 @@ class PranaEnvironment(Environment):
123
 
124
  if field not in patient:
125
  logger.info(f"{tag} field={field} NOT_FOUND for patient={patient_id}")
126
- return AdminObservation(
127
  query_result=f"NOT_FOUND: field '{field}' not in PatientDB for patient '{patient_id}'.",
128
  active_task=self._active_task,
129
  done=False,
@@ -133,7 +133,7 @@ class PranaEnvironment(Environment):
133
  value = patient[field]
134
  if value is None:
135
  logger.info(f"{tag} field={field} is NULL for patient={patient_id}")
136
- return AdminObservation(
137
  query_result=f"NOT_FOUND: field '{field}' has no recorded value for patient '{patient_id}'.",
138
  active_task=self._active_task,
139
  done=False,
@@ -141,7 +141,7 @@ class PranaEnvironment(Environment):
141
  )
142
 
143
  logger.info(f"{tag} query_db success field={field} value={value} patient={patient_id}")
144
- return AdminObservation(
145
  query_result=str(value),
146
  active_task=self._active_task,
147
  done=False,
 
13
  from openenv.core.env_server.interfaces import Environment
14
  from openenv.core.env_server.types import State
15
 
16
+ from models import PranaAction, PranaObservation
17
 
18
  tag = "[prana_env/environment]"
19
  logger = logging.getLogger(__name__)
 
52
  with open(path) as f:
53
  return json.load(f)
54
 
55
+ def reset(self, seed: int | None = None, episode_id: str | None = None, **kwargs) -> PranaObservation:
56
  patient_id: str | None = kwargs.get("patient_id")
57
  self._state = State(episode_id=episode_id or str(uuid4()), step_count=0)
58
  self._active_task = "t1"
 
60
 
61
  logger.info(f"{tag} reset — episode={self._state.episode_id} patient_id={patient_id}")
62
 
63
+ return PranaObservation(
64
  query_result="Episode reset. Ready for task t1: Initial Labs.",
65
  active_task=self._active_task,
66
  done=False,
67
  reward=0.0,
68
  )
69
 
70
+ def step(self, action: PranaAction) -> PranaObservation: # type: ignore[override]
71
  self._state.step_count += 1
72
  logger.info(
73
  f"{tag} step={self._state.step_count} action_type={action.action_type} "
 
78
  return self._handle_query_db(action)
79
 
80
  logger.warning(f"{tag} Unsupported action_type={action.action_type}")
81
+ return PranaObservation(
82
  query_result=f"NOT_SUPPORTED: action_type '{action.action_type}' not implemented yet.",
83
  active_task=self._active_task,
84
  done=False,
85
  reward=0.0,
86
  )
87
 
88
+ def _handle_query_db(self, action: PranaAction) -> PranaObservation:
89
  db_name = (action.target or "").lower()
90
  field = (action.field or "").lower()
91
  patient_id = action.patient_id or self._patient_id
 
94
 
95
  if db_name != "patientdb":
96
  logger.warning(f"{tag} Datastore '{db_name}' not available in Phase 1")
97
+ return PranaObservation(
98
  query_result=f"NOT_AVAILABLE: datastore '{action.target}' not loaded in Phase 1.",
99
  active_task=self._active_task,
100
  done=False,
 
102
  )
103
 
104
  if not patient_id:
105
+ return PranaObservation(
106
  query_result="ERROR: patient_id is required for query_db.",
107
  active_task=self._active_task,
108
  done=False,
 
114
 
115
  if not patient:
116
  logger.info(f"{tag} patient_id={patient_id} NOT_FOUND in PatientDB")
117
+ return PranaObservation(
118
  query_result=f"NOT_FOUND: patient '{patient_id}' not in PatientDB.",
119
  active_task=self._active_task,
120
  done=False,
 
123
 
124
  if field not in patient:
125
  logger.info(f"{tag} field={field} NOT_FOUND for patient={patient_id}")
126
+ return PranaObservation(
127
  query_result=f"NOT_FOUND: field '{field}' not in PatientDB for patient '{patient_id}'.",
128
  active_task=self._active_task,
129
  done=False,
 
133
  value = patient[field]
134
  if value is None:
135
  logger.info(f"{tag} field={field} is NULL for patient={patient_id}")
136
+ return PranaObservation(
137
  query_result=f"NOT_FOUND: field '{field}' has no recorded value for patient '{patient_id}'.",
138
  active_task=self._active_task,
139
  done=False,
 
141
  )
142
 
143
  logger.info(f"{tag} query_db success field={field} value={value} patient={patient_id}")
144
+ return PranaObservation(
145
  query_result=str(value),
146
  active_task=self._active_task,
147
  done=False,
test_client.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from prana_env.client import PranaEnv
2
+ from prana_env.models import PranaAction
3
+
4
+ with PranaEnv(base_url="http://localhost:8000") as client:
5
+ client.reset()
6
+ result = client.step(PranaAction(action_type="query_db",
7
+ target="PatientDB",
8
+ field="hba1c",
9
+ patient_id="P001",
10
+ ))
11
+ print(result.observation.query_result)
12
+
13
+