{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Prepare trained environment using MDPDataset" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[Errno 2] No such file or directory: '../../IRL-MOOC/'\n", "/Users/huonglan/Documents/codeproject/IRL-MOOC\n" ] } ], "source": [ "cd ../../IRL-MOOC/" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import importlib\n", "from tqdm import tqdm\n", "import environment.raw_world as Env\n", "import models.maxcausal as maxcausal\n", "from utils.plot import *\n", "from utils.data_helper import *\n", "from utils.distance import *\n", "from utils.irl_helper import make_syn_student, make_syn_student_personalized\n", "import d3rlpy\n", "from sklearn.model_selection import train_test_split\n", "from d3rlpy.algos import DQNConfig\n", "from d3rlpy.metrics import TDErrorEvaluator\n", "from torch.nn import Softmax" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "import torch" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "DATA_DIR = 'data'\n", "OUTPUT_DIR = 'results/dsp-002'\n", "metadata_dir = 'metadata'\n", "course_id = 'dsp-002'" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "np.random.seed(0)" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "\n", "combinedg = pd.read_csv(f'data/{course_id}/combinedg_features_{course_id}.csv')\n", "combinedg = combinedg.drop_duplicates(subset=['event_id', 'action', 'timestamp'], keep='first')\n", "metadata = pd.read_csv(f\"{metadata_dir}/metadata.csv\")\n", "course_meta = metadata[metadata['course_id'] == course_id]\n", "start_date, end_date = course_meta['start_date'].values[0], course_meta['end_date'].values[0]\n", "combinedg['date'] = combinedg['timestamp'].apply(tmp2dt)\n", "combinedg['week'] = combinedg['date'].apply(dt2w)\n", "combinedg = filter_range_dates(combinedg, start_date, end_date)" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "\"\"\"\n", "Map event_id to integer\n", "Map week to discretized week number\n", "Map action to integer\n", "\"\"\"\n", "sorted_week = combinedg.sort_values('timestamp')['week'].unique()\n", "map_week = dict(zip(sorted_week, np.arange(len(sorted_week))))\n", "sorted_event_id = sorted(combinedg['event_id'].unique())\n", "unique_action = combinedg['action'].unique()\n", "unique_action = np.append(unique_action, [\"Move To Quiz\", \"Move To Video\"])\n", "\n", "map_event_id = dict(zip(sorted_event_id, np.arange(len(sorted_event_id))))\n", "dict_event = dict(zip(np.arange(len(sorted_event_id)), sorted_event_id))\n", "\n", "map_action = dict(zip(unique_action, np.arange(len(unique_action))))\n", "dict_action = dict(zip(np.arange(len(unique_action)), unique_action))" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
user_idUnnamed: 0event_idtimestamplabel-pass-faileventactionproblem_gradesubmission_numbergradedateweek
02002[935674, 1304925, 975953, 222247, 651793, 1792...[22, 22, 22, 22, 22, 6, 21, 9, 8, 22, 22, 22, ...[1386261382, 1386261383, 1386261383, 138626138...1.0[Video, Video, Video, Video, Video, Video, Vid...[Video.Load, Video.SpeedChange, Video.Play, Vi...[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1....[nan, nan, nan, nan, nan, nan, nan, nan, nan, ...[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...[2013-12-05 16:36:22, 2013-12-05 16:36:23, 201...[8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, ...
\n", "
" ], "text/plain": [ " user_id Unnamed: 0 \\\n", "0 2002 [935674, 1304925, 975953, 222247, 651793, 1792... \n", "\n", " event_id \\\n", "0 [22, 22, 22, 22, 22, 6, 21, 9, 8, 22, 22, 22, ... \n", "\n", " timestamp label-pass-fail \\\n", "0 [1386261382, 1386261383, 1386261383, 138626138... 1.0 \n", "\n", " event \\\n", "0 [Video, Video, Video, Video, Video, Video, Vid... \n", "\n", " action \\\n", "0 [Video.Load, Video.SpeedChange, Video.Play, Vi... \n", "\n", " problem_grade \\\n", "0 [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.... \n", "\n", " submission_number \\\n", "0 [nan, nan, nan, nan, nan, nan, nan, nan, nan, ... \n", "\n", " grade \\\n", "0 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", "\n", " date \\\n", "0 [2013-12-05 16:36:22, 2013-12-05 16:36:23, 201... \n", "\n", " week \n", "0 [8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, ... " ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "unique week in dataset: [0 1 2 3 4 5 6 7 8 9]\n" ] } ], "source": [ "### try discretize week based on timestamp only\n", "combinedg['week'] = combinedg['week'].map(map_week)\n", "combinedg['event_id'] = combinedg['event_id'].map(map_event_id)\n", "# combinedg['action'] = combinedg['action'].map(map_action)\n", "def process_week(lst):\n", " old_val = lst[0]\n", " lst[0] = 1\n", " for i in range(1, len(lst)):\n", " if lst[i] != old_val:\n", " old_val = lst[i]\n", " lst[i] = lst[i-1] + 1\n", " else:\n", " lst[i] = lst[i-1]\n", " return lst\n", "\n", "# try discretize week based on when the student first does an event\n", "dataset = combinedg.sort_values('timestamp').groupby('user_id').agg(list).reset_index()\n", "dataset['label-pass-fail'] = dataset['label-pass-fail'].apply(lambda x: x[0])\n", "display(dataset.head(1))\n", "unique_week = np.unique(np.concatenate(dataset['week'].values))\n", "print('unique week in dataset:', np.unique(np.concatenate(dataset['week'].values)))" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [], "source": [ "problem_event = np.unique(combinedg[combinedg.problem_grade!=-1].event_id)\n", "video_event = np.unique(combinedg[combinedg.problem_grade==-1].event_id)\n", "# trajectories = data_to_trajectories(dataset, map_action=map_action, all=True, remove_empty=False)\n", "# whatif_world = Env.ClickstreamWorld(trajectories=trajectories,\n", "# dict_action=dict_action, \n", "# dict_event=dict_event,\n", "# video_arr=video_event,\n", "# problem_arr=problem_event,\n", "# values=None,\n", "# add_state=False)" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "max length, min length 13916 1\n", "num of student in None week: 3974\n", "----------------------\n" ] } ], "source": [ "trajectories = data_to_trajectories(dataset, map_action=map_action, all=True, remove_empty=False)" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "max length, min length 2750 0\n", "num of student in [0] week: 3974\n", "----------------------\n", "max length, min length 3157 0\n", "num of student in [1] week: 3974\n", "----------------------\n", "max length, min length 2616 0\n", "num of student in [2] week: 3974\n", "----------------------\n", "max length, min length 3442 0\n", "num of student in [3] week: 3974\n", "----------------------\n", "max length, min length 2438 0\n", "num of student in [4] week: 3974\n", "----------------------\n", "max length, min length 2330 0\n", "num of student in [5] week: 3974\n", "----------------------\n", "max length, min length 1892 0\n", "num of student in [6] week: 3974\n", "----------------------\n", "max length, min length 2473 0\n", "num of student in [7] week: 3974\n", "----------------------\n", "max length, min length 2059 0\n", "num of student in [8] week: 3974\n", "----------------------\n", "max length, min length 1475 0\n", "num of student in [9] week: 3974\n", "----------------------\n" ] } ], "source": [ "trajectories_each_week = [data_to_trajectories(dataset, map_action=map_action, weeks=[week], all=False, remove_empty=False) for week in unique_week]" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of failing students: 3046\n", "Number of passing students: 928\n", "max length, min length 1715 0\n", "num of student in [0] week: 928\n", "----------------------\n", "max length, min length 3157 0\n", "num of student in [1] week: 928\n", "----------------------\n", "max length, min length 1956 0\n", "num of student in [2] week: 928\n", "----------------------\n", "max length, min length 2755 0\n", "num of student in [3] week: 928\n", "----------------------\n", "max length, min length 1414 0\n", "num of student in [4] week: 928\n", "----------------------\n", "max length, min length 2330 0\n", "num of student in [5] week: 928\n", "----------------------\n", "max length, min length 1892 0\n", "num of student in [6] week: 928\n", "----------------------\n", "max length, min length 1376 0\n", "num of student in [7] week: 928\n", "----------------------\n", "max length, min length 1942 0\n", "num of student in [8] week: 928\n", "----------------------\n", "max length, min length 1475 0\n", "num of student in [9] week: 928\n", "----------------------\n", "max length, min length 2750 0\n", "num of student in [0] week: 3046\n", "----------------------\n", "max length, min length 2436 0\n", "num of student in [1] week: 3046\n", "----------------------\n", "max length, min length 2616 0\n", "num of student in [2] week: 3046\n", "----------------------\n", "max length, min length 3442 0\n", "num of student in [3] week: 3046\n", "----------------------\n", "max length, min length 2438 0\n", "num of student in [4] week: 3046\n", "----------------------\n", "max length, min length 931 0\n", "num of student in [5] week: 3046\n", "----------------------\n", "max length, min length 1635 0\n", "num of student in [6] week: 3046\n", "----------------------\n", "max length, min length 2473 0\n", "num of student in [7] week: 3046\n", "----------------------\n", "max length, min length 2059 0\n", "num of student in [8] week: 3046\n", "----------------------\n", "max length, min length 1264 0\n", "num of student in [9] week: 3046\n", "----------------------\n" ] } ], "source": [ "fail_dataset = dataset[dataset['label-pass-fail'] == 1]\n", "pass_dataset = dataset[dataset['label-pass-fail'] == 0]\n", "print('Number of failing students:', len(fail_dataset))\n", "print('Number of passing students:', len(pass_dataset))\n", "trajectories_each_week_pass = [data_to_trajectories(pass_dataset, map_action=map_action, weeks=[week], all=False, remove_empty=False) for week in unique_week]\n", "trajectories_each_week_fail = [data_to_trajectories(fail_dataset, map_action=map_action, weeks=[week], all=False, remove_empty=False) for week in unique_week]" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "max length, min length 13916 21\n", "num of student in None week: 928\n", "----------------------\n", "max length, min length 12171 1\n", "num of student in None week: 3046\n", "----------------------\n" ] } ], "source": [ "trajectories_pass = data_to_trajectories(pass_dataset, map_action=map_action, all=True, remove_empty=False)\n", "trajectories_fail = data_to_trajectories(fail_dataset, map_action=map_action, all=True, remove_empty=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "number of trajectories: 3974\n", "(61, 3)\n" ] } ], "source": [ "print('number of trajectories:', len(trajectories))\n", "print(trajectories[0].shape)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def create_mdp_data(trajectories, dataset, save_dir='/data/lan/dsp-002/mdp_dataset_split_pass.h5'):\n", " observations, actions, rewards, terminals = [], [], [], []\n", " for id in range(len(trajectories)):\n", " \n", " if (len(trajectories.iloc[id]) <= 0):\n", " continue\n", " state_seq = trajectories.iloc[id][:, 0]\n", " observations.extend(state_seq)\n", " actions.extend(trajectories.iloc[id][:, 1])\n", " rewards.extend([0]*len(state_seq))\n", " terminals.extend([0]*len(state_seq))\n", " terminals[-1] = 1\n", "\n", " label = dataset['label-pass-fail'].values[id]\n", " rewards[-1] = 1.0 if label == 1.0 else -1.0 # if fail label==1, reward=-1, else reward=1. objective: learn passing policy, penalize failing policy\n", "\n", " print(len(observations)) \n", " observations = np.array(observations).reshape(-1, 1)\n", " print(observations.shape)\n", " actions = np.array(actions)\n", " rewards = np.array(rewards)\n", " terminals = np.array(terminals)\n", " \n", " mdp_dataset = d3rlpy.dataset.MDPDataset(\n", " # Env=Env.ClickstreamWorld,\n", " observations=observations,\n", " actions=actions,\n", " rewards=rewards,\n", " terminals=terminals,\n", " timeouts=None, # consider setting timeout\n", " )\n", " \n", " with open(f\"{save_dir}\", \"w+b\") as f:\n", " mdp_dataset.dump(f)\n", " \n", " return mdp_dataset, observations, actions, rewards, terminals" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "3046\n", "928\n" ] } ], "source": [ "print(len(trajectories_fail))\n", "print(len(trajectories_pass))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "3046" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(trajectories_fail)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1053503\n", "(1053503, 1)\n", "\u001b[2m2024-11-28 10:11.46\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mSignatures have been automatically determined.\u001b[0m \u001b[36maction_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('int64')], shape=[(1,)])\u001b[0m \u001b[36mobservation_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('int64')], shape=[(1,)])\u001b[0m \u001b[36mreward_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('float64')], shape=[(1,)])\u001b[0m\n", "\u001b[2m2024-11-28 10:11.46\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mAction-space has been automatically determined.\u001b[0m \u001b[36maction_space\u001b[0m=\u001b[35m\u001b[0m\n", "\u001b[2m2024-11-28 10:11.46\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mAction size has been automatically determined.\u001b[0m \u001b[36maction_size\u001b[0m=\u001b[35m11\u001b[0m\n", "1154296\n", "(1154296, 1)\n", "\u001b[2m2024-11-28 10:11.53\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mSignatures have been automatically determined.\u001b[0m \u001b[36maction_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('int64')], shape=[(1,)])\u001b[0m \u001b[36mobservation_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('int64')], shape=[(1,)])\u001b[0m \u001b[36mreward_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('float64')], shape=[(1,)])\u001b[0m\n", "\u001b[2m2024-11-28 10:11.53\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mAction-space has been automatically determined.\u001b[0m \u001b[36maction_space\u001b[0m=\u001b[35m\u001b[0m\n", "\u001b[2m2024-11-28 10:11.53\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mAction size has been automatically determined.\u001b[0m \u001b[36maction_size\u001b[0m=\u001b[35m11\u001b[0m\n" ] }, { "data": { "text/plain": [ "(,\n", " array([[22],\n", " [22],\n", " [22],\n", " ...,\n", " [61],\n", " [61],\n", " [75]]),\n", " array([3, 0, 2, ..., 1, 9, 8]),\n", " array([ 0., 0., 0., ..., 0., 0., -1.]),\n", " array([0, 0, 0, ..., 0, 0, 1]))" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "create_mdp_data(trajectories_fail, fail_dataset, save_dir=f\"/data/lan/irl/results/{course_id}/mdp_dataset_fail_split.h5\")\n", "create_mdp_data(trajectories_pass, pass_dataset, save_dir=f\"/data/lan/irl/results/{course_id}/mdp_dataset_pass_split.h5\")" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/huonglan/Documents/codeproject/IRL-MOOC/utils/data_helper.py:247: SettingWithCopyWarning: \n", "A value is trying to be set on a copy of a slice from a DataFrame.\n", "Try using .loc[row_indexer,col_indexer] = value instead\n", "\n", "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", " video_length['duration'] = (video_length['duration']-min_len)/(max_len-min_len)\n", "/Users/huonglan/Documents/codeproject/IRL-MOOC/utils/data_helper.py:252: FutureWarning: A value is trying to be set on a copy of a DataFrame or Series through chained assignment using an inplace method.\n", "The behavior will change in pandas 3.0. This inplace method will never work because the intermediate object on which we are setting values always behaves as a copy.\n", "\n", "For example, when doing 'df[col].method(value, inplace=True)', try using 'df.method({col: value}, inplace=True)' or df[col] = df[col].method(value) instead, to perform the operation inplace on the original object.\n", "\n", "\n", " event_value['duration'].fillna(event_value['duration'].mean(), inplace=True)\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
event_iddifficultychapterdurationis_problem
000.155000100.507272True
110.38449010.507272True
220.35906530.507272True
330.46041820.507272True
440.55692440.507272True
\n", "
" ], "text/plain": [ " event_id difficulty chapter duration is_problem\n", "0 0 0.155000 10 0.507272 True\n", "1 1 0.384490 1 0.507272 True\n", "2 2 0.359065 3 0.507272 True\n", "3 3 0.460418 2 0.507272 True\n", "4 4 0.556924 4 0.507272 True" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "datadir = 'data/mooc_raw/'\n", "schedule = pd.read_csv(f'{datadir}/schedule/{course_id}.csv')\n", "map_event = [dict_event[x] for x in combinedg.event_id.unique()]\n", "values = whatif_values(combinedg, schedule, map_event_id=map_event_id, problem_event=problem_event)\n", "values.head(5)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[2m2024-11-28 10:12.32\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mSignatures have been automatically determined.\u001b[0m \u001b[36maction_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('int64')], shape=[(1,)])\u001b[0m \u001b[36mobservation_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('int64')], shape=[(1,)])\u001b[0m \u001b[36mreward_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('float64')], shape=[(1,)])\u001b[0m\n", "\u001b[2m2024-11-28 10:12.32\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mAction-space has been automatically determined.\u001b[0m \u001b[36maction_space\u001b[0m=\u001b[35m\u001b[0m\n", "\u001b[2m2024-11-28 10:12.32\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mAction size has been automatically determined.\u001b[0m \u001b[36maction_size\u001b[0m=\u001b[35m11\u001b[0m\n", "\u001b[2m2024-11-28 10:12.33\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mSignatures have been automatically determined.\u001b[0m \u001b[36maction_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('int64')], shape=[(1,)])\u001b[0m \u001b[36mobservation_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('int64')], shape=[(1,)])\u001b[0m \u001b[36mreward_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('float64')], shape=[(1,)])\u001b[0m\n", "\u001b[2m2024-11-28 10:12.33\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mAction-space has been automatically determined.\u001b[0m \u001b[36maction_space\u001b[0m=\u001b[35m\u001b[0m\n", "\u001b[2m2024-11-28 10:12.33\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mAction size has been automatically determined.\u001b[0m \u001b[36maction_size\u001b[0m=\u001b[35m11\u001b[0m\n", "\u001b[2m2024-11-28 10:12.35\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mSignatures have been automatically determined.\u001b[0m \u001b[36maction_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('int64')], shape=[(1,)])\u001b[0m \u001b[36mobservation_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('int64')], shape=[(1,)])\u001b[0m \u001b[36mreward_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('float64')], shape=[(1,)])\u001b[0m\n", "\u001b[2m2024-11-28 10:12.35\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mAction-space has been automatically determined.\u001b[0m \u001b[36maction_space\u001b[0m=\u001b[35m\u001b[0m\n", "\u001b[2m2024-11-28 10:12.35\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mAction size has been automatically determined.\u001b[0m \u001b[36maction_size\u001b[0m=\u001b[35m11\u001b[0m\n", "\u001b[2m2024-11-28 10:12.36\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mSignatures have been automatically determined.\u001b[0m \u001b[36maction_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('int64')], shape=[(1,)])\u001b[0m \u001b[36mobservation_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('int64')], shape=[(1,)])\u001b[0m \u001b[36mreward_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('float64')], shape=[(1,)])\u001b[0m\n", "\u001b[2m2024-11-28 10:12.36\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mAction-space has been automatically determined.\u001b[0m \u001b[36maction_space\u001b[0m=\u001b[35m\u001b[0m\n", "\u001b[2m2024-11-28 10:12.36\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mAction size has been automatically determined.\u001b[0m \u001b[36maction_size\u001b[0m=\u001b[35m11\u001b[0m\n", "\u001b[2m2024-11-28 10:12.38\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mSignatures have been automatically determined.\u001b[0m \u001b[36maction_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('int64')], shape=[(1,)])\u001b[0m \u001b[36mobservation_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('int64')], shape=[(1,)])\u001b[0m \u001b[36mreward_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('float64')], shape=[(1,)])\u001b[0m\n", "\u001b[2m2024-11-28 10:12.38\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mAction-space has been automatically determined.\u001b[0m \u001b[36maction_space\u001b[0m=\u001b[35m\u001b[0m\n", "\u001b[2m2024-11-28 10:12.38\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mAction size has been automatically determined.\u001b[0m \u001b[36maction_size\u001b[0m=\u001b[35m11\u001b[0m\n", "\u001b[2m2024-11-28 10:12.39\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mSignatures have been automatically determined.\u001b[0m \u001b[36maction_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('int64')], shape=[(1,)])\u001b[0m \u001b[36mobservation_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('int64')], shape=[(1,)])\u001b[0m \u001b[36mreward_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('float64')], shape=[(1,)])\u001b[0m\n", "\u001b[2m2024-11-28 10:12.39\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mAction-space has been automatically determined.\u001b[0m \u001b[36maction_space\u001b[0m=\u001b[35m\u001b[0m\n", "\u001b[2m2024-11-28 10:12.39\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mAction size has been automatically determined.\u001b[0m \u001b[36maction_size\u001b[0m=\u001b[35m11\u001b[0m\n", "\u001b[2m2024-11-28 10:12.40\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mSignatures have been automatically determined.\u001b[0m \u001b[36maction_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('int64')], shape=[(1,)])\u001b[0m \u001b[36mobservation_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('int64')], shape=[(1,)])\u001b[0m \u001b[36mreward_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('float64')], shape=[(1,)])\u001b[0m\n", "\u001b[2m2024-11-28 10:12.40\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mAction-space has been automatically determined.\u001b[0m \u001b[36maction_space\u001b[0m=\u001b[35m\u001b[0m\n", "\u001b[2m2024-11-28 10:12.40\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mAction size has been automatically determined.\u001b[0m \u001b[36maction_size\u001b[0m=\u001b[35m11\u001b[0m\n", "\u001b[2m2024-11-28 10:12.41\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mSignatures have been automatically determined.\u001b[0m \u001b[36maction_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('int64')], shape=[(1,)])\u001b[0m \u001b[36mobservation_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('int64')], shape=[(1,)])\u001b[0m \u001b[36mreward_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('float64')], shape=[(1,)])\u001b[0m\n", "\u001b[2m2024-11-28 10:12.41\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mAction-space has been automatically determined.\u001b[0m \u001b[36maction_space\u001b[0m=\u001b[35m\u001b[0m\n", "\u001b[2m2024-11-28 10:12.41\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mAction size has been automatically determined.\u001b[0m \u001b[36maction_size\u001b[0m=\u001b[35m11\u001b[0m\n", "\u001b[2m2024-11-28 10:12.42\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mSignatures have been automatically determined.\u001b[0m \u001b[36maction_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('int64')], shape=[(1,)])\u001b[0m \u001b[36mobservation_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('int64')], shape=[(1,)])\u001b[0m \u001b[36mreward_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('float64')], shape=[(1,)])\u001b[0m\n", "\u001b[2m2024-11-28 10:12.42\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mAction-space has been automatically determined.\u001b[0m \u001b[36maction_space\u001b[0m=\u001b[35m\u001b[0m\n", "\u001b[2m2024-11-28 10:12.42\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mAction size has been automatically determined.\u001b[0m \u001b[36maction_size\u001b[0m=\u001b[35m11\u001b[0m\n", "\u001b[2m2024-11-28 10:12.43\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mSignatures have been automatically determined.\u001b[0m \u001b[36maction_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('int64')], shape=[(1,)])\u001b[0m \u001b[36mobservation_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('int64')], shape=[(1,)])\u001b[0m \u001b[36mreward_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('float64')], shape=[(1,)])\u001b[0m\n", "\u001b[2m2024-11-28 10:12.43\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mAction-space has been automatically determined.\u001b[0m \u001b[36maction_space\u001b[0m=\u001b[35m\u001b[0m\n", "\u001b[2m2024-11-28 10:12.43\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mAction size has been automatically determined.\u001b[0m \u001b[36maction_size\u001b[0m=\u001b[35m11\u001b[0m\n", "\u001b[2m2024-11-28 10:12.44\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mSignatures have been automatically determined.\u001b[0m \u001b[36maction_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('int64')], shape=[(1,)])\u001b[0m \u001b[36mobservation_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('int64')], shape=[(1,)])\u001b[0m \u001b[36mreward_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('float64')], shape=[(1,)])\u001b[0m\n", "\u001b[2m2024-11-28 10:12.44\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mAction-space has been automatically determined.\u001b[0m \u001b[36maction_space\u001b[0m=\u001b[35m\u001b[0m\n", "\u001b[2m2024-11-28 10:12.44\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mAction size has been automatically determined.\u001b[0m \u001b[36maction_size\u001b[0m=\u001b[35m11\u001b[0m\n", "\u001b[2m2024-11-28 10:12.47\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mSignatures have been automatically determined.\u001b[0m \u001b[36maction_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('int64')], shape=[(1,)])\u001b[0m \u001b[36mobservation_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('int64')], shape=[(1,)])\u001b[0m \u001b[36mreward_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('float64')], shape=[(1,)])\u001b[0m\n", "\u001b[2m2024-11-28 10:12.47\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mAction-space has been automatically determined.\u001b[0m \u001b[36maction_space\u001b[0m=\u001b[35m\u001b[0m\n", "\u001b[2m2024-11-28 10:12.47\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mAction size has been automatically determined.\u001b[0m \u001b[36maction_size\u001b[0m=\u001b[35m11\u001b[0m\n", "\u001b[2m2024-11-28 10:12.51\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mSignatures have been automatically determined.\u001b[0m \u001b[36maction_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('int64')], shape=[(1,)])\u001b[0m \u001b[36mobservation_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('int64')], shape=[(1,)])\u001b[0m \u001b[36mreward_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('float64')], shape=[(1,)])\u001b[0m\n", "\u001b[2m2024-11-28 10:12.51\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mAction-space has been automatically determined.\u001b[0m \u001b[36maction_space\u001b[0m=\u001b[35m\u001b[0m\n", "\u001b[2m2024-11-28 10:12.51\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mAction size has been automatically determined.\u001b[0m \u001b[36maction_size\u001b[0m=\u001b[35m11\u001b[0m\n", "\u001b[2m2024-11-28 10:12.53\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mSignatures have been automatically determined.\u001b[0m \u001b[36maction_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('int64')], shape=[(1,)])\u001b[0m \u001b[36mobservation_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('int64')], shape=[(1,)])\u001b[0m \u001b[36mreward_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('float64')], shape=[(1,)])\u001b[0m\n", "\u001b[2m2024-11-28 10:12.53\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mAction-space has been automatically determined.\u001b[0m \u001b[36maction_space\u001b[0m=\u001b[35m\u001b[0m\n", "\u001b[2m2024-11-28 10:12.53\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mAction size has been automatically determined.\u001b[0m \u001b[36maction_size\u001b[0m=\u001b[35m11\u001b[0m\n", "\u001b[2m2024-11-28 10:12.54\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mSignatures have been automatically determined.\u001b[0m \u001b[36maction_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('int64')], shape=[(1,)])\u001b[0m \u001b[36mobservation_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('int64')], shape=[(1,)])\u001b[0m \u001b[36mreward_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('float64')], shape=[(1,)])\u001b[0m\n", "\u001b[2m2024-11-28 10:12.54\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mAction-space has been automatically determined.\u001b[0m \u001b[36maction_space\u001b[0m=\u001b[35m\u001b[0m\n", "\u001b[2m2024-11-28 10:12.54\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mAction size has been automatically determined.\u001b[0m \u001b[36maction_size\u001b[0m=\u001b[35m11\u001b[0m\n", "\u001b[2m2024-11-28 10:12.55\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mSignatures have been automatically determined.\u001b[0m \u001b[36maction_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('int64')], shape=[(1,)])\u001b[0m \u001b[36mobservation_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('int64')], shape=[(1,)])\u001b[0m \u001b[36mreward_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('float64')], shape=[(1,)])\u001b[0m\n", "\u001b[2m2024-11-28 10:12.55\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mAction-space has been automatically determined.\u001b[0m \u001b[36maction_space\u001b[0m=\u001b[35m\u001b[0m\n", "\u001b[2m2024-11-28 10:12.55\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mAction size has been automatically determined.\u001b[0m \u001b[36maction_size\u001b[0m=\u001b[35m11\u001b[0m\n", "\u001b[2m2024-11-28 10:12.55\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mSignatures have been automatically determined.\u001b[0m \u001b[36maction_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('int64')], shape=[(1,)])\u001b[0m \u001b[36mobservation_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('int64')], shape=[(1,)])\u001b[0m \u001b[36mreward_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('float64')], shape=[(1,)])\u001b[0m\n", "\u001b[2m2024-11-28 10:12.55\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mAction-space has been automatically determined.\u001b[0m \u001b[36maction_space\u001b[0m=\u001b[35m\u001b[0m\n", "\u001b[2m2024-11-28 10:12.55\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mAction size has been automatically determined.\u001b[0m \u001b[36maction_size\u001b[0m=\u001b[35m11\u001b[0m\n", "\u001b[2m2024-11-28 10:12.56\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mSignatures have been automatically determined.\u001b[0m \u001b[36maction_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('int64')], shape=[(1,)])\u001b[0m \u001b[36mobservation_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('int64')], shape=[(1,)])\u001b[0m \u001b[36mreward_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('float64')], shape=[(1,)])\u001b[0m\n", "\u001b[2m2024-11-28 10:12.56\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mAction-space has been automatically determined.\u001b[0m \u001b[36maction_space\u001b[0m=\u001b[35m\u001b[0m\n", "\u001b[2m2024-11-28 10:12.56\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mAction size has been automatically determined.\u001b[0m \u001b[36maction_size\u001b[0m=\u001b[35m11\u001b[0m\n", "\u001b[2m2024-11-28 10:12.56\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mSignatures have been automatically determined.\u001b[0m \u001b[36maction_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('int64')], shape=[(1,)])\u001b[0m \u001b[36mobservation_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('int64')], shape=[(1,)])\u001b[0m \u001b[36mreward_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('float64')], shape=[(1,)])\u001b[0m\n", "\u001b[2m2024-11-28 10:12.56\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mAction-space has been automatically determined.\u001b[0m \u001b[36maction_space\u001b[0m=\u001b[35m\u001b[0m\n", "\u001b[2m2024-11-28 10:12.56\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mAction size has been automatically determined.\u001b[0m \u001b[36maction_size\u001b[0m=\u001b[35m11\u001b[0m\n", "\u001b[2m2024-11-28 10:12.57\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mSignatures have been automatically determined.\u001b[0m \u001b[36maction_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('int64')], shape=[(1,)])\u001b[0m \u001b[36mobservation_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('int64')], shape=[(1,)])\u001b[0m \u001b[36mreward_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('float64')], shape=[(1,)])\u001b[0m\n", "\u001b[2m2024-11-28 10:12.57\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mAction-space has been automatically determined.\u001b[0m \u001b[36maction_space\u001b[0m=\u001b[35m\u001b[0m\n", "\u001b[2m2024-11-28 10:12.57\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mAction size has been automatically determined.\u001b[0m \u001b[36maction_size\u001b[0m=\u001b[35m11\u001b[0m\n" ] } ], "source": [ "def create_mdp_data_each_week(traj, save_dir='/data/lan/dsp-002/mdp_dataset_split_pass'):\n", " for week in range(10):\n", " observations, actions, rewards, terminals, start_states = [], [], [], [], []\n", " for id in range(len(traj[week])):\n", " if len(traj[week].iloc[id]) == 0:\n", " continue\n", " state_seq = traj[week].iloc[id][:, 0]\n", " observations.extend(state_seq)\n", " actions.extend(traj[week].iloc[id][:, 1])\n", " rewards.extend([0]*len(state_seq))\n", " terminals.extend([0]*len(state_seq))\n", " terminals[-1] = 1\n", " label = dataset['label-pass-fail'].values[id]\n", " rewards[-1] = 1.0 if label == 1.0 else -1.0 # if fail label==1, reward=-1, else reward=1. objective: learn passing policy, penalize failing policy\n", "\n", " observations = np.array(observations).reshape(-1, 1)\n", " actions = np.array(actions)\n", " rewards = np.array(rewards)\n", " terminals = np.array(terminals)\n", " \n", " mdp_dataset = d3rlpy.dataset.MDPDataset(\n", " observations=observations,\n", " actions=actions,\n", " rewards=rewards,\n", " terminals=terminals,\n", " timeouts=None, \n", " )\n", " with open(f\"{save_dir}_{week}.h5\", \"w+b\") as f:\n", " mdp_dataset.dump(f)\n", "\n", "create_mdp_data_each_week(trajectories_each_week_pass, save_dir=f\"/data/lan/irl/results/{course_id}/mdp_dataset_pass_split\")\n", "create_mdp_data_each_week(trajectories_each_week_fail, save_dir=f\"/data/lan/irl/results/{course_id}/mdp_dataset_fail_split\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[2m2024-11-26 23:40.09\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mdataset info \u001b[0m \u001b[36mdataset_info\u001b[0m=\u001b[35mDatasetInfo(observation_signature=Signature(dtype=[dtype('int64')], shape=[(1,)]), action_signature=Signature(dtype=[dtype('int64')], shape=[(1,)]), reward_signature=Signature(dtype=[dtype('float64')], shape=[(1,)]), action_space=, action_size=11)\u001b[0m\n", "\u001b[2m2024-11-26 23:40.09\u001b[0m [\u001b[32m\u001b[1mdebug \u001b[0m] \u001b[1mBuilding models... \u001b[0m\n", "\u001b[2m2024-11-26 23:40.09\u001b[0m [\u001b[32m\u001b[1mdebug \u001b[0m] \u001b[1mModels have been built. \u001b[0m\n", "\u001b[2m2024-11-26 23:40.09\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mDirectory is created at d3rlpy_logs/DiscreteFQE_20241126234009\u001b[0m\n", "\u001b[2m2024-11-26 23:40.09\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mParameters \u001b[0m \u001b[36mparams\u001b[0m=\u001b[35m{'observation_shape': [1], 'action_size': 11, 'config': {'type': 'fqe', 'params': {'batch_size': 100, 'gamma': 0.99, 'observation_scaler': {'type': 'none', 'params': {}}, 'action_scaler': {'type': 'none', 'params': {}}, 'reward_scaler': {'type': 'none', 'params': {}}, 'compile_graph': False, 'learning_rate': 0.0001, 'optim_factory': {'type': 'adam', 'params': {'clip_grad_norm': None, 'lr_scheduler_factory': {'type': 'none', 'params': {}}, 'betas': [0.9, 0.999], 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}}, 'encoder_factory': {'type': 'default', 'params': {'activation': 'relu', 'use_batch_norm': False, 'dropout_rate': None}}, 'q_func_factory': {'type': 'mean', 'params': {'share_encoder': False}}, 'n_critics': 1, 'target_update_interval': 100}}}\u001b[0m\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "89a001ca487c4a38a0fae2502894ed68", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Epoch 1/1: 0%| | 0/10000 [00:00>>> Length bound', length_bound)\n", " syn_students = []\n", " state, length = 0, 0\n", " models = []\n", " if model_name == 'decision_transformer':\n", " path = find_file(model_dir+f'_all', f'model_10000.d3')\n", " print(f'>>>> Load model for all weeks from', path)\n", " dqn = d3rlpy.load_learnable(path)\n", " actor = dqn.as_stateful_wrapper(target_return=0)\n", " models = [actor] * num_week\n", " else:\n", " for i in range(num_week):\n", " path = find_file(model_dir+f'_week_{i}', f'model_10000.d3')\n", " print(f'>>>> Load model week {i + 1} from', path)\n", " dqn = d3rlpy.load_learnable(path)\n", " models.append(dqn)\n", "\n", " for i in tqdm(range(n_trajectories)):\n", " student = []\n", " for week in range(num_week):\n", " trajectory = []\n", " obs, _ = whatif_world.reset()\n", " state = obs['agent'] \n", " length = np.random.randint(length_bound[week][0], length_bound[week][1])\n", " for i in range(length):\n", " obs = None\n", " iter = 0\n", " while obs is None and iter < MAX_ITER: # make sure return valid action\n", " action = soft_policy_predict(models[week], np.array([state]).reshape(1, -1), model_name=model_name)\n", " obs = whatif_world.step(action)\n", " iter += 1\n", " if iter >= MAX_ITER:\n", " print('>>>> Invalid action', ': end at step', i)\n", " break\n", " next_state = obs[0]['agent']\n", " trajectory.append([state, action, next_state])\n", " state = next_state\n", " if model_name == 'decision_transformer':\n", " models[week].reset()\n", " student.append(trajectory)\n", " syn_students.append(student)\n", " return syn_students\n", "\n", "# make_syn_student(whatif_world, trajectories=trajectories_each_week, model_dir='results/dsp-002/d3rlpy_logs_pass', n_trajectories=3, num_week=1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "mode_student = 'pass'\n", "model_list = ['dqn', 'decision_transformer', 'bc', 'sac', 'ddgp']\n", "model_name = 'sac'\n", "model = None\n", "assert model_name in model_list\n", "# td_error_evaluator = TDErrorEvaluator(episodes=mdp_dataset.episodes)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[2m2024-12-09 11:08.54\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mSignatures have been automatically determined.\u001b[0m \u001b[36maction_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('int64')], shape=[(1,)])\u001b[0m \u001b[36mobservation_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('int64')], shape=[(1,)])\u001b[0m \u001b[36mreward_signature\u001b[0m=\u001b[35mSignature(dtype=[dtype('float64')], shape=[(1,)])\u001b[0m\n", "\u001b[2m2024-12-09 11:08.54\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mAction-space has been automatically determined.\u001b[0m \u001b[36maction_space\u001b[0m=\u001b[35m\u001b[0m\n", "\u001b[2m2024-12-09 11:08.54\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mAction size has been automatically determined.\u001b[0m \u001b[36maction_size\u001b[0m=\u001b[35m11\u001b[0m\n", "\u001b[2m2024-12-09 11:08.54\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mdataset info \u001b[0m \u001b[36mdataset_info\u001b[0m=\u001b[35mDatasetInfo(observation_signature=Signature(dtype=[dtype('int64')], shape=[(1,)]), action_signature=Signature(dtype=[dtype('int64')], shape=[(1,)]), reward_signature=Signature(dtype=[dtype('float64')], shape=[(1,)]), action_space=, action_size=11)\u001b[0m\n", "\u001b[2m2024-12-09 11:08.54\u001b[0m [\u001b[33m\u001b[1mwarning \u001b[0m] \u001b[1mSkip building models since they're already built.\u001b[0m\n", "\u001b[2m2024-12-09 11:08.54\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mDirectory is created at results/dsp-002/sac/d3rlpy_logs_pass_week_0/DiscreteSAC_20241209110854\u001b[0m\n", "\u001b[2m2024-12-09 11:08.54\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mParameters \u001b[0m \u001b[36mparams\u001b[0m=\u001b[35m{'observation_shape': [1], 'action_size': 11, 'config': {'type': 'discrete_sac', 'params': {'batch_size': 64, 'gamma': 0.99, 'observation_scaler': {'type': 'none', 'params': {}}, 'action_scaler': {'type': 'none', 'params': {}}, 'reward_scaler': {'type': 'none', 'params': {}}, 'compile_graph': False, 'actor_learning_rate': 0.0003, 'critic_learning_rate': 0.0003, 'temp_learning_rate': 0.0003, 'actor_optim_factory': {'type': 'adam', 'params': {'clip_grad_norm': None, 'lr_scheduler_factory': {'type': 'none', 'params': {}}, 'betas': [0.9, 0.999], 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}}, 'critic_optim_factory': {'type': 'adam', 'params': {'clip_grad_norm': None, 'lr_scheduler_factory': {'type': 'none', 'params': {}}, 'betas': [0.9, 0.999], 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}}, 'temp_optim_factory': {'type': 'adam', 'params': {'clip_grad_norm': None, 'lr_scheduler_factory': {'type': 'none', 'params': {}}, 'betas': [0.9, 0.999], 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}}, 'actor_encoder_factory': {'type': 'default', 'params': {'activation': 'relu', 'use_batch_norm': False, 'dropout_rate': None}}, 'critic_encoder_factory': {'type': 'default', 'params': {'activation': 'relu', 'use_batch_norm': False, 'dropout_rate': None}}, 'q_func_factory': {'type': 'mean', 'params': {'share_encoder': False}}, 'n_critics': 2, 'initial_temperature': 1.0, 'target_update_interval': 8000}}}\u001b[0m\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "64078974919247058688a599e02cf06f", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Epoch 1/1: 0%| | 0/10000 [00:00\u001b[0m\n", "\u001b[2m2024-12-09 11:10.22\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mAction size has been automatically determined.\u001b[0m \u001b[36maction_size\u001b[0m=\u001b[35m11\u001b[0m\n", "\u001b[2m2024-12-09 11:10.22\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mdataset info \u001b[0m \u001b[36mdataset_info\u001b[0m=\u001b[35mDatasetInfo(observation_signature=Signature(dtype=[dtype('int64')], shape=[(1,)]), action_signature=Signature(dtype=[dtype('int64')], shape=[(1,)]), reward_signature=Signature(dtype=[dtype('float64')], shape=[(1,)]), action_space=, action_size=11)\u001b[0m\n", "\u001b[2m2024-12-09 11:10.22\u001b[0m [\u001b[33m\u001b[1mwarning \u001b[0m] \u001b[1mSkip building models since they're already built.\u001b[0m\n", "\u001b[2m2024-12-09 11:10.22\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mDirectory is created at results/dsp-002/sac/d3rlpy_logs_pass_week_1/DiscreteSAC_20241209111022\u001b[0m\n", "\u001b[2m2024-12-09 11:10.22\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mParameters \u001b[0m \u001b[36mparams\u001b[0m=\u001b[35m{'observation_shape': [1], 'action_size': 11, 'config': {'type': 'discrete_sac', 'params': {'batch_size': 64, 'gamma': 0.99, 'observation_scaler': {'type': 'none', 'params': {}}, 'action_scaler': {'type': 'none', 'params': {}}, 'reward_scaler': {'type': 'none', 'params': {}}, 'compile_graph': False, 'actor_learning_rate': 0.0003, 'critic_learning_rate': 0.0003, 'temp_learning_rate': 0.0003, 'actor_optim_factory': {'type': 'adam', 'params': {'clip_grad_norm': None, 'lr_scheduler_factory': {'type': 'none', 'params': {}}, 'betas': [0.9, 0.999], 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}}, 'critic_optim_factory': {'type': 'adam', 'params': {'clip_grad_norm': None, 'lr_scheduler_factory': {'type': 'none', 'params': {}}, 'betas': [0.9, 0.999], 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}}, 'temp_optim_factory': {'type': 'adam', 'params': {'clip_grad_norm': None, 'lr_scheduler_factory': {'type': 'none', 'params': {}}, 'betas': [0.9, 0.999], 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}}, 'actor_encoder_factory': {'type': 'default', 'params': {'activation': 'relu', 'use_batch_norm': False, 'dropout_rate': None}}, 'critic_encoder_factory': {'type': 'default', 'params': {'activation': 'relu', 'use_batch_norm': False, 'dropout_rate': None}}, 'q_func_factory': {'type': 'mean', 'params': {'share_encoder': False}}, 'n_critics': 2, 'initial_temperature': 1.0, 'target_update_interval': 8000}}}\u001b[0m\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ee77a7ee25184509a0760681b4100887", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Epoch 1/1: 0%| | 0/10000 [00:00\u001b[0m\n", "\u001b[2m2024-12-09 11:12.00\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mAction size has been automatically determined.\u001b[0m \u001b[36maction_size\u001b[0m=\u001b[35m11\u001b[0m\n", "\u001b[2m2024-12-09 11:12.00\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mdataset info \u001b[0m \u001b[36mdataset_info\u001b[0m=\u001b[35mDatasetInfo(observation_signature=Signature(dtype=[dtype('int64')], shape=[(1,)]), action_signature=Signature(dtype=[dtype('int64')], shape=[(1,)]), reward_signature=Signature(dtype=[dtype('float64')], shape=[(1,)]), action_space=, action_size=11)\u001b[0m\n", "\u001b[2m2024-12-09 11:12.00\u001b[0m [\u001b[33m\u001b[1mwarning \u001b[0m] \u001b[1mSkip building models since they're already built.\u001b[0m\n", "\u001b[2m2024-12-09 11:12.00\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mDirectory is created at results/dsp-002/sac/d3rlpy_logs_pass_week_2/DiscreteSAC_20241209111200\u001b[0m\n", "\u001b[2m2024-12-09 11:12.00\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mParameters \u001b[0m \u001b[36mparams\u001b[0m=\u001b[35m{'observation_shape': [1], 'action_size': 11, 'config': {'type': 'discrete_sac', 'params': {'batch_size': 64, 'gamma': 0.99, 'observation_scaler': {'type': 'none', 'params': {}}, 'action_scaler': {'type': 'none', 'params': {}}, 'reward_scaler': {'type': 'none', 'params': {}}, 'compile_graph': False, 'actor_learning_rate': 0.0003, 'critic_learning_rate': 0.0003, 'temp_learning_rate': 0.0003, 'actor_optim_factory': {'type': 'adam', 'params': {'clip_grad_norm': None, 'lr_scheduler_factory': {'type': 'none', 'params': {}}, 'betas': [0.9, 0.999], 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}}, 'critic_optim_factory': {'type': 'adam', 'params': {'clip_grad_norm': None, 'lr_scheduler_factory': {'type': 'none', 'params': {}}, 'betas': [0.9, 0.999], 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}}, 'temp_optim_factory': {'type': 'adam', 'params': {'clip_grad_norm': None, 'lr_scheduler_factory': {'type': 'none', 'params': {}}, 'betas': [0.9, 0.999], 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}}, 'actor_encoder_factory': {'type': 'default', 'params': {'activation': 'relu', 'use_batch_norm': False, 'dropout_rate': None}}, 'critic_encoder_factory': {'type': 'default', 'params': {'activation': 'relu', 'use_batch_norm': False, 'dropout_rate': None}}, 'q_func_factory': {'type': 'mean', 'params': {'share_encoder': False}}, 'n_critics': 2, 'initial_temperature': 1.0, 'target_update_interval': 8000}}}\u001b[0m\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ac52c2c5392d44b5b6a8d8ef78b714b6", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Epoch 1/1: 0%| | 0/10000 [00:00\u001b[0m\n", "\u001b[2m2024-12-09 11:13.20\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mAction size has been automatically determined.\u001b[0m \u001b[36maction_size\u001b[0m=\u001b[35m11\u001b[0m\n", "\u001b[2m2024-12-09 11:13.20\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mdataset info \u001b[0m \u001b[36mdataset_info\u001b[0m=\u001b[35mDatasetInfo(observation_signature=Signature(dtype=[dtype('int64')], shape=[(1,)]), action_signature=Signature(dtype=[dtype('int64')], shape=[(1,)]), reward_signature=Signature(dtype=[dtype('float64')], shape=[(1,)]), action_space=, action_size=11)\u001b[0m\n", "\u001b[2m2024-12-09 11:13.20\u001b[0m [\u001b[33m\u001b[1mwarning \u001b[0m] \u001b[1mSkip building models since they're already built.\u001b[0m\n", "\u001b[2m2024-12-09 11:13.20\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mDirectory is created at results/dsp-002/sac/d3rlpy_logs_pass_week_3/DiscreteSAC_20241209111320\u001b[0m\n", "\u001b[2m2024-12-09 11:13.20\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mParameters \u001b[0m \u001b[36mparams\u001b[0m=\u001b[35m{'observation_shape': [1], 'action_size': 11, 'config': {'type': 'discrete_sac', 'params': {'batch_size': 64, 'gamma': 0.99, 'observation_scaler': {'type': 'none', 'params': {}}, 'action_scaler': {'type': 'none', 'params': {}}, 'reward_scaler': {'type': 'none', 'params': {}}, 'compile_graph': False, 'actor_learning_rate': 0.0003, 'critic_learning_rate': 0.0003, 'temp_learning_rate': 0.0003, 'actor_optim_factory': {'type': 'adam', 'params': {'clip_grad_norm': None, 'lr_scheduler_factory': {'type': 'none', 'params': {}}, 'betas': [0.9, 0.999], 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}}, 'critic_optim_factory': {'type': 'adam', 'params': {'clip_grad_norm': None, 'lr_scheduler_factory': {'type': 'none', 'params': {}}, 'betas': [0.9, 0.999], 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}}, 'temp_optim_factory': {'type': 'adam', 'params': {'clip_grad_norm': None, 'lr_scheduler_factory': {'type': 'none', 'params': {}}, 'betas': [0.9, 0.999], 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}}, 'actor_encoder_factory': {'type': 'default', 'params': {'activation': 'relu', 'use_batch_norm': False, 'dropout_rate': None}}, 'critic_encoder_factory': {'type': 'default', 'params': {'activation': 'relu', 'use_batch_norm': False, 'dropout_rate': None}}, 'q_func_factory': {'type': 'mean', 'params': {'share_encoder': False}}, 'n_critics': 2, 'initial_temperature': 1.0, 'target_update_interval': 8000}}}\u001b[0m\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "845a45ba14824fc0a6927207b136ffe0", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Epoch 1/1: 0%| | 0/10000 [00:00\u001b[0m\n", "\u001b[2m2024-12-09 11:14.43\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mAction size has been automatically determined.\u001b[0m \u001b[36maction_size\u001b[0m=\u001b[35m11\u001b[0m\n", "\u001b[2m2024-12-09 11:14.43\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mdataset info \u001b[0m \u001b[36mdataset_info\u001b[0m=\u001b[35mDatasetInfo(observation_signature=Signature(dtype=[dtype('int64')], shape=[(1,)]), action_signature=Signature(dtype=[dtype('int64')], shape=[(1,)]), reward_signature=Signature(dtype=[dtype('float64')], shape=[(1,)]), action_space=, action_size=11)\u001b[0m\n", "\u001b[2m2024-12-09 11:14.43\u001b[0m [\u001b[33m\u001b[1mwarning \u001b[0m] \u001b[1mSkip building models since they're already built.\u001b[0m\n", "\u001b[2m2024-12-09 11:14.43\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mDirectory is created at results/dsp-002/sac/d3rlpy_logs_pass_week_4/DiscreteSAC_20241209111443\u001b[0m\n", "\u001b[2m2024-12-09 11:14.43\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mParameters \u001b[0m \u001b[36mparams\u001b[0m=\u001b[35m{'observation_shape': [1], 'action_size': 11, 'config': {'type': 'discrete_sac', 'params': {'batch_size': 64, 'gamma': 0.99, 'observation_scaler': {'type': 'none', 'params': {}}, 'action_scaler': {'type': 'none', 'params': {}}, 'reward_scaler': {'type': 'none', 'params': {}}, 'compile_graph': False, 'actor_learning_rate': 0.0003, 'critic_learning_rate': 0.0003, 'temp_learning_rate': 0.0003, 'actor_optim_factory': {'type': 'adam', 'params': {'clip_grad_norm': None, 'lr_scheduler_factory': {'type': 'none', 'params': {}}, 'betas': [0.9, 0.999], 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}}, 'critic_optim_factory': {'type': 'adam', 'params': {'clip_grad_norm': None, 'lr_scheduler_factory': {'type': 'none', 'params': {}}, 'betas': [0.9, 0.999], 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}}, 'temp_optim_factory': {'type': 'adam', 'params': {'clip_grad_norm': None, 'lr_scheduler_factory': {'type': 'none', 'params': {}}, 'betas': [0.9, 0.999], 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}}, 'actor_encoder_factory': {'type': 'default', 'params': {'activation': 'relu', 'use_batch_norm': False, 'dropout_rate': None}}, 'critic_encoder_factory': {'type': 'default', 'params': {'activation': 'relu', 'use_batch_norm': False, 'dropout_rate': None}}, 'q_func_factory': {'type': 'mean', 'params': {'share_encoder': False}}, 'n_critics': 2, 'initial_temperature': 1.0, 'target_update_interval': 8000}}}\u001b[0m\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "fbdfa5e408fe4b058b709f1e2950162a", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Epoch 1/1: 0%| | 0/10000 [00:00\u001b[0m\n", "\u001b[2m2024-12-09 11:15.51\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mAction size has been automatically determined.\u001b[0m \u001b[36maction_size\u001b[0m=\u001b[35m11\u001b[0m\n", "\u001b[2m2024-12-09 11:15.51\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mdataset info \u001b[0m \u001b[36mdataset_info\u001b[0m=\u001b[35mDatasetInfo(observation_signature=Signature(dtype=[dtype('int64')], shape=[(1,)]), action_signature=Signature(dtype=[dtype('int64')], shape=[(1,)]), reward_signature=Signature(dtype=[dtype('float64')], shape=[(1,)]), action_space=, action_size=11)\u001b[0m\n", "\u001b[2m2024-12-09 11:15.51\u001b[0m [\u001b[33m\u001b[1mwarning \u001b[0m] \u001b[1mSkip building models since they're already built.\u001b[0m\n", "\u001b[2m2024-12-09 11:15.51\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mDirectory is created at results/dsp-002/sac/d3rlpy_logs_pass_week_5/DiscreteSAC_20241209111551\u001b[0m\n", "\u001b[2m2024-12-09 11:15.51\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mParameters \u001b[0m \u001b[36mparams\u001b[0m=\u001b[35m{'observation_shape': [1], 'action_size': 11, 'config': {'type': 'discrete_sac', 'params': {'batch_size': 64, 'gamma': 0.99, 'observation_scaler': {'type': 'none', 'params': {}}, 'action_scaler': {'type': 'none', 'params': {}}, 'reward_scaler': {'type': 'none', 'params': {}}, 'compile_graph': False, 'actor_learning_rate': 0.0003, 'critic_learning_rate': 0.0003, 'temp_learning_rate': 0.0003, 'actor_optim_factory': {'type': 'adam', 'params': {'clip_grad_norm': None, 'lr_scheduler_factory': {'type': 'none', 'params': {}}, 'betas': [0.9, 0.999], 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}}, 'critic_optim_factory': {'type': 'adam', 'params': {'clip_grad_norm': None, 'lr_scheduler_factory': {'type': 'none', 'params': {}}, 'betas': [0.9, 0.999], 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}}, 'temp_optim_factory': {'type': 'adam', 'params': {'clip_grad_norm': None, 'lr_scheduler_factory': {'type': 'none', 'params': {}}, 'betas': [0.9, 0.999], 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}}, 'actor_encoder_factory': {'type': 'default', 'params': {'activation': 'relu', 'use_batch_norm': False, 'dropout_rate': None}}, 'critic_encoder_factory': {'type': 'default', 'params': {'activation': 'relu', 'use_batch_norm': False, 'dropout_rate': None}}, 'q_func_factory': {'type': 'mean', 'params': {'share_encoder': False}}, 'n_critics': 2, 'initial_temperature': 1.0, 'target_update_interval': 8000}}}\u001b[0m\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "59a2ca59cdb0442d857e3f94d60eaef6", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Epoch 1/1: 0%| | 0/10000 [00:00\u001b[0m\n", "\u001b[2m2024-12-09 11:17.04\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mAction size has been automatically determined.\u001b[0m \u001b[36maction_size\u001b[0m=\u001b[35m11\u001b[0m\n", "\u001b[2m2024-12-09 11:17.04\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mdataset info \u001b[0m \u001b[36mdataset_info\u001b[0m=\u001b[35mDatasetInfo(observation_signature=Signature(dtype=[dtype('int64')], shape=[(1,)]), action_signature=Signature(dtype=[dtype('int64')], shape=[(1,)]), reward_signature=Signature(dtype=[dtype('float64')], shape=[(1,)]), action_space=, action_size=11)\u001b[0m\n", "\u001b[2m2024-12-09 11:17.04\u001b[0m [\u001b[33m\u001b[1mwarning \u001b[0m] \u001b[1mSkip building models since they're already built.\u001b[0m\n", "\u001b[2m2024-12-09 11:17.04\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mDirectory is created at results/dsp-002/sac/d3rlpy_logs_pass_week_6/DiscreteSAC_20241209111704\u001b[0m\n", "\u001b[2m2024-12-09 11:17.04\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mParameters \u001b[0m \u001b[36mparams\u001b[0m=\u001b[35m{'observation_shape': [1], 'action_size': 11, 'config': {'type': 'discrete_sac', 'params': {'batch_size': 64, 'gamma': 0.99, 'observation_scaler': {'type': 'none', 'params': {}}, 'action_scaler': {'type': 'none', 'params': {}}, 'reward_scaler': {'type': 'none', 'params': {}}, 'compile_graph': False, 'actor_learning_rate': 0.0003, 'critic_learning_rate': 0.0003, 'temp_learning_rate': 0.0003, 'actor_optim_factory': {'type': 'adam', 'params': {'clip_grad_norm': None, 'lr_scheduler_factory': {'type': 'none', 'params': {}}, 'betas': [0.9, 0.999], 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}}, 'critic_optim_factory': {'type': 'adam', 'params': {'clip_grad_norm': None, 'lr_scheduler_factory': {'type': 'none', 'params': {}}, 'betas': [0.9, 0.999], 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}}, 'temp_optim_factory': {'type': 'adam', 'params': {'clip_grad_norm': None, 'lr_scheduler_factory': {'type': 'none', 'params': {}}, 'betas': [0.9, 0.999], 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}}, 'actor_encoder_factory': {'type': 'default', 'params': {'activation': 'relu', 'use_batch_norm': False, 'dropout_rate': None}}, 'critic_encoder_factory': {'type': 'default', 'params': {'activation': 'relu', 'use_batch_norm': False, 'dropout_rate': None}}, 'q_func_factory': {'type': 'mean', 'params': {'share_encoder': False}}, 'n_critics': 2, 'initial_temperature': 1.0, 'target_update_interval': 8000}}}\u001b[0m\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a57b72c5516240be832516b216b2cc22", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Epoch 1/1: 0%| | 0/10000 [00:00\u001b[0m\n", "\u001b[2m2024-12-09 11:18.44\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mAction size has been automatically determined.\u001b[0m \u001b[36maction_size\u001b[0m=\u001b[35m11\u001b[0m\n", "\u001b[2m2024-12-09 11:18.44\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mdataset info \u001b[0m \u001b[36mdataset_info\u001b[0m=\u001b[35mDatasetInfo(observation_signature=Signature(dtype=[dtype('int64')], shape=[(1,)]), action_signature=Signature(dtype=[dtype('int64')], shape=[(1,)]), reward_signature=Signature(dtype=[dtype('float64')], shape=[(1,)]), action_space=, action_size=11)\u001b[0m\n", "\u001b[2m2024-12-09 11:18.44\u001b[0m [\u001b[33m\u001b[1mwarning \u001b[0m] \u001b[1mSkip building models since they're already built.\u001b[0m\n", "\u001b[2m2024-12-09 11:18.44\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mDirectory is created at results/dsp-002/sac/d3rlpy_logs_pass_week_7/DiscreteSAC_20241209111844\u001b[0m\n", "\u001b[2m2024-12-09 11:18.44\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mParameters \u001b[0m \u001b[36mparams\u001b[0m=\u001b[35m{'observation_shape': [1], 'action_size': 11, 'config': {'type': 'discrete_sac', 'params': {'batch_size': 64, 'gamma': 0.99, 'observation_scaler': {'type': 'none', 'params': {}}, 'action_scaler': {'type': 'none', 'params': {}}, 'reward_scaler': {'type': 'none', 'params': {}}, 'compile_graph': False, 'actor_learning_rate': 0.0003, 'critic_learning_rate': 0.0003, 'temp_learning_rate': 0.0003, 'actor_optim_factory': {'type': 'adam', 'params': {'clip_grad_norm': None, 'lr_scheduler_factory': {'type': 'none', 'params': {}}, 'betas': [0.9, 0.999], 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}}, 'critic_optim_factory': {'type': 'adam', 'params': {'clip_grad_norm': None, 'lr_scheduler_factory': {'type': 'none', 'params': {}}, 'betas': [0.9, 0.999], 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}}, 'temp_optim_factory': {'type': 'adam', 'params': {'clip_grad_norm': None, 'lr_scheduler_factory': {'type': 'none', 'params': {}}, 'betas': [0.9, 0.999], 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}}, 'actor_encoder_factory': {'type': 'default', 'params': {'activation': 'relu', 'use_batch_norm': False, 'dropout_rate': None}}, 'critic_encoder_factory': {'type': 'default', 'params': {'activation': 'relu', 'use_batch_norm': False, 'dropout_rate': None}}, 'q_func_factory': {'type': 'mean', 'params': {'share_encoder': False}}, 'n_critics': 2, 'initial_temperature': 1.0, 'target_update_interval': 8000}}}\u001b[0m\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2d80e8e2b1484217bce57abd52c746ce", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Epoch 1/1: 0%| | 0/10000 [00:00\u001b[0m\n", "\u001b[2m2024-12-09 11:19.46\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mAction size has been automatically determined.\u001b[0m \u001b[36maction_size\u001b[0m=\u001b[35m11\u001b[0m\n", "\u001b[2m2024-12-09 11:19.46\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mdataset info \u001b[0m \u001b[36mdataset_info\u001b[0m=\u001b[35mDatasetInfo(observation_signature=Signature(dtype=[dtype('int64')], shape=[(1,)]), action_signature=Signature(dtype=[dtype('int64')], shape=[(1,)]), reward_signature=Signature(dtype=[dtype('float64')], shape=[(1,)]), action_space=, action_size=11)\u001b[0m\n", "\u001b[2m2024-12-09 11:19.46\u001b[0m [\u001b[33m\u001b[1mwarning \u001b[0m] \u001b[1mSkip building models since they're already built.\u001b[0m\n", "\u001b[2m2024-12-09 11:19.46\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mDirectory is created at results/dsp-002/sac/d3rlpy_logs_pass_week_8/DiscreteSAC_20241209111946\u001b[0m\n", "\u001b[2m2024-12-09 11:19.46\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mParameters \u001b[0m \u001b[36mparams\u001b[0m=\u001b[35m{'observation_shape': [1], 'action_size': 11, 'config': {'type': 'discrete_sac', 'params': {'batch_size': 64, 'gamma': 0.99, 'observation_scaler': {'type': 'none', 'params': {}}, 'action_scaler': {'type': 'none', 'params': {}}, 'reward_scaler': {'type': 'none', 'params': {}}, 'compile_graph': False, 'actor_learning_rate': 0.0003, 'critic_learning_rate': 0.0003, 'temp_learning_rate': 0.0003, 'actor_optim_factory': {'type': 'adam', 'params': {'clip_grad_norm': None, 'lr_scheduler_factory': {'type': 'none', 'params': {}}, 'betas': [0.9, 0.999], 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}}, 'critic_optim_factory': {'type': 'adam', 'params': {'clip_grad_norm': None, 'lr_scheduler_factory': {'type': 'none', 'params': {}}, 'betas': [0.9, 0.999], 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}}, 'temp_optim_factory': {'type': 'adam', 'params': {'clip_grad_norm': None, 'lr_scheduler_factory': {'type': 'none', 'params': {}}, 'betas': [0.9, 0.999], 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}}, 'actor_encoder_factory': {'type': 'default', 'params': {'activation': 'relu', 'use_batch_norm': False, 'dropout_rate': None}}, 'critic_encoder_factory': {'type': 'default', 'params': {'activation': 'relu', 'use_batch_norm': False, 'dropout_rate': None}}, 'q_func_factory': {'type': 'mean', 'params': {'share_encoder': False}}, 'n_critics': 2, 'initial_temperature': 1.0, 'target_update_interval': 8000}}}\u001b[0m\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6fca4fd21db3493083934158ef203baf", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Epoch 1/1: 0%| | 0/10000 [00:00\u001b[0m\n", "\u001b[2m2024-12-09 11:20.49\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mAction size has been automatically determined.\u001b[0m \u001b[36maction_size\u001b[0m=\u001b[35m11\u001b[0m\n", "\u001b[2m2024-12-09 11:20.49\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mdataset info \u001b[0m \u001b[36mdataset_info\u001b[0m=\u001b[35mDatasetInfo(observation_signature=Signature(dtype=[dtype('int64')], shape=[(1,)]), action_signature=Signature(dtype=[dtype('int64')], shape=[(1,)]), reward_signature=Signature(dtype=[dtype('float64')], shape=[(1,)]), action_space=, action_size=11)\u001b[0m\n", "\u001b[2m2024-12-09 11:20.49\u001b[0m [\u001b[33m\u001b[1mwarning \u001b[0m] \u001b[1mSkip building models since they're already built.\u001b[0m\n", "\u001b[2m2024-12-09 11:20.49\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mDirectory is created at results/dsp-002/sac/d3rlpy_logs_pass_week_9/DiscreteSAC_20241209112049\u001b[0m\n", "\u001b[2m2024-12-09 11:20.49\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mParameters \u001b[0m \u001b[36mparams\u001b[0m=\u001b[35m{'observation_shape': [1], 'action_size': 11, 'config': {'type': 'discrete_sac', 'params': {'batch_size': 64, 'gamma': 0.99, 'observation_scaler': {'type': 'none', 'params': {}}, 'action_scaler': {'type': 'none', 'params': {}}, 'reward_scaler': {'type': 'none', 'params': {}}, 'compile_graph': False, 'actor_learning_rate': 0.0003, 'critic_learning_rate': 0.0003, 'temp_learning_rate': 0.0003, 'actor_optim_factory': {'type': 'adam', 'params': {'clip_grad_norm': None, 'lr_scheduler_factory': {'type': 'none', 'params': {}}, 'betas': [0.9, 0.999], 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}}, 'critic_optim_factory': {'type': 'adam', 'params': {'clip_grad_norm': None, 'lr_scheduler_factory': {'type': 'none', 'params': {}}, 'betas': [0.9, 0.999], 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}}, 'temp_optim_factory': {'type': 'adam', 'params': {'clip_grad_norm': None, 'lr_scheduler_factory': {'type': 'none', 'params': {}}, 'betas': [0.9, 0.999], 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}}, 'actor_encoder_factory': {'type': 'default', 'params': {'activation': 'relu', 'use_batch_norm': False, 'dropout_rate': None}}, 'critic_encoder_factory': {'type': 'default', 'params': {'activation': 'relu', 'use_batch_norm': False, 'dropout_rate': None}}, 'q_func_factory': {'type': 'mean', 'params': {'share_encoder': False}}, 'n_critics': 2, 'initial_temperature': 1.0, 'target_update_interval': 8000}}}\u001b[0m\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "1c6107444d9140ddae0e8b010a50bcc8", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Epoch 1/1: 0%| | 0/10000 [00:00>>> Length bound [[0, 267], [1, 396], [7, 462], [6, 413], [6, 399], [4, 346], [2, 310], [0, 201], [0, 266], [0, 247]]\n", ">>>> Load model week 1 from results/dsp-002/sac/d3rlpy_logs_pass_week_0/DiscreteSAC_20241209110854/model_10000.d3\n", ">>>> Load model week 2 from results/dsp-002/sac/d3rlpy_logs_pass_week_1/DiscreteSAC_20241209111022/model_10000.d3\n", ">>>> Load model week 3 from results/dsp-002/sac/d3rlpy_logs_pass_week_2/DiscreteSAC_20241209111200/model_10000.d3\n", ">>>> Load model week 4 from results/dsp-002/sac/d3rlpy_logs_pass_week_3/DiscreteSAC_20241209111320/model_10000.d3\n", ">>>> Load model week 5 from results/dsp-002/sac/d3rlpy_logs_pass_week_4/DiscreteSAC_20241209111443/model_10000.d3\n", ">>>> Load model week 6 from results/dsp-002/sac/d3rlpy_logs_pass_week_5/DiscreteSAC_20241209111551/model_10000.d3\n", ">>>> Load model week 7 from results/dsp-002/sac/d3rlpy_logs_pass_week_6/DiscreteSAC_20241209111704/model_10000.d3\n", ">>>> Load model week 8 from results/dsp-002/sac/d3rlpy_logs_pass_week_7/DiscreteSAC_20241209111844/model_10000.d3\n", ">>>> Load model week 9 from results/dsp-002/sac/d3rlpy_logs_pass_week_8/DiscreteSAC_20241209111946/model_10000.d3\n", ">>>> Load model week 10 from results/dsp-002/sac/d3rlpy_logs_pass_week_9/DiscreteSAC_20241209112049/model_10000.d3\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " 0%| | 0/100 [00:00>>> Invalid action : end at step 0\n", ">>>> Invalid action : end at step 0\n", ">>>> Invalid action : end at step 1\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " 1%| | 1/100 [00:01<01:50, 1.12s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ ">>>> Invalid action : end at step 1\n", ">>>> Invalid action : end at step 0\n", ">>>> Invalid action : end at step 0\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " 2%|▏ | 2/100 [00:02<01:44, 1.06s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ ">>>> Invalid action : end at step 3\n", ">>>> Invalid action : end at step 0\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " 3%|▎ | 3/100 [00:02<01:34, 1.03it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ ">>>> Invalid action : end at step 1\n", ">>>> Invalid action : end at step 1\n", ">>>> Invalid action : end at step 0\n", ">>>> Invalid action : end at step 1\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " 4%|▍ | 4/100 [00:04<01:40, 1.04s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ ">>>> Invalid action : end at step 5\n", ">>>> Invalid action : end at step 1\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " 5%|▌ | 5/100 [00:05<01:38, 1.04s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ ">>>> Invalid action : end at step 4\n", ">>>> Invalid action : end at step 0\n", ">>>> Invalid action : end at step 0\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " 6%|▌ | 6/100 [00:06<01:38, 1.04s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ ">>>> Invalid action : end at step 0\n", ">>>> Invalid action : end at step 0\n", ">>>> Invalid action : end at step 0\n", ">>>> Invalid action : end at step 0\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " 7%|▋ | 7/100 [00:07<01:35, 1.03s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ ">>>> Invalid action : end at step 0\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " 8%|▊ | 8/100 [00:07<01:18, 1.17it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ ">>>> Invalid action : end at step 0\n", ">>>> Invalid action : end at step 0\n", ">>>> Invalid action : end at step 0\n", ">>>> Invalid action : end at step 1\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " 9%|▉ | 9/100 [00:08<01:27, 1.04it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ ">>>> Invalid action : end at step 4\n", ">>>> Invalid action : end at step 0\n", ">>>> Invalid action : end at step 0\n", ">>>> Invalid action : end at step 1\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " 10%|█ | 10/100 [00:09<01:27, 1.02it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ ">>>> Invalid action : end at step 0\n", ">>>> Invalid action : end at step 0\n", ">>>> Invalid action : end at step 0\n", ">>>> Invalid action : end at step 1\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " 11%|█ | 11/100 [00:11<01:35, 1.07s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ ">>>> Invalid action : end at step 3\n", ">>>> Invalid action : end at step 0\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " 12%|█▏ | 12/100 [00:11<01:26, 1.02it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ ">>>> Invalid action : end at step 1\n", ">>>> Invalid action : end at step 0\n", ">>>> Invalid action : end at step 0\n", ">>>> Invalid action : end at step 1\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " 13%|█▎ | 13/100 [00:13<01:28, 1.02s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ ">>>> Invalid action : end at step 0\n", ">>>> Invalid action : end at step 0\n", ">>>> Invalid action : end at step 0\n", ">>>> Invalid action : end at step 0\n", ">>>> Invalid action : end at step 1\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " 14%|█▍ | 14/100 [00:14<01:37, 1.13s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ ">>>> Invalid action : end at step 1\n", ">>>> Invalid action : end at step 0\n", ">>>> Invalid action : end at step 0\n", ">>>> Invalid action : end at step 0\n", ">>>> Invalid action : end at step 1\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " 15%|█▌ | 15/100 [00:15<01:40, 1.18s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ ">>>> Invalid action : end at step 0\n", ">>>> Invalid action : end at step 0\n", ">>>> Invalid action : end at step 0\n", ">>>> Invalid action : end at step 0\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " 16%|█▌ | 16/100 [00:16<01:34, 1.12s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ ">>>> Invalid action : end at step 1\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " 17%|█▋ | 17/100 [00:17<01:25, 1.03s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ ">>>> Invalid action : end at step 0\n", ">>>> Invalid action : end at step 0\n", ">>>> Invalid action : end at step 0\n", ">>>> Invalid action : end at step 1\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " 18%|█▊ | 18/100 [00:18<01:30, 1.10s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ ">>>> Invalid action : end at step 4\n", ">>>> Invalid action : end at step 0\n", ">>>> Invalid action : end at step 0\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " 19%|█▉ | 19/100 [00:19<01:23, 1.03s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ ">>>> Invalid action : end at step 0\n", ">>>> Invalid action : end at step 0\n", ">>>> Invalid action : end at step 0\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " 20%|██ | 20/100 [00:20<01:17, 1.04it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ ">>>> Invalid action : end at step 0\n", ">>>> Invalid action : end at step 0\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " 21%|██ | 21/100 [00:21<01:13, 1.07it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ ">>>> Invalid action : end at step 1\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " 21%|██ | 21/100 [00:21<01:21, 1.03s/it]\n" ] }, { "ename": "KeyboardInterrupt", "evalue": "", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[53], line 6\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mNUM WEEK: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnum_week\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m; ID: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mid\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28mprint\u001b[39m(model_name)\n\u001b[0;32m----> 6\u001b[0m syn_students_pass \u001b[38;5;241m=\u001b[39m \u001b[43mmake_syn_student\u001b[49m\u001b[43m(\u001b[49m\u001b[43mwhatif_world\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrajectories\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrajectories_each_week_pass\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43mf\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mresults/dsp-002/\u001b[39;49m\u001b[38;5;132;43;01m{\u001b[39;49;00m\u001b[43mmodel_name\u001b[49m\u001b[38;5;132;43;01m}\u001b[39;49;00m\u001b[38;5;124;43m/d3rlpy_logs_pass\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\n\u001b[1;32m 7\u001b[0m \u001b[43m \u001b[49m\u001b[43mn_trajectories\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m100\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_week\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mnum_week\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodel_name\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 8\u001b[0m syn_students_fail \u001b[38;5;241m=\u001b[39m make_syn_student(whatif_world, trajectories\u001b[38;5;241m=\u001b[39mtrajectories_each_week_fail, model_dir\u001b[38;5;241m=\u001b[39m\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mresults/dsp-002/\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mmodel_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m/d3rlpy_logs_fail\u001b[39m\u001b[38;5;124m'\u001b[39m,\n\u001b[1;32m 9\u001b[0m n_trajectories\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m100\u001b[39m, num_week\u001b[38;5;241m=\u001b[39mnum_week, model_name\u001b[38;5;241m=\u001b[39mmodel_name)\n\u001b[1;32m 10\u001b[0m trajectories_to_features(syn_students_pass, whatif_world, path\u001b[38;5;241m=\u001b[39m\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mOUTPUT_DIR\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m/\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mmodel_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m_synthesized_early-prediction_\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mcourse_id\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m_all_pass_num_week_\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnum_week\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m_\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mid\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m, num_week\u001b[38;5;241m=\u001b[39mnum_week)\n", "Cell \u001b[0;32mIn[44], line 39\u001b[0m, in \u001b[0;36mmake_syn_student\u001b[0;34m(whatif_world, trajectories, model_dir, n_trajectories, num_week, model_name)\u001b[0m\n\u001b[1;32m 37\u001b[0m \u001b[38;5;28miter\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[1;32m 38\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m obs \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28miter\u001b[39m \u001b[38;5;241m<\u001b[39m MAX_ITER: \u001b[38;5;66;03m# make sure return valid action\u001b[39;00m\n\u001b[0;32m---> 39\u001b[0m action \u001b[38;5;241m=\u001b[39m \u001b[43msoft_policy_predict\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodels\u001b[49m\u001b[43m[\u001b[49m\u001b[43mweek\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43marray\u001b[49m\u001b[43m(\u001b[49m\u001b[43m[\u001b[49m\u001b[43mstate\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mreshape\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodel_name\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 40\u001b[0m obs \u001b[38;5;241m=\u001b[39m whatif_world\u001b[38;5;241m.\u001b[39mstep(action)\n\u001b[1;32m 41\u001b[0m \u001b[38;5;28miter\u001b[39m \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n", "Cell \u001b[0;32mIn[48], line 26\u001b[0m, in \u001b[0;36msoft_policy_predict\u001b[0;34m(model, state, action, temperature, model_name)\u001b[0m\n\u001b[1;32m 23\u001b[0m action \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mrandom\u001b[38;5;241m.\u001b[39mchoice(np\u001b[38;5;241m.\u001b[39marange(model\u001b[38;5;241m.\u001b[39maction_size), p\u001b[38;5;241m=\u001b[39mprobabilities)\n\u001b[1;32m 24\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m action\n\u001b[0;32m---> 26\u001b[0m q_values \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpredict_value\u001b[49m\u001b[43m(\u001b[49m\u001b[43mstate\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maction\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 28\u001b[0m \u001b[38;5;66;03m# Compute probabilities using softmax\u001b[39;00m\n\u001b[1;32m 29\u001b[0m probabilities \u001b[38;5;241m=\u001b[39m softmax(q_values \u001b[38;5;241m/\u001b[39m temperature)\n", "File \u001b[0;32m~/miniconda3/envs/irl/lib/python3.12/site-packages/d3rlpy/algos/qlearning/base.py:319\u001b[0m, in \u001b[0;36mQLearningAlgoBase.predict_value\u001b[0;34m(self, x, action)\u001b[0m\n\u001b[1;32m 316\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_impl \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m, IMPL_NOT_INITIALIZED_ERROR\n\u001b[1;32m 317\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m check_non_1d_array(x), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mInput must have batch dimension.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m--> 319\u001b[0m torch_x \u001b[38;5;241m=\u001b[39m \u001b[43mconvert_to_torch_recursively\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_device\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 321\u001b[0m torch_action \u001b[38;5;241m=\u001b[39m convert_to_torch(action, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_device)\n\u001b[1;32m 323\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n", "File \u001b[0;32m~/miniconda3/envs/irl/lib/python3.12/site-packages/d3rlpy/torch_utility.py:118\u001b[0m, in \u001b[0;36mconvert_to_torch_recursively\u001b[0;34m(array, device)\u001b[0m\n\u001b[1;32m 116\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m [convert_to_torch(data, device) \u001b[38;5;28;01mfor\u001b[39;00m data \u001b[38;5;129;01min\u001b[39;00m array]\n\u001b[1;32m 117\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(array, np\u001b[38;5;241m.\u001b[39mndarray):\n\u001b[0;32m--> 118\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mconvert_to_torch\u001b[49m\u001b[43m(\u001b[49m\u001b[43marray\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 119\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 120\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124minvalid array type: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(array)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n", "File \u001b[0;32m~/miniconda3/envs/irl/lib/python3.12/site-packages/d3rlpy/torch_utility.py:96\u001b[0m, in \u001b[0;36mconvert_to_torch\u001b[0;34m(array, device)\u001b[0m\n\u001b[1;32m 94\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mconvert_to_torch\u001b[39m(array: NDArray, device: \u001b[38;5;28mstr\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m torch\u001b[38;5;241m.\u001b[39mTensor:\n\u001b[1;32m 95\u001b[0m dtype \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39muint8 \u001b[38;5;28;01mif\u001b[39;00m array\u001b[38;5;241m.\u001b[39mdtype \u001b[38;5;241m==\u001b[39m np\u001b[38;5;241m.\u001b[39muint8 \u001b[38;5;28;01melse\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mfloat32\n\u001b[0;32m---> 96\u001b[0m tensor \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtensor\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43marray\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdtype\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 97\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m tensor\u001b[38;5;241m.\u001b[39mfloat()\n", "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } ], "source": [ "OUTPUT_DIR = f'results/{course_id}'\n", "for num_week in [10]:\n", " for id in range(1, 3):\n", " print(f'NUM WEEK: {num_week}; ID: {id}')\n", " print(model_name)\n", " syn_students_pass = make_syn_student(whatif_world, trajectories=trajectories_each_week_pass, model_dir=f'results/dsp-002/{model_name}/d3rlpy_logs_pass', \n", " n_trajectories=100, num_week=num_week, model_name=model_name)\n", " syn_students_fail = make_syn_student(whatif_world, trajectories=trajectories_each_week_fail, model_dir=f'results/dsp-002/{model_name}/d3rlpy_logs_fail',\n", " n_trajectories=100, num_week=num_week, model_name=model_name)\n", " trajectories_to_features(syn_students_pass, whatif_world, path=f'{OUTPUT_DIR}/{model_name}_synthesized_early-prediction_{course_id}_all_pass_num_week_{num_week}_{id}', num_week=num_week)\n", " trajectories_to_features(syn_students_fail, whatif_world, path=f'{OUTPUT_DIR}/{model_name}_synthesized_early-prediction_{course_id}_all_fail_num_week_{num_week}_{id}', num_week=num_week)" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [], "source": [ "def make_syn_student_personalized(students, test_labels, world,\n", " week=5, model_dir='results/dsp-002/d3rlpy_logs', model_name='dqn'):\n", " \"\"\"\n", " Here we made an assumption that we already know label pass/fail of a student\n", " \"\"\"\n", " MAX_ITER = 1000\n", " ids = np.load('results/dsp-002/test_students_5.npy')\n", " print('Week: ', week + 1, 'personalized') \n", " \n", " start_state_arr = [t[0, 0] if len(t) > 0 else None for t in students[ids]]\n", " lengths = [len(t) for t in students[ids]]\n", "\n", " syn_students = []\n", " models = []\n", " if model_name == 'decision_transformer':\n", " paths = [f'{model_dir}_pass_all, {model_dir}_fail_all']\n", " else:\n", " paths = [f'{model_dir}_pass_week_{week}', f'{model_dir}_fail_week_{week}']\n", " for path in paths:\n", " file = find_file(path, f'model_10000.d3')\n", " print(f'>>>> Load model week {week + 1} from', file)\n", " dqn = d3rlpy.load_learnable(file)\n", " if model_name == 'decision_transformer':\n", " actor = dqn.as_stateful_wrapper(target_return=0)\n", " models.append(actor)\n", " else:\n", " models.append(dqn)\n", " assert len(models) == 2\n", " print(len(ids), len(test_labels))\n", " for i in tqdm(range(len(ids))):\n", " student = []\n", " if test_labels[i] == 0:\n", " model = models[0] # pass model\n", " else:\n", " model = models[1] # fail model\n", " trajectory = []\n", " state = start_state_arr[i] \n", " world._agent_location = state\n", " length = lengths[i]\n", " if length > 0:\n", " for i in range(length):\n", " obs = None\n", " iter = 0\n", " while obs is None and iter < MAX_ITER: # make sure return valid action\n", " action = soft_policy_predict(model, np.array([state]).reshape(1, -1), model_name=model_name)\n", " obs = world.step(action)\n", " iter += 1\n", " if iter >= MAX_ITER:\n", " print('>>>> Invalid action', ': end at step', i)\n", " break\n", " next_state = obs[0]['agent']\n", " trajectory.append([state, action, next_state])\n", " state = next_state\n", " else:\n", " trajectory = []\n", " if model_name == 'decision_transformer':\n", " models[0].reset()\n", " models[1].reset()\n", " student.append(trajectory)\n", " # print(student)\n", " syn_students.append(student)\n", "\n", " return syn_students" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ ">>>>> WEEK: 5 ID: 0\n", "Week: 6 personalized\n", ">>>> Load model week 6 from None\n" ] }, { "ename": "TypeError", "evalue": "expected str, bytes or os.PathLike object, not NoneType", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[24], line 7\u001b[0m\n\u001b[1;32m 5\u001b[0m pred \u001b[38;5;241m=\u001b[39m pd\u001b[38;5;241m.\u001b[39mread_csv(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mOUTPUT_DIR\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m/predictions_weeks_\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mweek\u001b[38;5;250m \u001b[39m\u001b[38;5;241m+\u001b[39m\u001b[38;5;250m \u001b[39m\u001b[38;5;241m1\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.csv\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 6\u001b[0m pred \u001b[38;5;241m=\u001b[39m pred[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124my_pred\u001b[39m\u001b[38;5;124m'\u001b[39m]\n\u001b[0;32m----> 7\u001b[0m syn_students \u001b[38;5;241m=\u001b[39m \u001b[43mmake_syn_student_personalized\u001b[49m\u001b[43m(\u001b[49m\u001b[43mstudents\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrajectories_each_week\u001b[49m\u001b[43m[\u001b[49m\u001b[43mweek\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 8\u001b[0m \u001b[43m \u001b[49m\u001b[43mworld\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mwhatif_world\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mweek\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m5\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtest_labels\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpred\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodel_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43mf\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mresults/dsp-002/\u001b[39;49m\u001b[38;5;132;43;01m{\u001b[39;49;00m\u001b[43mmodel_name\u001b[49m\u001b[38;5;132;43;01m}\u001b[39;49;00m\u001b[38;5;124;43m/d3rlpy_logs\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 9\u001b[0m course_feature \u001b[38;5;241m=\u001b[39m trajectories_to_features(syn_students, whatif_world, num_week\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m, path\u001b[38;5;241m=\u001b[39m\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mOUTPUT_DIR\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m/\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mmodel_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m_synthesized_early-prediction_\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mcourse_id\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m_personalized_week_\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mweek\u001b[38;5;241m+\u001b[39m\u001b[38;5;241m1\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124monly_\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mid\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m----------------\u001b[39m\u001b[38;5;124m'\u001b[39m)\n", "Cell \u001b[0;32mIn[23], line 22\u001b[0m, in \u001b[0;36mmake_syn_student_personalized\u001b[0;34m(students, test_labels, world, week, model_dir, model_name)\u001b[0m\n\u001b[1;32m 20\u001b[0m file \u001b[38;5;241m=\u001b[39m find_file(path, \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mmodel_10000.d3\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 21\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m>>>> Load model week \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mweek\u001b[38;5;250m \u001b[39m\u001b[38;5;241m+\u001b[39m\u001b[38;5;250m \u001b[39m\u001b[38;5;241m1\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m from\u001b[39m\u001b[38;5;124m'\u001b[39m, file)\n\u001b[0;32m---> 22\u001b[0m dqn \u001b[38;5;241m=\u001b[39m \u001b[43md3rlpy\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload_learnable\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfile\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 23\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m model_name \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mdecision_transformer\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[1;32m 24\u001b[0m actor \u001b[38;5;241m=\u001b[39m dqn\u001b[38;5;241m.\u001b[39mas_stateful_wrapper(target_return\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m)\n", "File \u001b[0;32m~/miniconda3/envs/irl/lib/python3.12/site-packages/d3rlpy/base.py:192\u001b[0m, in \u001b[0;36mload_learnable\u001b[0;34m(fname, device)\u001b[0m\n\u001b[1;32m 189\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mload_learnable\u001b[39m(\n\u001b[1;32m 190\u001b[0m fname: \u001b[38;5;28mstr\u001b[39m, device: DeviceArg \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 191\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mLearnableBase[ImplBase, LearnableConfig]\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[0;32m--> 192\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28;43mopen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mfname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mrb\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mas\u001b[39;00m f:\n\u001b[1;32m 193\u001b[0m obj \u001b[38;5;241m=\u001b[39m pickle\u001b[38;5;241m.\u001b[39mload(f)\n\u001b[1;32m 194\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m obj[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mversion\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m!=\u001b[39m __version__:\n", "\u001b[0;31mTypeError\u001b[0m: expected str, bytes or os.PathLike object, not NoneType" ] } ], "source": [ "model_name = 'decision_transformer'\n", "for week in [5]:\n", " for id in range(3):\n", " print('>>>>> WEEK:', week, 'ID:', id)\n", " pred = pd.read_csv(f'{OUTPUT_DIR}/predictions_weeks_{week + 1}.csv')\n", " pred = pred['y_pred']\n", " syn_students = make_syn_student_personalized(students=trajectories_each_week[week],\n", " world=whatif_world, week=5, test_labels=pred, model_name=model_name, model_dir=f'results/dsp-002/{model_name}/d3rlpy_logs')\n", " course_feature = trajectories_to_features(syn_students, whatif_world, num_week=1, path=f'{OUTPUT_DIR}/{model_name}_synthesized_early-prediction_{course_id}_personalized_week_{week+1}only_{id}')\n", " print('----------------')\n", " " ] } ], "metadata": { "kernelspec": { "display_name": "irl", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.2" } }, "nbformat": 4, "nbformat_minor": 2 }