Spaces:
Sleeping
Sleeping
Commit ·
52dde71
1
Parent(s): 50ef6b4
Fix: Add START/STEP/END structured log format to inference.py
Browse files- inference.py +14 -18
inference.py
CHANGED
|
@@ -118,12 +118,10 @@ def run_task(env: FraudShieldEnvironment, agent: object, task_name: str) -> Tupl
|
|
| 118 |
print(f"Accuracy: {sum(p == l for p, l in zip(preds, labels)) / len(preds)}")
|
| 119 |
"""
|
| 120 |
|
| 121 |
-
|
| 122 |
-
logger.info("
|
| 123 |
-
logger.info("%s", "=" * 72)
|
| 124 |
|
| 125 |
reset_result = env.reset(task_name)
|
| 126 |
-
logger.info("Episode %s contains %s transactions", env.episode_id, reset_result.info["num_transactions"])
|
| 127 |
|
| 128 |
observation = reset_result.observation
|
| 129 |
predictions: List[str] = []
|
|
@@ -135,21 +133,21 @@ def run_task(env: FraudShieldEnvironment, agent: object, task_name: str) -> Tupl
|
|
| 135 |
confidences.append(action.confidence)
|
| 136 |
step_result = env.step(action)
|
| 137 |
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
)
|
| 146 |
|
| 147 |
observation = step_result.observation
|
| 148 |
|
|
|
|
| 149 |
logger.info(
|
| 150 |
-
"
|
| 151 |
task_name.upper(),
|
| 152 |
-
|
| 153 |
env.cumulative_reward,
|
| 154 |
)
|
| 155 |
return predictions, list(env.ground_truth_labels), confidences
|
|
@@ -193,9 +191,7 @@ def main() -> Dict[str, object]:
|
|
| 193 |
print(f"Easy: {result['easy']['score']:.4f}")
|
| 194 |
"""
|
| 195 |
|
| 196 |
-
logger.info("
|
| 197 |
-
logger.info("FraudShield baseline inference")
|
| 198 |
-
logger.info("%s", "=" * 72)
|
| 199 |
|
| 200 |
env = FraudShieldEnvironment(data_path="data", seed=42)
|
| 201 |
if not env.load_data():
|
|
@@ -241,7 +237,7 @@ def main() -> Dict[str, object]:
|
|
| 241 |
logger.info("Easy score: %.4f", grading_result["easy"]["score"])
|
| 242 |
logger.info("Medium score: %.4f", grading_result["medium"]["score"])
|
| 243 |
logger.info("Hard score: %.4f", grading_result["hard"]["score"])
|
| 244 |
-
logger.info("
|
| 245 |
|
| 246 |
with open(RESULTS_FILE, "w", encoding="utf-8") as handle:
|
| 247 |
json.dump(grading_result, handle, indent=2)
|
|
|
|
| 118 |
print(f"Accuracy: {sum(p == l for p, l in zip(preds, labels)) / len(preds)}")
|
| 119 |
"""
|
| 120 |
|
| 121 |
+
agent_name = getattr(agent, "name", agent.__class__.__name__)
|
| 122 |
+
logger.info("START %s %s", task_name, agent_name)
|
|
|
|
| 123 |
|
| 124 |
reset_result = env.reset(task_name)
|
|
|
|
| 125 |
|
| 126 |
observation = reset_result.observation
|
| 127 |
predictions: List[str] = []
|
|
|
|
| 133 |
confidences.append(action.confidence)
|
| 134 |
step_result = env.step(action)
|
| 135 |
|
| 136 |
+
logger.info(
|
| 137 |
+
"STEP %02d %s %.2f %+.2f",
|
| 138 |
+
env.step_count,
|
| 139 |
+
action.decision.value,
|
| 140 |
+
action.confidence,
|
| 141 |
+
step_result.reward.value,
|
| 142 |
+
)
|
|
|
|
| 143 |
|
| 144 |
observation = step_result.observation
|
| 145 |
|
| 146 |
+
accuracy = env.correct_predictions / max(1, env.step_count)
|
| 147 |
logger.info(
|
| 148 |
+
"END %s %.3f %.3f",
|
| 149 |
task_name.upper(),
|
| 150 |
+
accuracy,
|
| 151 |
env.cumulative_reward,
|
| 152 |
)
|
| 153 |
return predictions, list(env.ground_truth_labels), confidences
|
|
|
|
| 191 |
print(f"Easy: {result['easy']['score']:.4f}")
|
| 192 |
"""
|
| 193 |
|
| 194 |
+
logger.info("START FraudShield baseline inference")
|
|
|
|
|
|
|
| 195 |
|
| 196 |
env = FraudShieldEnvironment(data_path="data", seed=42)
|
| 197 |
if not env.load_data():
|
|
|
|
| 237 |
logger.info("Easy score: %.4f", grading_result["easy"]["score"])
|
| 238 |
logger.info("Medium score: %.4f", grading_result["medium"]["score"])
|
| 239 |
logger.info("Hard score: %.4f", grading_result["hard"]["score"])
|
| 240 |
+
logger.info("END FraudShield %.4f", grading_result["final_score"])
|
| 241 |
|
| 242 |
with open(RESULTS_FILE, "w", encoding="utf-8") as handle:
|
| 243 |
json.dump(grading_result, handle, indent=2)
|