Neo-X commited on
Commit
e18f7ef
·
1 Parent(s): 472e6ce

Fixing sim evals and addting action stacking support.

Browse files
Files changed (3) hide show
  1. README.md +1 -2
  2. app.py +20 -14
  3. sim_eval.py +78 -38
README.md CHANGED
@@ -14,5 +14,4 @@ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-
14
  Submit your files in the form
15
 
16
  miniGRP.pth
17
- conf/config.yaml
18
- grp_model.py
 
14
  Submit your files in the form
15
 
16
  miniGRP.pth
17
+ conf/config.yaml
 
app.py CHANGED
@@ -19,7 +19,7 @@ api = HfApi()
19
 
20
  def evaluate_policy(model_id):
21
  """
22
- Downloads a PPO model from HF Hub, runs it in Gym, returns mean reward.
23
  """
24
  print(f"Starting evaluation for: {model_id}")
25
  try:
@@ -36,7 +36,7 @@ def evaluate_policy(model_id):
36
  model_file = os.path.join(root, file)
37
  if file.endswith("model.py"):
38
  grp_file_path = os.path.join(root, file)
39
- if file.endswith(".yaml") or file.endswith(".yalm"):
40
  hydra_config_file_path = os.path.join(root, file)
41
 
42
  if not model_file:
@@ -44,7 +44,7 @@ def evaluate_policy(model_id):
44
 
45
  # 2. Load the PPO Agent
46
  # custom_objects map may be needed if python versions differ, but usually fine for PPO
47
- import torch
48
  # ------------
49
  # Train and test splits
50
  # Loading data
@@ -54,13 +54,14 @@ def evaluate_policy(model_id):
54
  from omegaconf import OmegaConf
55
  cfg = OmegaConf.load(hydra_config_file_path)
56
  cfg.dataset.load_dataset = "skip"
 
57
  ## load the GRP model from the file doanloaded in the snappshot
58
  # Dynamically load the module
59
  import importlib.util, sys
60
  sys.path.insert(0, repo_path+"/") ## dangerous for sequrity but ok for now.
61
  from grp_model import GRP
62
 
63
- model_ = torch.load(model_file)
64
  # model_._cgf = cfg
65
  # model = PPO.load(model_file)
66
  print("Memory used by the model:", torch.cuda.memory_allocated(cfg.device) / 1e6, "MB") ## This to the database later.
@@ -69,15 +70,19 @@ def evaluate_policy(model_id):
69
 
70
  tokenizer = None
71
  text_model = None
 
 
72
  if cfg.dataset.encode_with_t5: ## Load T5 model
73
  from transformers import T5Tokenizer, T5ForConditionalGeneration
74
  tokenizer = T5Tokenizer.from_pretrained(cfg.dataset.t5_version)
75
  text_model = T5ForConditionalGeneration.from_pretrained(cfg.dataset.t5_version)
76
 
77
  if "libero" in cfg.simEval:
 
78
  results = eval_libero(model_.to(cfg.device), device=cfg.device, cfg=cfg,
79
  iter_=0, tokenizer=tokenizer, text_model=text_model, wandb=None,
80
  log_dir="./")
 
81
  if "simple_env" in cfg.simEval:
82
  import simpler_env
83
  task_name = "widowx_carrot_on_plate" # @param ["google_robot_pick_coke_can", "google_robot_move_near", "google_robot_open_drawer", "google_robot_close_drawer", "widowx_spoon_on_towel", "widowx_carrot_on_plate", "widowx_stack_cube", "widowx_put_eggplant_in_basket"]
@@ -93,11 +98,11 @@ def evaluate_policy(model_id):
93
  wandb=None, iter_=0, tokenizer=tokenizer, text_model=text_model)
94
  print("results:", results)
95
 
96
- # cbuffer.save(cfg.dataset.to_name)
97
- env.close()
98
- del env
99
-
100
- return results['rewards'], "Success"
101
 
102
  except Exception as e:
103
  print(f"Evaluation failed: {e}")
@@ -136,23 +141,24 @@ def run_evaluation_loop():
136
  score, status_msg = evaluate_policy(model_id)
137
 
138
  # 4. Update the Dataframes
139
-
140
  # Update Requests (Mark as Done or Failed)
141
  requests_df.loc[row_index, "status"] = "Done" if score is not None else "Failed"
142
-
143
  # Prepare Results Row
144
  if score is not None:
145
  new_result = {
146
  "model_id": model_id,
147
- "mean_reward": score,
148
- "status": "Success"
 
 
149
  }
150
 
151
  # Load Results Dataset
152
  try:
153
  results_df = pd.read_csv(f"hf://datasets/{RESULTS_DATASET}/results.csv")
154
  except:
155
- results_df = pd.DataFrame(columns=["model_id", "mean_reward", "status"])
156
 
157
  # Append new result
158
  results_df = pd.concat([results_df, pd.DataFrame([new_result])], ignore_index=True)
 
19
 
20
  def evaluate_policy(model_id):
21
  """
22
+ Downloads a GRP model from HF Hub, runs it in the simulator, returns mean reward.
23
  """
24
  print(f"Starting evaluation for: {model_id}")
25
  try:
 
36
  model_file = os.path.join(root, file)
37
  if file.endswith("model.py"):
38
  grp_file_path = os.path.join(root, file)
39
+ if file.endswith("config.yaml") or file.endswith("config.yalm"):
40
  hydra_config_file_path = os.path.join(root, file)
41
 
42
  if not model_file:
 
44
 
45
  # 2. Load the PPO Agent
46
  # custom_objects map may be needed if python versions differ, but usually fine for PPO
47
+ import torch, dill
48
  # ------------
49
  # Train and test splits
50
  # Loading data
 
54
  from omegaconf import OmegaConf
55
  cfg = OmegaConf.load(hydra_config_file_path)
56
  cfg.dataset.load_dataset = "skip"
57
+ cfg.testing = True
58
  ## load the GRP model from the file doanloaded in the snappshot
59
  # Dynamically load the module
60
  import importlib.util, sys
61
  sys.path.insert(0, repo_path+"/") ## dangerous for sequrity but ok for now.
62
  from grp_model import GRP
63
 
64
+ model_ = torch.load(model_file, pickle_module=dill)
65
  # model_._cgf = cfg
66
  # model = PPO.load(model_file)
67
  print("Memory used by the model:", torch.cuda.memory_allocated(cfg.device) / 1e6, "MB") ## This to the database later.
 
70
 
71
  tokenizer = None
72
  text_model = None
73
+ ## Time the evalaution run
74
+ start_time = time.time()
75
  if cfg.dataset.encode_with_t5: ## Load T5 model
76
  from transformers import T5Tokenizer, T5ForConditionalGeneration
77
  tokenizer = T5Tokenizer.from_pretrained(cfg.dataset.t5_version)
78
  text_model = T5ForConditionalGeneration.from_pretrained(cfg.dataset.t5_version)
79
 
80
  if "libero" in cfg.simEval:
81
+ from sim_eval import eval_libero
82
  results = eval_libero(model_.to(cfg.device), device=cfg.device, cfg=cfg,
83
  iter_=0, tokenizer=tokenizer, text_model=text_model, wandb=None,
84
  log_dir="./")
85
+ print("LIBERO results:", results)
86
  if "simple_env" in cfg.simEval:
87
  import simpler_env
88
  task_name = "widowx_carrot_on_plate" # @param ["google_robot_pick_coke_can", "google_robot_move_near", "google_robot_open_drawer", "google_robot_close_drawer", "widowx_spoon_on_towel", "widowx_carrot_on_plate", "widowx_stack_cube", "widowx_put_eggplant_in_basket"]
 
98
  wandb=None, iter_=0, tokenizer=tokenizer, text_model=text_model)
99
  print("results:", results)
100
 
101
+ # cbuffer.save(cfg.dataset.to_name)
102
+ env.close()
103
+ del env
104
+ results['time'] = time.time() - start_time
105
+ return results, "Success"
106
 
107
  except Exception as e:
108
  print(f"Evaluation failed: {e}")
 
141
  score, status_msg = evaluate_policy(model_id)
142
 
143
  # 4. Update the Dataframes
 
144
  # Update Requests (Mark as Done or Failed)
145
  requests_df.loc[row_index, "status"] = "Done" if score is not None else "Failed"
146
+
147
  # Prepare Results Row
148
  if score is not None:
149
  new_result = {
150
  "model_id": model_id,
151
+ "mean_reward": score['rewards'],
152
+ "run_time": score["time"],
153
+ "status": "Success",
154
+ "completed_at": time.time()
155
  }
156
 
157
  # Load Results Dataset
158
  try:
159
  results_df = pd.read_csv(f"hf://datasets/{RESULTS_DATASET}/results.csv")
160
  except:
161
+ results_df = pd.DataFrame(columns=["model_id", "mean_reward", "run_time", "status", "completed_at"])
162
 
163
  # Append new result
164
  results_df = pd.concat([results_df, pd.DataFrame([new_result])], ignore_index=True)
sim_eval.py CHANGED
@@ -1,4 +1,7 @@
1
 
 
 
 
2
 
3
  def get_text_tokens(cfg, tokenizer, text_model, goal, model=None):
4
  """
@@ -43,28 +46,24 @@ def eval_model_in_sim(cfg, model, device, log_dir, env, env_unwrapped,
43
  obs, reset_info = env.reset()
44
  obs_ = get_image_from_maniskill2_obs_dict(env_unwrapped, obs)[:,:,:3]
45
  obs_hist = deque(maxlen=cfg.policy.obs_stacking)
46
- obs_hist.append(obs_)
47
- obs_hist.append(obs_)
48
- obs_hist.append(obs_)
49
  instruction = env_unwrapped.get_language_instruction()
50
  # print("Reset info", reset_info)
51
  print("Instruction", instruction)
52
  frames = []
53
  done, truncated, timeLimit, t = False, False, 100, 0
54
  txt_goal = get_text_tokens(cfg, tokenizer, text_model, instruction, model=model)
 
55
  while not (done or truncated or (t > timeLimit)):
56
  # action[:3]: delta xyz; action[3:6]: delta rotation in axis-angle representation;
57
  # action[6:7]: gripper (the meaning of open / close depends on robot URDF)
58
- image = get_image_from_maniskill2_obs_dict(env_unwrapped, obs)
59
- image = image[:,:,:3] ## Remove last dimension of image color
60
-
61
- obs_hist.append(image) ## Add the new observation to the history buffer
62
  # obs = [obs_["image"] for obs_ in obs] # obs is a list of dicts
63
  image = np.stack(obs_hist, axis=-1) # stack along the last dimension
64
  image = rearrange(image, 'h w c t -> h w (c t)') # add batch dimension
65
-
66
- obs_state = model.preprocess_state(image).to(device)
67
- goal_state = model.preprocess_goal_image(image[:,:,:3]).to(device)
68
  action, loss = model.forward(torch.tensor(obs_state.unsqueeze(0), dtype=torch.float32).to(device)
69
  ,torch.tensor(txt_goal).to(device)
70
  ,torch.tensor(goal_state.unsqueeze(0), dtype=torch.float32).to(device),
@@ -73,24 +72,39 @@ def eval_model_in_sim(cfg, model, device, log_dir, env, env_unwrapped,
73
  )
74
 
75
  action = model.decode_action(action[0]).cpu().detach().numpy() ## Add in the gripper close action
76
- obs, reward, done, truncated, info = env.step(action)
77
- reward = -np.linalg.norm(info["eof_to_obj1_diff"])
78
- frames.append(image)
79
- rewards.append(reward)
80
- t=t+1
 
 
 
 
 
 
 
 
 
81
 
82
  episode_stats = info.get('episode_stats', {})
83
  episode_stats['rewards'] = np.mean(rewards)
84
  # print("Episode stats", episode_stats)
85
- # print(f"avg reward {np.mean(episode_stats['rewards']):.8f}")
86
  if not cfg.testing:
87
  wandb.log({"avg reward": np.mean(rewards)})
88
- import moviepy.editor as mpy
89
- clip = mpy.ImageSequenceClip(list(frames), fps=20)
90
- path_ = log_dir+"/sim-env-"+str(iter_)+".mp4"
91
- # clip.write_videofile(path_, fps=20, audio=False, logger=None) ## Getting weird Nonetype issues. Will need to fix version issue later.
 
 
92
  if not cfg.testing:
93
- wandb.log({"example": wandb.Video(path_)})
 
 
 
 
94
  return episode_stats
95
 
96
  import gymnasium as gym
@@ -143,6 +157,7 @@ def eval_libero(model, device, cfg, iter_=0, log_dir="./",
143
  from libero.libero.utils import get_libero_path
144
  from gymnasium.wrappers import FrameStackObservation
145
  from einops import rearrange
 
146
 
147
 
148
  benchmark_dict = benchmark.get_benchmark_dict()
@@ -172,6 +187,9 @@ def eval_libero(model, device, cfg, iter_=0, log_dir="./",
172
  env.set_init_state(init_states[init_state_id])
173
  env = FrameStackObservation(DictWrapper(env, obs_key="agentview_image"), cfg.policy.obs_stacking) ## Stacking the observations
174
  obs, info = env.reset()
 
 
 
175
 
176
  mask = get_blocked_mask(cfg, targets=None, T=0) ## Get the blocked mask
177
 
@@ -180,9 +198,15 @@ def eval_libero(model, device, cfg, iter_=0, log_dir="./",
180
  frames = []
181
  rewards = []
182
  infos = []
183
- for step_ in range(250):
 
184
  ## Reshape the image to the correct size and stack the hostory on the last channel dimension
185
- image = obs[0]
 
 
 
 
 
186
  # obs = obs.reshape((128, 128, 3*cfg.policy.obs_stacking)) ## Assuming the observation is an image of size 128x128 with 3 color channels
187
  obs = rearrange(obs, 't h w c -> h w (t c)', c=3, t=cfg.policy.obs_stacking) ## Rearranging the image to have the stacked history in the last channel dimension
188
  # image = obs[:,:,:3] ## Remove the last dimension of the image color
@@ -195,35 +219,50 @@ def eval_libero(model, device, cfg, iter_=0, log_dir="./",
195
  pose=torch.tensor([[np.concatenate( (info["robot0_eef_pos"],
196
  info["robot0_eef_quat"][:3],
197
  [(info["robot0_gripper_qpos"][0] - info["robot0_gripper_qpos"][0]) < 0.005 ]), axis=-1)]], dtype=torch.float32).to(device),
198
- morphology=torch.tensor([0], dtype=torch.uint8).to(device) ## Morphology is 0 for arm, 1 for A1}
199
  )
200
 
201
- action = model.decode_action(action[0,0,:7]).cpu().detach().numpy() ## Add in the gripper close action
202
- frames.append(image)
203
- x = env.step(action)
204
- obs, reward, done, truncated, info = x
205
- rewards.append(reward)
206
- infos.append(info)
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  if done:
208
  print("Episode finished after {} timesteps".format(step_))
209
  break
210
-
 
 
211
  print(f"avg reward {np.mean(rewards):.8f}")
212
  if not cfg.testing:
213
  wandb.log({"avg reward_"+str(task_id): np.mean(rewards)})
214
- import moviepy.editor as mpy
215
- clip = mpy.ImageSequenceClip(list(frames), fps=20)
216
- path_ = log_dir+"/sim-libero-90-"+str(task_id)+"-"+str(iter_)+".mp4"
217
- clip.write_videofile(path_, fps=20)
218
  if not cfg.testing:
219
  wandb.log({"example": wandb.Video(path_)})
220
  env.close()
 
221
 
222
  import hydra
223
  from omegaconf import DictConfig
224
- from mini_grp import *
225
 
226
- @hydra.main(config_path="./conf", config_name="libero-simpleEnv-64pix-pose")
227
  def my_main(cfg: DictConfig):
228
  from mini_shuffel_buffer import CircularBuffer
229
  import torch
@@ -237,7 +276,8 @@ def my_main(cfg: DictConfig):
237
  # model_ = torch.load("/home/gberseth/playground/mini_grp/miniGRP.pth")
238
  model_dir = hydra.utils.get_original_cwd()+"/mini-grp/miniGRP.pth"
239
  print ("Loading model from:", model_dir)
240
- model_ = torch.load(model_dir)
 
241
  # model_._cgf = cfg
242
 
243
  tokenizer = None
 
1
 
2
+ import dill
3
+ import numpy as np
4
+ import torch
5
 
6
  def get_text_tokens(cfg, tokenizer, text_model, goal, model=None):
7
  """
 
46
  obs, reset_info = env.reset()
47
  obs_ = get_image_from_maniskill2_obs_dict(env_unwrapped, obs)[:,:,:3]
48
  obs_hist = deque(maxlen=cfg.policy.obs_stacking)
49
+ for _ in range(cfg.policy.obs_stacking):
50
+ obs_hist.append(obs_)
 
51
  instruction = env_unwrapped.get_language_instruction()
52
  # print("Reset info", reset_info)
53
  print("Instruction", instruction)
54
  frames = []
55
  done, truncated, timeLimit, t = False, False, 100, 0
56
  txt_goal = get_text_tokens(cfg, tokenizer, text_model, instruction, model=model)
57
+ # obs_hist.append(image) ## Add the new observation to the history buffer
58
  while not (done or truncated or (t > timeLimit)):
59
  # action[:3]: delta xyz; action[3:6]: delta rotation in axis-angle representation;
60
  # action[6:7]: gripper (the meaning of open / close depends on robot URDF)
 
 
 
 
61
  # obs = [obs_["image"] for obs_ in obs] # obs is a list of dicts
62
  image = np.stack(obs_hist, axis=-1) # stack along the last dimension
63
  image = rearrange(image, 'h w c t -> h w (c t)') # add batch dimension
64
+
65
+ obs_state = torch.tensor(model.preprocess_state(image), dtype=torch.float32)
66
+ goal_state = torch.tensor(model.preprocess_goal_image(image[:,:,:3]), dtype=torch.float32)
67
  action, loss = model.forward(torch.tensor(obs_state.unsqueeze(0), dtype=torch.float32).to(device)
68
  ,torch.tensor(txt_goal).to(device)
69
  ,torch.tensor(goal_state.unsqueeze(0), dtype=torch.float32).to(device),
 
72
  )
73
 
74
  action = model.decode_action(action[0]).cpu().detach().numpy() ## Add in the gripper close action
75
+ ## If the actions are stacked into a longer vector execute the sequence of actions
76
+ for step_ in range(cfg.policy.action_stacking):
77
+ act_ = action[cfg.action_dim*step_:(cfg.action_dim*(step_+1))]
78
+ obs, reward, done, truncated, info = env.step(act_)
79
+ image = get_image_from_maniskill2_obs_dict(env_unwrapped, obs)
80
+ image = image[:,:,:3] ## Remove last dimension of image color
81
+ # Store the original image for video before stacking/processing
82
+ frames.append(image)
83
+ reward = -(np.linalg.norm(info["eof_to_obj1_diff"]) + np.linalg.norm(info["eof_to_obj1_diff"])) ## Use a shaped reward as distance between gripper and objects
84
+ rewards.append(reward)
85
+ t=t+1
86
+ if done or truncated:
87
+ break
88
+
89
 
90
  episode_stats = info.get('episode_stats', {})
91
  episode_stats['rewards'] = np.mean(rewards)
92
  # print("Episode stats", episode_stats)
93
+ print(f"avg reward {np.mean(episode_stats['rewards']):.8f}")
94
  if not cfg.testing:
95
  wandb.log({"avg reward": np.mean(rewards)})
96
+
97
+ import os
98
+ path_ = os.path.join(log_dir, f"simple-env-{iter_}.mp4")
99
+ import imageio
100
+ imageio.mimsave(path_, frames, fps=20)
101
+
102
  if not cfg.testing:
103
+ try:
104
+ wandb.log({"example": wandb.Video(path_)})
105
+ except Exception as e:
106
+ print(f"Warning: failed to log video to wandb: {e}")
107
+
108
  return episode_stats
109
 
110
  import gymnasium as gym
 
157
  from libero.libero.utils import get_libero_path
158
  from gymnasium.wrappers import FrameStackObservation
159
  from einops import rearrange
160
+ from collections import deque
161
 
162
 
163
  benchmark_dict = benchmark.get_benchmark_dict()
 
187
  env.set_init_state(init_states[init_state_id])
188
  env = FrameStackObservation(DictWrapper(env, obs_key="agentview_image"), cfg.policy.obs_stacking) ## Stacking the observations
189
  obs, info = env.reset()
190
+ # obs_hist = deque(maxlen=cfg.policy.obs_stacking)
191
+ # for _ in range(cfg.policy.obs_stacking):
192
+ # obs_hist.append(obs)
193
 
194
  mask = get_blocked_mask(cfg, targets=None, T=0) ## Get the blocked mask
195
 
 
198
  frames = []
199
  rewards = []
200
  infos = []
201
+ done, truncated, timeLimit, t, wait_steps = False, False, 400, 0, 10
202
+ while not (done or truncated or (t > (timeLimit + wait_steps))):
203
  ## Reshape the image to the correct size and stack the hostory on the last channel dimension
204
+ # image = obs[0]
205
+ if t < wait_steps: ## let object stabalize before acting.
206
+ obs, reward, done, truncated, info = env.step([0,0,0,0,0,0,-1])
207
+ # obs_hist.append(obs)
208
+ t += 1
209
+ continue
210
  # obs = obs.reshape((128, 128, 3*cfg.policy.obs_stacking)) ## Assuming the observation is an image of size 128x128 with 3 color channels
211
  obs = rearrange(obs, 't h w c -> h w (t c)', c=3, t=cfg.policy.obs_stacking) ## Rearranging the image to have the stacked history in the last channel dimension
212
  # image = obs[:,:,:3] ## Remove the last dimension of the image color
 
219
  pose=torch.tensor([[np.concatenate( (info["robot0_eef_pos"],
220
  info["robot0_eef_quat"][:3],
221
  [(info["robot0_gripper_qpos"][0] - info["robot0_gripper_qpos"][0]) < 0.005 ]), axis=-1)]], dtype=torch.float32).to(device),
222
+ # morphology=torch.tensor([0], dtype=torch.uint8).to(device) ## Morphology is 0 for arm, 1 for A1}
223
  )
224
 
225
+ action = model.decode_action(action[0]).cpu().detach().numpy() ## Add in the gripper close action
226
+ ## If the actions are stacked into a longer vector execute the sequence of actions
227
+ for step_ in range(cfg.policy.action_stacking):
228
+ act_ = action[cfg.action_dim*step_:(cfg.action_dim*(step_+1))]
229
+ ## Need to process LIBERO gripper action [0, 1] -> [-1, 1], then invert, https://github.com/moojink/openvla-oft/blob/e4287e94541f459edc4feabc4e181f537cd569a8/experiments/robot/libero/run_libero_eval.py#L265
230
+ act_[6] = ((act_[6] - 0.5) * 2) * -1.0
231
+
232
+ obs, reward, done, truncated, info = env.step(act_)
233
+ # image = get_image_from_maniskill2_obs_dict(env_unwrapped, obs)
234
+ # image = image[:,:,:3] ## Remove last dimension of image color
235
+ # Store the original image for video before stacking/processing
236
+ image = obs[0]
237
+ frames.append(image)
238
+ # reward = -(np.linalg.norm(info["eof_to_obj1_diff"]) + np.linalg.norm(info["eof_to_obj1_diff"])) ## Use a shaped reward as distance between gripper and objects
239
+ rewards.append(reward)
240
+ infos.append(info)
241
+ t=t+1
242
+ if done or truncated:
243
+ break
244
  if done:
245
  print("Episode finished after {} timesteps".format(step_))
246
  break
247
+
248
+ episode_stats = info.get('episode_stats', {})
249
+ episode_stats['rewards'] = np.mean(rewards)
250
  print(f"avg reward {np.mean(rewards):.8f}")
251
  if not cfg.testing:
252
  wandb.log({"avg reward_"+str(task_id): np.mean(rewards)})
253
+ import os
254
+ path_ = os.path.join(log_dir, f"libero-{iter_}.mp4")
255
+ import imageio
256
+ imageio.mimsave(path_, frames, fps=20)
257
  if not cfg.testing:
258
  wandb.log({"example": wandb.Video(path_)})
259
  env.close()
260
+ return episode_stats
261
 
262
  import hydra
263
  from omegaconf import DictConfig
 
264
 
265
+ @hydra.main(config_path="./conf", config_name="64pix-pose")
266
  def my_main(cfg: DictConfig):
267
  from mini_shuffel_buffer import CircularBuffer
268
  import torch
 
276
  # model_ = torch.load("/home/gberseth/playground/mini_grp/miniGRP.pth")
277
  model_dir = hydra.utils.get_original_cwd()+"/mini-grp/miniGRP.pth"
278
  print ("Loading model from:", model_dir)
279
+ from grp_model import GRP
280
+ model_ = torch.load(model_dir, pickle_module=dill)
281
  # model_._cgf = cfg
282
 
283
  tokenizer = None