gkemp181 commited on
Commit
07bd4cd
·
1 Parent(s): 479bef5

Added model selection to app

Browse files
App/model/{model.zip → pick_and_place_dense.zip} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d257f9937f3914c65c7ad21a2c25d601862ffc7d0ede4a7c6d3270fc04db2eec
3
- size 3372650
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8582301dcbe21ded7d266bd5548a629d76f603ceeb44995af657a3b5b322295a
3
+ size 3377664
App/model/pick_and_place_her.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ab787b78fb54a6ee447bfd046248a1217a6e3207633e6753a2824282af3c08ad
3
+ size 3379264
App/model/push.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9953fc1dfd1c19b9faa56d898cbc985790468b41c46c530f797e5b7f56106715
3
+ size 3377665
App/model/reach.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:51a4a2ae881f240be42ff6cae71e54c2a0487d5b083cacd52e346359d6fbb139
3
+ size 3207511
__pycache__/app.cpython-311.pyc ADDED
Binary file (4.31 kB). View file
 
__pycache__/custom_env.cpython-311.pyc ADDED
Binary file (4.68 kB). View file
 
app.py CHANGED
@@ -1,7 +1,4 @@
1
- # <-- this must come first, before any mujoco / gym imports
2
  import os
3
- os.environ["MUJOCO_GL"] = "osmesa"
4
-
5
  import gradio as gr
6
  import numpy as np
7
  import torch
@@ -9,49 +6,81 @@ import imageio
9
  from stable_baselines3 import SAC
10
  from custom_env import create_env
11
 
12
- # Define the function that runs the model and outputs a video
13
- def run_model_episode(x_start, y_start, x_targ, y_targ, z_targ):
14
- # Create environment with user inputs
15
- env = create_env(render_mode="rgb_array",
16
- block_xy=(x_start, y_start),
17
- goal_xyz=(x_targ, y_targ, z_targ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- # Load your trained model
20
- checkpoint_path = os.path.join("App", "model", "model.zip")
21
- model = SAC.load(checkpoint_path, env=env, verbose=1)
22
 
23
- # Rollout the episode
24
  frames = []
25
  obs, info = env.reset()
26
-
27
- for _ in range(200): # Shorter rollout
28
  action, _ = model.predict(obs, deterministic=True)
29
  obs, reward, done, trunc, info = env.step(action)
30
-
31
- frame = env.render()
32
- frames.append(frame)
33
-
34
  if done or trunc:
35
  obs, info = env.reset()
36
-
37
  env.close()
38
 
39
- # Save frames into a video
40
  video_path = "run_video.mp4"
41
  imageio.mimsave(video_path, frames, fps=30)
42
-
43
  return video_path
44
 
45
- # --------------------------------------
46
- # Build the Gradio App
47
- # --------------------------------------
48
 
49
  with gr.Blocks() as demo:
50
  gr.Markdown("## Fetch Robot: Model Demo App")
51
- gr.Markdown("Enter start and target coordinates, then click 'Run Model' to watch the robot!")
52
  gr.Markdown("Coordinates are relative to the center of the table.")
53
- gr.Markdown("X and Y coordinates are in meters, Z coordinate is height in meters.")
54
- gr.Markdown("0,0,0 is the center of the table.")
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  with gr.Row():
57
  x_start = gr.Number(label="Start X", value=0.0)
@@ -62,14 +91,14 @@ with gr.Blocks() as demo:
62
  y_targ = gr.Number(label="Target Y", value=0.1)
63
  z_targ = gr.Number(label="Target Z", value=0.1)
64
 
65
- run_button = gr.Button("Run Model")
66
  output_video = gr.Video()
67
 
 
68
  run_button.click(
69
  fn=run_model_episode,
70
- inputs=[x_start, y_start, x_targ, y_targ, z_targ],
71
  outputs=output_video
72
  )
73
 
74
  demo.launch(share=True)
75
-
 
 
1
  import os
 
 
2
  import gradio as gr
3
  import numpy as np
4
  import torch
 
6
  from stable_baselines3 import SAC
7
  from custom_env import create_env
8
 
9
+ # Update your run function to accept a model_name
10
+ def run_model_episode(x_start, y_start, x_targ, y_targ, z_targ, model_name, random_coords):
11
+
12
+ # map the radio‐choice to the actual checkpoint on disk
13
+ model_paths = {
14
+ "Pick & Place (HER)": "App/model/pick_and_place_her.zip",
15
+ "Pick & Place (Dense)": "App/model/pick_and_place_dense.zip",
16
+ "Push": "App/model/push.zip",
17
+ "Reach": "App/model/reach.zip",
18
+ }
19
+ checkpoint_path = model_paths[model_name]
20
+
21
+ # map the radio‐choice to the actual environment name
22
+ environments = {
23
+ "Pick & Place (HER)": "FetchPickAndPlace-v3",
24
+ "Pick & Place (Dense)": "FetchPickAndPlaceDense-v3",
25
+ "Push": "FetchPush-v3",
26
+ "Reach": "FetchReach-v3",
27
+ }
28
+ environment = environments[model_name]
29
+
30
+ # Handle environment coordinates
31
+ if(environment == "FetchPush-v3"):
32
+ z_targ = 0.0
33
+
34
+ block_xy=(x_start, y_start),
35
+ goal_xyz=(x_targ, y_targ, z_targ)
36
+
37
+ if random_coords:
38
+ block_xy = None
39
+ goal_xyz = None
40
+
41
+ # create the env
42
+ env = create_env(
43
+ render_mode="rgb_array",
44
+ block_xy=block_xy,
45
+ goal_xyz=goal_xyz,
46
+ environment=environment
47
+ )
48
 
49
+ # load the selected model
50
+ model = SAC.load(checkpoint_path, env=env, verbose=0)
 
51
 
 
52
  frames = []
53
  obs, info = env.reset()
54
+ for _ in range(200):
 
55
  action, _ = model.predict(obs, deterministic=True)
56
  obs, reward, done, trunc, info = env.step(action)
57
+ frames.append(env.render())
 
 
 
58
  if done or trunc:
59
  obs, info = env.reset()
 
60
  env.close()
61
 
 
62
  video_path = "run_video.mp4"
63
  imageio.mimsave(video_path, frames, fps=30)
 
64
  return video_path
65
 
 
 
 
66
 
67
  with gr.Blocks() as demo:
68
  gr.Markdown("## Fetch Robot: Model Demo App")
69
+ gr.Markdown("Enter coordinates, pick a model, then click **Run Model**.")
70
  gr.Markdown("Coordinates are relative to the center of the table.")
71
+
72
+ # 1) add a radio (or gr.Dropdown) for model selection
73
+ model_selector = gr.Radio(
74
+ choices=["Pick & Place (HER)", "Pick & Place (Dense)", "Push", "Reach"],
75
+ value="Pick & Place (HER)",
76
+ label="Select a model/environment"
77
+ )
78
+
79
+ # Randomize coordinates
80
+ randomize = gr.Checkbox(
81
+ label="Use randomized coordinates?",
82
+ value=False
83
+ )
84
 
85
  with gr.Row():
86
  x_start = gr.Number(label="Start X", value=0.0)
 
91
  y_targ = gr.Number(label="Target Y", value=0.1)
92
  z_targ = gr.Number(label="Target Z", value=0.1)
93
 
94
+ run_button = gr.Button("Run Model")
95
  output_video = gr.Video()
96
 
97
+ # 2) include the selector as an input to your click callback
98
  run_button.click(
99
  fn=run_model_episode,
100
+ inputs=[x_start, y_start, x_targ, y_targ, z_targ, model_selector, randomize],
101
  outputs=output_video
102
  )
103
 
104
  demo.launch(share=True)
 
custom_env.py CHANGED
@@ -1,6 +1,6 @@
1
  # <-- this must come first, before any mujoco / gym imports
2
- import os
3
- os.environ["MUJOCO_GL"] = "osmesa"
4
 
5
  import numpy as np
6
  import gymnasium as gym
@@ -8,7 +8,7 @@ import gymnasium_robotics
8
  import mujoco
9
 
10
  class CustomFetchWrapper(gym.Wrapper):
11
- def __init__(self, env, block_xy=None, goal_xyz=None):
12
  super().__init__(env)
13
  self.u = env.unwrapped # MujocoFetchPickAndPlaceEnv
14
  # stash your fixed coords (or None to randomize)
@@ -16,6 +16,7 @@ class CustomFetchWrapper(gym.Wrapper):
16
  if block_xy is not None else None)
17
  self.default_goal_xyz = (np.array(goal_xyz, dtype=float)
18
  if goal_xyz is not None else None)
 
19
 
20
  def reset(self, *args, **kwargs):
21
  # 1) do the normal reset — gets you a random goal in obs
@@ -33,45 +34,49 @@ class CustomFetchWrapper(gym.Wrapper):
33
  ):
34
  utils.set_joint_qpos(model, data, name, val)
35
 
36
- # 3) pick block position
37
- if self.default_block_xy is None:
38
- # — original random‐sampling —
39
- home_xy = u.initial_gripper_xpos[:2]
40
- obj_range = u.obj_range
41
- min_dist = u.distance_threshold
42
- while True:
43
- offset = rng.uniform(-obj_range, obj_range, size=2)
44
- if np.linalg.norm(offset) < min_dist:
45
- continue
46
- cand = home_xy + offset
47
- if np.linalg.norm(cand - obs["desired_goal"][:2]) < min_dist:
48
- continue
49
- break
50
- block_xy = cand
51
- else:
52
- block_xy = self.default_block_xy
53
-
54
- # place the block
55
- blk_qpos = utils.get_joint_qpos(model, data, "object0:joint")
56
- blk_qpos[0:2] = block_xy
57
- blk_qpos[2] = 0.42 # table height
58
- utils.set_joint_qpos(model, data, "object0:joint", blk_qpos)
 
 
 
 
 
 
 
 
 
59
 
60
  # 4) pick goal position
61
- if self.default_goal_xyz is None:
62
- # — original “raise above table” logic —
63
- raise_z = 0.1 + rng.uniform(0, 0.2)
64
- new_goal = obs["desired_goal"].copy()
65
- new_goal[2] = blk_qpos[2] + raise_z
66
- else:
67
  new_goal = self.default_goal_xyz
68
 
69
- # override the goal both in the env and in the MuJoCo site
70
- u.goal = new_goal
71
- sid = mujoco.mj_name2id(model,
72
- mujoco.mjtObj.mjOBJ_SITE,
73
- "target0")
74
- data.site_xpos[sid] = new_goal
75
 
76
  # 5) forward‐kinematics + fresh obs
77
  u._mujoco.mj_forward(model, data)
@@ -80,9 +85,15 @@ class CustomFetchWrapper(gym.Wrapper):
80
  return obs, info
81
 
82
 
83
- def create_env(render_mode=None, block_xy=None, goal_xyz=None):
84
  gym.register_envs(gymnasium_robotics)
85
- base_env = gym.make("FetchPickAndPlace-v3", render_mode=render_mode)
 
 
 
 
 
 
86
  u = base_env.unwrapped
87
 
88
  # 1) compute table center in world coords
@@ -110,6 +121,7 @@ def create_env(render_mode=None, block_xy=None, goal_xyz=None):
110
  env = CustomFetchWrapper(
111
  base_env,
112
  block_xy=abs_block_xy,
113
- goal_xyz=abs_goal_xyz
 
114
  )
115
  return env
 
1
  # <-- this must come first, before any mujoco / gym imports
2
+ # import os
3
+ # os.environ["MUJOCO_GL"] = "osmesa"
4
 
5
  import numpy as np
6
  import gymnasium as gym
 
8
  import mujoco
9
 
10
  class CustomFetchWrapper(gym.Wrapper):
11
+ def __init__(self, env, block_xy=None, goal_xyz=None, object=True):
12
  super().__init__(env)
13
  self.u = env.unwrapped # MujocoFetchPickAndPlaceEnv
14
  # stash your fixed coords (or None to randomize)
 
16
  if block_xy is not None else None)
17
  self.default_goal_xyz = (np.array(goal_xyz, dtype=float)
18
  if goal_xyz is not None else None)
19
+ self.object = object
20
 
21
  def reset(self, *args, **kwargs):
22
  # 1) do the normal reset — gets you a random goal in obs
 
34
  ):
35
  utils.set_joint_qpos(model, data, name, val)
36
 
37
+ # pull out the actual goal so we can avoid it
38
+ goal_pos = obs["desired_goal"][:2].copy()
39
+
40
+ if (self.object==True):
41
+ # 3) pick block position
42
+ if self.default_block_xy is None:
43
+ home_xy = u.initial_gripper_xpos[:2]
44
+ obj_range = u.obj_range
45
+ min_dist = u.distance_threshold
46
+
47
+ while True:
48
+ offset = rng.uniform(-obj_range, obj_range, size=2)
49
+ # 3a) must be outside the “too-close to gripper” zone
50
+ if np.linalg.norm(offset) < min_dist:
51
+ continue
52
+ candidate_xy = home_xy + offset
53
+ # 3b) must be outside the “too-close to goal” zone
54
+ if np.linalg.norm(candidate_xy - goal_pos) < min_dist:
55
+ continue
56
+ # if we get here, both checks passed
57
+ break
58
+
59
+ block_xy = candidate_xy
60
+
61
+ else:
62
+ block_xy = self.default_block_xy
63
+
64
+ # place the block
65
+ blk_qpos = utils.get_joint_qpos(model, data, "object0:joint")
66
+ blk_qpos[0:2] = block_xy
67
+ blk_qpos[2] = 0.42 # table height
68
+ utils.set_joint_qpos(model, data, "object0:joint", blk_qpos)
69
 
70
  # 4) pick goal position
71
+ if self.default_goal_xyz is not None:
 
 
 
 
 
72
  new_goal = self.default_goal_xyz
73
 
74
+ # override the goal both in the env and in the MuJoCo site
75
+ u.goal = new_goal
76
+ sid = mujoco.mj_name2id(model,
77
+ mujoco.mjtObj.mjOBJ_SITE,
78
+ "target0")
79
+ data.site_xpos[sid] = new_goal
80
 
81
  # 5) forward‐kinematics + fresh obs
82
  u._mujoco.mj_forward(model, data)
 
85
  return obs, info
86
 
87
 
88
+ def create_env(render_mode=None, block_xy=None, goal_xyz=None, environment = "FetchPickAndPlace-v3"):
89
  gym.register_envs(gymnasium_robotics)
90
+
91
+ if(environment == "FetchReach-v3"):
92
+ object = False
93
+ else:
94
+ object = True
95
+
96
+ base_env = gym.make(environment, render_mode=render_mode)
97
  u = base_env.unwrapped
98
 
99
  # 1) compute table center in world coords
 
121
  env = CustomFetchWrapper(
122
  base_env,
123
  block_xy=abs_block_xy,
124
+ goal_xyz=abs_goal_xyz,
125
+ object=object
126
  )
127
  return env
app_test_2.py → old_apps/app_test_2.py RENAMED
File without changes
app_test_3.py → old_apps/app_test_3.py RENAMED
File without changes
app_test_4.py → old_apps/app_test_4.py RENAMED
File without changes