JonathanShiju12 commited on
Commit
63fa460
·
verified ·
1 Parent(s): 4cb0f71

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. Inference.py +2 -3
  2. client.py +1 -2
Inference.py CHANGED
@@ -209,9 +209,8 @@ async def run_episode_loop(
209
  actions = await call_llm_for_actions(client, user_prompt)
210
  action_payload = actions_to_hft_action(actions)
211
  observation = await step_fn(action_payload)
212
- observation, user_prompt = build_user_prompt(
213
- observation.get("observation", observation)
214
- )
215
 
216
  reward = observation.reward
217
  done = observation.done
 
209
  actions = await call_llm_for_actions(client, user_prompt)
210
  action_payload = actions_to_hft_action(actions)
211
  observation = await step_fn(action_payload)
212
+ log_transcript(f"Received observation: {observation}")
213
+ observation, user_prompt = build_user_prompt(observation.get("observation"))
 
214
 
215
  reward = observation.reward
216
  done = observation.done
client.py CHANGED
@@ -80,9 +80,8 @@ class HftEnv(EnvClient[HftAction, HftObservation, HftState]):
80
  obs_data = payload.get("observation", {})
81
  observation = HftObservation(
82
  time=obs_data.get("time", 0.0),
83
- reward=obs_data.get("reward", 0.0),
84
  done=obs_data.get("done", False),
85
- spread=obs_data.get("spread"),
86
  history=obs_data.get("history", []),
87
  active_orders=obs_data.get("active_orders", []),
88
  )
 
80
  obs_data = payload.get("observation", {})
81
  observation = HftObservation(
82
  time=obs_data.get("time", 0.0),
83
+ reward=payload.get("reward", 0.0),
84
  done=obs_data.get("done", False),
 
85
  history=obs_data.get("history", []),
86
  active_orders=obs_data.get("active_orders", []),
87
  )