| import gradio as gr |
| import os |
| import numpy as np |
| import torch |
| import imageio |
| from stable_baselines3 import SAC |
| from custom_env import create_env |
|
|
| |
| def run_model_episode(x_start, y_start, x_targ, y_targ, z_targ): |
| |
| env = create_env(render_mode="rgb_array", |
| block_xy=(x_start, y_start), |
| goal_xyz=(x_targ, y_targ, z_targ)) |
|
|
| |
| checkpoint_path = os.path.join("App", "model", "model.zip") |
| model = SAC.load(checkpoint_path, env=env, verbose=1) |
|
|
| |
| frames = [] |
| obs, info = env.reset() |
|
|
| for _ in range(200): |
| action, _ = model.predict(obs, deterministic=True) |
| obs, reward, done, trunc, info = env.step(action) |
|
|
| frame = env.render() |
| frames.append(frame) |
|
|
| if done or trunc: |
| obs, info = env.reset() |
|
|
| env.close() |
|
|
| |
| video_path = "run_video.mp4" |
| imageio.mimsave(video_path, frames, fps=30) |
|
|
| return video_path |
|
|
| |
| |
| |
|
|
| with gr.Blocks() as demo: |
| gr.Markdown("## Fetch Robot: Model Demo App") |
| gr.Markdown("Enter start and target coordinates, then click 'Run Model' to watch the robot!") |
| gr.Markdown("Coordinates are relative to the center of the table.") |
| gr.Markdown("X and Y coordinates are in meters, Z coordinate is height in meters.") |
| gr.Markdown("0,0,0 is the center of the table.") |
|
|
| with gr.Row(): |
| x_start = gr.Number(label="Start X", value=0.0) |
| y_start = gr.Number(label="Start Y", value=0.0) |
|
|
| with gr.Row(): |
| x_targ = gr.Number(label="Target X", value=0.1) |
| y_targ = gr.Number(label="Target Y", value=0.1) |
| z_targ = gr.Number(label="Target Z", value=0.1) |
|
|
| run_button = gr.Button("Run Model") |
| output_video = gr.Video() |
|
|
| run_button.click( |
| fn=run_model_episode, |
| inputs=[x_start, y_start, x_targ, y_targ, z_targ], |
| outputs=output_video |
| ) |
|
|
| demo.launch(share=True) |
|
|
|
|