conflict-arbitration-env / eval /domain_transfer.py
Jeeevan11's picture
initial commit: conflict arbitration env
8fd5450
Raw
History Blame Contribute Delete
1.58 kB
from training.prompt_templates import format_arbitrator_observation
from training.rollout import generate_decision
def evaluate_domain_transfer(
trained_model,
env_client,
tokenizer,
unseen_domain: str,
num_episodes: int = 50
) -> dict:
"""
Tests Agent C on a domain it was never trained on.
If accuracy stays high: it learned the abstract pattern.
If accuracy drops: it memorized the training domain.
This is your wow demonstration on stage.
Run this live. Show it works on a domain it never saw.
"""
results = []
for _ in range(num_episodes):
obs = env_client.reset()
messages = format_arbitrator_observation(obs)
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
_, decision = generate_decision(trained_model, tokenizer, prompt)
if decision is None:
decision = {
"conflict_detected": False,
"action": "nothing",
"reason": "parse error",
"correction_request": "",
}
result = env_client.step(decision)
results.append(result)
correct = sum(r["info"].get("agent_c_was_correct", False) for r in results)
satisfied = sum(r["info"].get("spec_satisfied", False) for r in results)
return {
"domain": unseen_domain,
"accuracy": correct / num_episodes if num_episodes else 0.0,
"merge_rate": satisfied / num_episodes if num_episodes else 0.0,
"episodes": num_episodes,
}