gkemp181 commited on
Commit
ca09bf5
·
1 Parent(s): 12362f2

Initial commit

Browse files
Files changed (3) hide show
  1. app.py +72 -0
  2. custom_env.py +111 -0
  3. requirements.txt +0 -0
app.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import numpy as np
4
+ import torch
5
+ import imageio
6
+ from stable_baselines3 import SAC
7
+ from custom_env import create_env
8
+
9
+ # Define the function that runs the model and outputs a video
10
+ def run_model_episode(x_start, y_start, x_targ, y_targ, z_targ):
11
+ # Create environment with user inputs
12
+ env = create_env(render_mode="rgb_array",
13
+ block_xy=(x_start, y_start),
14
+ goal_xyz=(x_targ, y_targ, z_targ))
15
+
16
+ # Load your trained model
17
+ checkpoint_path = os.path.join("App", "model", "model.zip")
18
+ model = SAC.load(checkpoint_path, env=env, verbose=1)
19
+
20
+ # Rollout the episode
21
+ frames = []
22
+ obs, info = env.reset()
23
+
24
+ for _ in range(200): # Shorter rollout
25
+ action, _ = model.predict(obs, deterministic=True)
26
+ obs, reward, done, trunc, info = env.step(action)
27
+
28
+ frame = env.render()
29
+ frames.append(frame)
30
+
31
+ if done or trunc:
32
+ obs, info = env.reset()
33
+
34
+ env.close()
35
+
36
+ # Save frames into a video
37
+ video_path = "run_video.mp4"
38
+ imageio.mimsave(video_path, frames, fps=30)
39
+
40
+ return video_path
41
+
42
+ # --------------------------------------
43
+ # Build the Gradio App
44
+ # --------------------------------------
45
+
46
+ with gr.Blocks() as demo:
47
+ gr.Markdown("## Fetch Robot: Model Demo App")
48
+ gr.Markdown("Enter start and target coordinates, then click 'Run Model' to watch the robot!")
49
+ gr.Markdown("Coordinates are relative to the center of the table.")
50
+ gr.Markdown("X and Y coordinates are in meters, Z coordinate is height in meters.")
51
+ gr.Markdown("0,0,0 is the center of the table.")
52
+
53
+ with gr.Row():
54
+ x_start = gr.Number(label="Start X", value=0.0)
55
+ y_start = gr.Number(label="Start Y", value=0.0)
56
+
57
+ with gr.Row():
58
+ x_targ = gr.Number(label="Target X", value=0.1)
59
+ y_targ = gr.Number(label="Target Y", value=0.1)
60
+ z_targ = gr.Number(label="Target Z", value=0.1)
61
+
62
+ run_button = gr.Button("Run Model")
63
+ output_video = gr.Video()
64
+
65
+ run_button.click(
66
+ fn=run_model_episode,
67
+ inputs=[x_start, y_start, x_targ, y_targ, z_targ],
68
+ outputs=output_video
69
+ )
70
+
71
+ demo.launch(share=True)
72
+
custom_env.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import gymnasium as gym
3
+ import gymnasium_robotics
4
+ import mujoco
5
+
6
+ class CustomFetchWrapper(gym.Wrapper):
7
+ def __init__(self, env, block_xy=None, goal_xyz=None):
8
+ super().__init__(env)
9
+ self.u = env.unwrapped # MujocoFetchPickAndPlaceEnv
10
+ # stash your fixed coords (or None to randomize)
11
+ self.default_block_xy = (np.array(block_xy, dtype=float)
12
+ if block_xy is not None else None)
13
+ self.default_goal_xyz = (np.array(goal_xyz, dtype=float)
14
+ if goal_xyz is not None else None)
15
+
16
+ def reset(self, *args, **kwargs):
17
+ # 1) do the normal reset — gets you a random goal in obs
18
+ obs, info = super().reset(*args, **kwargs)
19
+ u = self.unwrapped
20
+ model = u.model
21
+ data = u.data
22
+ utils = u._utils
23
+ rng = u.np_random
24
+
25
+ # 2) reset the robot slides to your home pose
26
+ for name, val in zip(
27
+ ["robot0:slide0","robot0:slide1","robot0:slide2"],
28
+ [0.405, 0.48, 0.0],
29
+ ):
30
+ utils.set_joint_qpos(model, data, name, val)
31
+
32
+ # 3) pick block position
33
+ if self.default_block_xy is None:
34
+ # — original random‐sampling —
35
+ home_xy = u.initial_gripper_xpos[:2]
36
+ obj_range = u.obj_range
37
+ min_dist = u.distance_threshold
38
+ while True:
39
+ offset = rng.uniform(-obj_range, obj_range, size=2)
40
+ if np.linalg.norm(offset) < min_dist:
41
+ continue
42
+ cand = home_xy + offset
43
+ if np.linalg.norm(cand - obs["desired_goal"][:2]) < min_dist:
44
+ continue
45
+ break
46
+ block_xy = cand
47
+ else:
48
+ block_xy = self.default_block_xy
49
+
50
+ # place the block
51
+ blk_qpos = utils.get_joint_qpos(model, data, "object0:joint")
52
+ blk_qpos[0:2] = block_xy
53
+ blk_qpos[2] = 0.42 # table height
54
+ utils.set_joint_qpos(model, data, "object0:joint", blk_qpos)
55
+
56
+ # 4) pick goal position
57
+ if self.default_goal_xyz is None:
58
+ # — original “raise above table” logic —
59
+ raise_z = 0.1 + rng.uniform(0, 0.2)
60
+ new_goal = obs["desired_goal"].copy()
61
+ new_goal[2] = blk_qpos[2] + raise_z
62
+ else:
63
+ new_goal = self.default_goal_xyz
64
+
65
+ # override the goal both in the env and in the MuJoCo site
66
+ u.goal = new_goal
67
+ sid = mujoco.mj_name2id(model,
68
+ mujoco.mjtObj.mjOBJ_SITE,
69
+ "target0")
70
+ data.site_xpos[sid] = new_goal
71
+
72
+ # 5) forward‐kinematics + fresh obs
73
+ u._mujoco.mj_forward(model, data)
74
+ obs = u._get_obs()
75
+
76
+ return obs, info
77
+
78
+
79
+ def create_env(render_mode=None, block_xy=None, goal_xyz=None):
80
+ gym.register_envs(gymnasium_robotics)
81
+ base_env = gym.make("FetchPickAndPlace-v3", render_mode=render_mode)
82
+ u = base_env.unwrapped
83
+
84
+ # 1) compute table center in world coords
85
+ # – X,Y: same as the gripper’s initial XY (over table center)
86
+ # – Z: the table‐top height the wrapper uses (0.42 m)
87
+ center_xy = u.initial_gripper_xpos[:2] # e.g. [1.366, 0.750]
88
+ table_z = 0.42 # match blk_qpos[2] in your wrapper
89
+ table_center = np.array([*center_xy, table_z])
90
+
91
+ # 2) turn your “relative” block_xy into an absolute XY
92
+ if block_xy is not None:
93
+ rel = np.array(block_xy, dtype=float)
94
+ abs_block_xy = center_xy + rel
95
+ else:
96
+ abs_block_xy = None
97
+
98
+ # 3) turn your “relative” goal_xyz into an absolute XYZ
99
+ if goal_xyz is not None:
100
+ rel = np.array(goal_xyz, dtype=float)
101
+ abs_goal_xyz = table_center + rel
102
+ else:
103
+ abs_goal_xyz = None
104
+
105
+ # 4) build the wrapped env with those absolutes
106
+ env = CustomFetchWrapper(
107
+ base_env,
108
+ block_xy=abs_block_xy,
109
+ goal_xyz=abs_goal_xyz
110
+ )
111
+ return env
requirements.txt ADDED
Binary file (3.9 kB). View file