File size: 2,187 Bytes
ca09bf5
 
 
 
 
 
c866ade
ca09bf5
 
c866ade
 
 
 
 
ca09bf5
c866ade
 
ca09bf5
 
c866ade
ca09bf5
 
 
c866ade
ca09bf5
 
 
c866ade
ca09bf5
 
 
 
 
 
 
c866ade
ca09bf5
 
 
 
 
 
 
 
 
 
c866ade
 
 
 
 
 
 
 
 
 
 
 
 
 
ca09bf5
 
 
 
c866ade
 
 
 
 
ca09bf5
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import gradio as gr
import os
import numpy as np
import torch
import imageio
from stable_baselines3 import SAC
from custom_env import create_env

# Define the function that runs the model and outputs a video
def run_model_episode(x_start, y_start, x_targ, y_targ, z_targ):
    # Create environment with user inputs
    env = create_env(render_mode="rgb_array",
                     block_xy=(x_start, y_start),
                     goal_xyz=(x_targ, y_targ, z_targ))

    # Load your trained model
    checkpoint_path = os.path.join("App", "model", "model.zip")
    model = SAC.load(checkpoint_path, env=env, verbose=1)

    # Rollout the episode
    frames = []
    obs, info = env.reset()

    for _ in range(200):  # Shorter rollout
        action, _ = model.predict(obs, deterministic=True)
        obs, reward, done, trunc, info = env.step(action)

        frame = env.render()
        frames.append(frame)

        if done or trunc:
            obs, info = env.reset()

    env.close()

    # Save frames into a video
    video_path = "run_video.mp4"
    imageio.mimsave(video_path, frames, fps=30)

    return video_path

# --------------------------------------
# Build the Gradio App
# --------------------------------------

with gr.Blocks() as demo:
    gr.Markdown("## Fetch Robot: Model Demo App")
    gr.Markdown("Enter start and target coordinates, then click 'Run Model' to watch the robot!")
    gr.Markdown("Coordinates are relative to the center of the table.")
    gr.Markdown("X and Y coordinates are in meters, Z coordinate is height in meters.") 
    gr.Markdown("0,0,0 is the center of the table.")

    with gr.Row():
        x_start = gr.Number(label="Start X", value=0.0)
        y_start = gr.Number(label="Start Y", value=0.0)

    with gr.Row():
        x_targ = gr.Number(label="Target X", value=0.1)
        y_targ = gr.Number(label="Target Y", value=0.1)
        z_targ = gr.Number(label="Target Z", value=0.1)

    run_button = gr.Button("Run Model")
    output_video = gr.Video()

    run_button.click(
        fn=run_model_episode,
        inputs=[x_start, y_start, x_targ, y_targ, z_targ],
        outputs=output_video
    )

demo.launch(share=True)