| import os |
| import gradio as gr |
| import numpy as np |
| import torch |
| import imageio |
| from stable_baselines3 import SAC |
| from custom_env import create_env |
|
|
| |
| def run_model_episode(x_start, y_start, x_targ, y_targ, z_targ, model_name, random_coords): |
| |
| |
| model_paths = { |
| "Pick & Place (HER)": "App/model/pick_and_place_her.zip", |
| "Pick & Place (Dense)": "App/model/pick_and_place_dense.zip", |
| "Push": "App/model/push.zip", |
| "Reach": "App/model/reach.zip", |
| } |
| checkpoint_path = model_paths[model_name] |
|
|
| |
| environments = { |
| "Pick & Place (HER)": "FetchPickAndPlace-v3", |
| "Pick & Place (Dense)": "FetchPickAndPlaceDense-v3", |
| "Push": "FetchPush-v3", |
| "Reach": "FetchReach-v3", |
| } |
| environment = environments[model_name] |
|
|
| |
| if(environment == "FetchPush-v3"): |
| z_targ = 0.0 |
|
|
| block_xy=(x_start, y_start), |
| goal_xyz=(x_targ, y_targ, z_targ) |
|
|
| if random_coords: |
| block_xy = None |
| goal_xyz = None |
| |
| |
| env = create_env( |
| render_mode="rgb_array", |
| block_xy=block_xy, |
| goal_xyz=goal_xyz, |
| environment=environment |
| ) |
|
|
| |
| model = SAC.load(checkpoint_path, env=env, verbose=0) |
|
|
| frames = [] |
| obs, info = env.reset() |
| for _ in range(200): |
| action, _ = model.predict(obs, deterministic=True) |
| obs, reward, done, trunc, info = env.step(action) |
| frames.append(env.render()) |
| if done or trunc: |
| obs, info = env.reset() |
| env.close() |
|
|
| video_path = "run_video.mp4" |
| imageio.mimsave(video_path, frames, fps=30) |
| return video_path |
|
|
|
|
| with gr.Blocks() as demo: |
| gr.Markdown("## Fetch Robot: Model Demo App") |
| gr.Markdown("Enter coordinates, pick a model, then click **Run Model**.") |
| gr.Markdown("Coordinates are relative to the center of the table.") |
|
|
| |
| model_selector = gr.Radio( |
| choices=["Pick & Place (HER)", "Pick & Place (Dense)", "Push", "Reach"], |
| value="Pick & Place (HER)", |
| label="Select a model/environment" |
| ) |
|
|
| |
| randomize = gr.Checkbox( |
| label="Use randomized coordinates?", |
| value=False |
| ) |
|
|
| with gr.Row(): |
| x_start = gr.Number(label="Start X", value=0.0) |
| y_start = gr.Number(label="Start Y", value=0.0) |
|
|
| with gr.Row(): |
| x_targ = gr.Number(label="Target X", value=0.1) |
| y_targ = gr.Number(label="Target Y", value=0.1) |
| z_targ = gr.Number(label="Target Z", value=0.1) |
|
|
| run_button = gr.Button("Run Model") |
| output_video = gr.Video() |
|
|
| |
| run_button.click( |
| fn=run_model_episode, |
| inputs=[x_start, y_start, x_targ, y_targ, z_targ, model_selector, randomize], |
| outputs=output_video |
| ) |
|
|
| demo.launch(share=True) |
|
|