Spaces:
Running
Running
File size: 1,792 Bytes
80d8c84 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 | # ββ 13. Quick interactive demo ββββββββββββββββββββββββββββββββββββββββββββββββ
import torch
DEMO_SEED = 999
DEMO_SCENARIO = "math_reasoning"
reset_data = client.reset(seed=DEMO_SEED, scenario=DEMO_SCENARIO, difficulty="easy")
obs = reset_data["observation"]
print(f"Episode: {reset_data['episode_id']}")
print(f"Paper: {obs['scientist']['paper_title']}\n")
done = False
total_reward = 0.0
model.eval()
while not done:
# transformers 5.x requires content as a list of blocks, not a plain string
messages = [
{"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
{"role": "user", "content": [{"type": "text", "text": obs_to_prompt(obs)}]},
]
inputs = tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
).to(model.device)
with torch.no_grad():
out = model.generate(
inputs,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=False,
pad_token_id=tokenizer.eos_token_id,
)
text = tokenizer.decode(out[0][inputs.shape[-1]:], skip_special_tokens=True)
action = parse_action(text)
result = client.step(action)
rnd = obs['scientist']['round_number'] + 1
r = result['reward']
total_reward += r
print(f"Round {rnd}: action={action['action_type']} reward={r:.3f}")
if action.get('rationale'):
print(f" rationale: {action['rationale'][:80]}")
done = result["done"]
if not done:
obs = result["observation"]
print(f"\nEpisode done. Total reward: {total_reward:.3f}")
print("Agreement reached:", result.get("info", {}).get("agreement_reached"))
|