DevikaJ2005 commited on
Commit
52dde71
·
1 Parent(s): 50ef6b4

Fix: Add START/STEP/END structured log format to inference.py

Browse files
Files changed (1) hide show
  1. inference.py +14 -18
inference.py CHANGED
@@ -118,12 +118,10 @@ def run_task(env: FraudShieldEnvironment, agent: object, task_name: str) -> Tupl
118
  print(f"Accuracy: {sum(p == l for p, l in zip(preds, labels)) / len(preds)}")
119
  """
120
 
121
- logger.info("%s", "=" * 72)
122
- logger.info("Running %s task with %s", task_name.upper(), getattr(agent, "name", agent.__class__.__name__))
123
- logger.info("%s", "=" * 72)
124
 
125
  reset_result = env.reset(task_name)
126
- logger.info("Episode %s contains %s transactions", env.episode_id, reset_result.info["num_transactions"])
127
 
128
  observation = reset_result.observation
129
  predictions: List[str] = []
@@ -135,21 +133,21 @@ def run_task(env: FraudShieldEnvironment, agent: object, task_name: str) -> Tupl
135
  confidences.append(action.confidence)
136
  step_result = env.step(action)
137
 
138
- if env.step_count in {1, len(env.current_cases)} or env.step_count % 10 == 0:
139
- logger.info(
140
- "Step %02d | decision=%s | confidence=%.2f | reward=%+.2f",
141
- env.step_count,
142
- action.decision.value,
143
- action.confidence,
144
- step_result.reward.value,
145
- )
146
 
147
  observation = step_result.observation
148
 
 
149
  logger.info(
150
- "Finished %s: accuracy_so_far=%.3f cumulative_reward=%.3f",
151
  task_name.upper(),
152
- env.correct_predictions / max(1, env.step_count),
153
  env.cumulative_reward,
154
  )
155
  return predictions, list(env.ground_truth_labels), confidences
@@ -193,9 +191,7 @@ def main() -> Dict[str, object]:
193
  print(f"Easy: {result['easy']['score']:.4f}")
194
  """
195
 
196
- logger.info("%s", "=" * 72)
197
- logger.info("FraudShield baseline inference")
198
- logger.info("%s", "=" * 72)
199
 
200
  env = FraudShieldEnvironment(data_path="data", seed=42)
201
  if not env.load_data():
@@ -241,7 +237,7 @@ def main() -> Dict[str, object]:
241
  logger.info("Easy score: %.4f", grading_result["easy"]["score"])
242
  logger.info("Medium score: %.4f", grading_result["medium"]["score"])
243
  logger.info("Hard score: %.4f", grading_result["hard"]["score"])
244
- logger.info("Final score: %.4f", grading_result["final_score"])
245
 
246
  with open(RESULTS_FILE, "w", encoding="utf-8") as handle:
247
  json.dump(grading_result, handle, indent=2)
 
118
  print(f"Accuracy: {sum(p == l for p, l in zip(preds, labels)) / len(preds)}")
119
  """
120
 
121
+ agent_name = getattr(agent, "name", agent.__class__.__name__)
122
+ logger.info("START %s %s", task_name, agent_name)
 
123
 
124
  reset_result = env.reset(task_name)
 
125
 
126
  observation = reset_result.observation
127
  predictions: List[str] = []
 
133
  confidences.append(action.confidence)
134
  step_result = env.step(action)
135
 
136
+ logger.info(
137
+ "STEP %02d %s %.2f %+.2f",
138
+ env.step_count,
139
+ action.decision.value,
140
+ action.confidence,
141
+ step_result.reward.value,
142
+ )
 
143
 
144
  observation = step_result.observation
145
 
146
+ accuracy = env.correct_predictions / max(1, env.step_count)
147
  logger.info(
148
+ "END %s %.3f %.3f",
149
  task_name.upper(),
150
+ accuracy,
151
  env.cumulative_reward,
152
  )
153
  return predictions, list(env.ground_truth_labels), confidences
 
191
  print(f"Easy: {result['easy']['score']:.4f}")
192
  """
193
 
194
+ logger.info("START FraudShield baseline inference")
 
 
195
 
196
  env = FraudShieldEnvironment(data_path="data", seed=42)
197
  if not env.load_data():
 
237
  logger.info("Easy score: %.4f", grading_result["easy"]["score"])
238
  logger.info("Medium score: %.4f", grading_result["medium"]["score"])
239
  logger.info("Hard score: %.4f", grading_result["hard"]["score"])
240
+ logger.info("END FraudShield %.4f", grading_result["final_score"])
241
 
242
  with open(RESULTS_FILE, "w", encoding="utf-8") as handle:
243
  json.dump(grading_result, handle, indent=2)