gkemp181 commited on
Commit
c866ade
·
1 Parent(s): 8edfe19

Fixed naming

Browse files
Files changed (2) hide show
  1. app.py +31 -14
  2. app_test_5.py +0 -72
app.py CHANGED
@@ -4,26 +4,28 @@ import numpy as np
4
  import torch
5
  import imageio
6
  from stable_baselines3 import SAC
7
- from create_env import create_env
8
 
9
  # Define the function that runs the model and outputs a video
10
- def run_model_episode():
11
- # 1. Create environment with render_mode="rgb_array" (needed to capture frames)
12
- env = create_env(render_mode="rgb_array")
 
 
13
 
14
- # 2. Load your trained model
15
- checkpoint_path = os.path.join("models", "test", "model.zip")
16
  model = SAC.load(checkpoint_path, env=env, verbose=1)
17
 
18
- # 3. Rollout the episode
19
  frames = []
20
  obs, info = env.reset()
21
 
22
- for _ in range(200): # Shorter rollout to avoid giant videos
23
  action, _ = model.predict(obs, deterministic=True)
24
  obs, reward, done, trunc, info = env.step(action)
25
 
26
- frame = env.render() # Get current frame as image (rgb_array)
27
  frames.append(frame)
28
 
29
  if done or trunc:
@@ -31,11 +33,10 @@ def run_model_episode():
31
 
32
  env.close()
33
 
34
- # 4. Save the frames into a video
35
  video_path = "run_video.mp4"
36
  imageio.mimsave(video_path, frames, fps=30)
37
 
38
- # 5. Return path to Gradio to display
39
  return video_path
40
 
41
  # --------------------------------------
@@ -43,13 +44,29 @@ def run_model_episode():
43
  # --------------------------------------
44
 
45
  with gr.Blocks() as demo:
46
- gr.Markdown("# 🤖 Fetch Robot: Model Demo App")
47
- gr.Markdown("Click 'Run Model' to watch the SAC agent interact with the FetchPickAndPlace environment.")
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  run_button = gr.Button("Run Model")
50
  output_video = gr.Video()
51
 
52
- run_button.click(fn=run_model_episode, inputs=[], outputs=output_video)
 
 
 
 
53
 
54
  demo.launch(share=True)
55
 
 
4
  import torch
5
  import imageio
6
  from stable_baselines3 import SAC
7
+ from custom_env import create_env
8
 
9
  # Define the function that runs the model and outputs a video
10
+ def run_model_episode(x_start, y_start, x_targ, y_targ, z_targ):
11
+ # Create environment with user inputs
12
+ env = create_env(render_mode="rgb_array",
13
+ block_xy=(x_start, y_start),
14
+ goal_xyz=(x_targ, y_targ, z_targ))
15
 
16
+ # Load your trained model
17
+ checkpoint_path = os.path.join("App", "model", "model.zip")
18
  model = SAC.load(checkpoint_path, env=env, verbose=1)
19
 
20
+ # Rollout the episode
21
  frames = []
22
  obs, info = env.reset()
23
 
24
+ for _ in range(200): # Shorter rollout
25
  action, _ = model.predict(obs, deterministic=True)
26
  obs, reward, done, trunc, info = env.step(action)
27
 
28
+ frame = env.render()
29
  frames.append(frame)
30
 
31
  if done or trunc:
 
33
 
34
  env.close()
35
 
36
+ # Save frames into a video
37
  video_path = "run_video.mp4"
38
  imageio.mimsave(video_path, frames, fps=30)
39
 
 
40
  return video_path
41
 
42
  # --------------------------------------
 
44
  # --------------------------------------
45
 
46
  with gr.Blocks() as demo:
47
+ gr.Markdown("## Fetch Robot: Model Demo App")
48
+ gr.Markdown("Enter start and target coordinates, then click 'Run Model' to watch the robot!")
49
+ gr.Markdown("Coordinates are relative to the center of the table.")
50
+ gr.Markdown("X and Y coordinates are in meters, Z coordinate is height in meters.")
51
+ gr.Markdown("0,0,0 is the center of the table.")
52
+
53
+ with gr.Row():
54
+ x_start = gr.Number(label="Start X", value=0.0)
55
+ y_start = gr.Number(label="Start Y", value=0.0)
56
+
57
+ with gr.Row():
58
+ x_targ = gr.Number(label="Target X", value=0.1)
59
+ y_targ = gr.Number(label="Target Y", value=0.1)
60
+ z_targ = gr.Number(label="Target Z", value=0.1)
61
 
62
  run_button = gr.Button("Run Model")
63
  output_video = gr.Video()
64
 
65
+ run_button.click(
66
+ fn=run_model_episode,
67
+ inputs=[x_start, y_start, x_targ, y_targ, z_targ],
68
+ outputs=output_video
69
+ )
70
 
71
  demo.launch(share=True)
72
 
app_test_5.py DELETED
@@ -1,72 +0,0 @@
1
- import gradio as gr
2
- import os
3
- import numpy as np
4
- import torch
5
- import imageio
6
- from stable_baselines3 import SAC
7
- from custom_env import create_env
8
-
9
- # Define the function that runs the model and outputs a video
10
- def run_model_episode(x_start, y_start, x_targ, y_targ, z_targ):
11
- # Create environment with user inputs
12
- env = create_env(render_mode="rgb_array",
13
- block_xy=(x_start, y_start),
14
- goal_xyz=(x_targ, y_targ, z_targ))
15
-
16
- # Load your trained model
17
- checkpoint_path = os.path.join("App", "model", "model.zip")
18
- model = SAC.load(checkpoint_path, env=env, verbose=1)
19
-
20
- # Rollout the episode
21
- frames = []
22
- obs, info = env.reset()
23
-
24
- for _ in range(200): # Shorter rollout
25
- action, _ = model.predict(obs, deterministic=True)
26
- obs, reward, done, trunc, info = env.step(action)
27
-
28
- frame = env.render()
29
- frames.append(frame)
30
-
31
- if done or trunc:
32
- obs, info = env.reset()
33
-
34
- env.close()
35
-
36
- # Save frames into a video
37
- video_path = "run_video.mp4"
38
- imageio.mimsave(video_path, frames, fps=30)
39
-
40
- return video_path
41
-
42
- # --------------------------------------
43
- # Build the Gradio App
44
- # --------------------------------------
45
-
46
- with gr.Blocks() as demo:
47
- gr.Markdown("## Fetch Robot: Model Demo App")
48
- gr.Markdown("Enter start and target coordinates, then click 'Run Model' to watch the robot!")
49
- gr.Markdown("Coordinates are relative to the center of the table.")
50
- gr.Markdown("X and Y coordinates are in meters, Z coordinate is height in meters.")
51
- gr.Markdown("0,0,0 is the center of the table.")
52
-
53
- with gr.Row():
54
- x_start = gr.Number(label="Start X", value=0.0)
55
- y_start = gr.Number(label="Start Y", value=0.0)
56
-
57
- with gr.Row():
58
- x_targ = gr.Number(label="Target X", value=0.1)
59
- y_targ = gr.Number(label="Target Y", value=0.1)
60
- z_targ = gr.Number(label="Target Z", value=0.1)
61
-
62
- run_button = gr.Button("Run Model")
63
- output_video = gr.Video()
64
-
65
- run_button.click(
66
- fn=run_model_episode,
67
- inputs=[x_start, y_start, x_targ, y_targ, z_targ],
68
- outputs=output_video
69
- )
70
-
71
- demo.launch(share=True)
72
-