| import gradio as gr |
| import os |
| import numpy as np |
| import torch |
| import imageio |
| from stable_baselines3 import SAC |
| from custom_env import create_env |
|
|
| |
| def run_model_episode(): |
| |
| |
| |
| x_start, y_start = 0.0, 0.0 |
| x_targ, y_targ, z_targ = 0.1, 0.1, 0.1 |
|
|
| env = create_env(render_mode="rgb_array", |
| block_xy=(x_start, y_start), |
| goal_xyz=(x_targ, y_targ, z_targ)) |
|
|
| |
| checkpoint_path = os.path.join("model", "model.zip") |
| model = SAC.load(checkpoint_path, env=env, verbose=1) |
|
|
| |
| frames = [] |
| obs, info = env.reset() |
|
|
| for _ in range(200): |
| action, _ = model.predict(obs, deterministic=True) |
| obs, reward, done, trunc, info = env.step(action) |
|
|
| frame = env.render() |
| frames.append(frame) |
|
|
| if done or trunc: |
| obs, info = env.reset() |
|
|
| env.close() |
|
|
| |
| |
|
|
| |
| video_path = "run_video_2.mp4" |
| imageio.mimsave(video_path, frames, fps=30) |
|
|
| |
| return video_path |
|
|
| |
| |
| |
|
|
| with gr.Blocks() as demo: |
| gr.Markdown("Fetch Robot: Model Demo App") |
| gr.Markdown("Click 'Run Model' to watch the SAC agent interact with the FetchPickAndPlace environment.") |
|
|
| run_button = gr.Button("Run Model") |
| output_video = gr.Video() |
|
|
| run_button.click(fn=run_model_episode, inputs=[], outputs=output_video) |
|
|
| demo.launch(share=True) |