File size: 1,212 Bytes
a8d4cdf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from qlearning import QTable, ACTIONS, encode_state
from environment import GarbageRobotEnv
from scenarios import SCENARIOS
import json

qt = QTable()
qt.load('qtable.json')
env = GarbageRobotEnv()

instruction = '''You are an AI brain controlling a garbage collecting robot.
Reply with EXACTLY ONE of: UP DOWN LEFT RIGHT COLLECT'''

alpaca = '''### Instruction:\n{}\n\n### Input:\nENVIRONMENT STATUS:\n{}\n\n### Response:\n{}'''

data = []
for task_id in SCENARIOS:
    for _ in range(10):  # 10 episodes per task
        env.reset(task_id)
        done = False
        while not done:
            obs_obj = env.get_observation()
            obs = {'robot_position': obs_obj.robot_position,
                   'garbage_positions': list(obs_obj.garbage_positions),
                   'grid_size': obs_obj.grid_size}
            state = encode_state(obs)
            action = ACTIONS[qt.best_action(state)]
            data.append({'text': alpaca.format(instruction, obs_obj.message, action)})
            result = env.step(action)
            done = result['done']

with open('rl_trajectories.jsonl', 'w') as f:
    for row in data:
        f.write(json.dumps(row) + '\n')
print(f'Generated {len(data)} samples')