gkemp181's picture
Fixed naming
c866ade
import gradio as gr
import os
import numpy as np
import torch
import imageio
from stable_baselines3 import SAC
from custom_env import create_env
# Define the function that runs the model and outputs a video
def run_model_episode(x_start, y_start, x_targ, y_targ, z_targ):
# Create environment with user inputs
env = create_env(render_mode="rgb_array",
block_xy=(x_start, y_start),
goal_xyz=(x_targ, y_targ, z_targ))
# Load your trained model
checkpoint_path = os.path.join("App", "model", "model.zip")
model = SAC.load(checkpoint_path, env=env, verbose=1)
# Rollout the episode
frames = []
obs, info = env.reset()
for _ in range(200): # Shorter rollout
action, _ = model.predict(obs, deterministic=True)
obs, reward, done, trunc, info = env.step(action)
frame = env.render()
frames.append(frame)
if done or trunc:
obs, info = env.reset()
env.close()
# Save frames into a video
video_path = "run_video.mp4"
imageio.mimsave(video_path, frames, fps=30)
return video_path
# --------------------------------------
# Build the Gradio App
# --------------------------------------
with gr.Blocks() as demo:
gr.Markdown("## Fetch Robot: Model Demo App")
gr.Markdown("Enter start and target coordinates, then click 'Run Model' to watch the robot!")
gr.Markdown("Coordinates are relative to the center of the table.")
gr.Markdown("X and Y coordinates are in meters, Z coordinate is height in meters.")
gr.Markdown("0,0,0 is the center of the table.")
with gr.Row():
x_start = gr.Number(label="Start X", value=0.0)
y_start = gr.Number(label="Start Y", value=0.0)
with gr.Row():
x_targ = gr.Number(label="Target X", value=0.1)
y_targ = gr.Number(label="Target Y", value=0.1)
z_targ = gr.Number(label="Target Z", value=0.1)
run_button = gr.Button("Run Model")
output_video = gr.Video()
run_button.click(
fn=run_model_episode,
inputs=[x_start, y_start, x_targ, y_targ, z_targ],
outputs=output_video
)
demo.launch(share=True)