Fouzanjaved commited on
Commit
2f9a092
·
verified ·
1 Parent(s): dfeaf40

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -31
app.py CHANGED
@@ -1,34 +1,47 @@
 
 
 
 
 
1
 
2
- import torch
3
- from diffusion_policy.model.diffusion_policy import DiffusionPolicy
4
- from diffusion_policy.env.robot_env import RobotEnv # Custom environment for robot motion
5
 
6
- # Check if CUDA is available
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
8
- dtype = torch.float16 if device == "cuda" else torch.float32
9
-
10
- print("Using device:", device)
11
-
12
- # Load a pretrained diffusion model for robot motion
13
- model_path = "diffusion-policy/pretrained/6dof_robot" # Example path, replace with actual model
14
- model = DiffusionPolicy.load_from_checkpoint(model_path, map_location=device)
15
- model.to(device)
16
- model.eval() # Set model to evaluation mode
17
-
18
- # Define the robot environment
19
- env = RobotEnv(num_dof=6) # 6-DOF robot environment
20
- start_state = env.get_random_start_state() # Get a random starting position
21
- goal_state = env.get_random_goal_state() # Get a random goal position
22
-
23
- # Generate trajectory using diffusion model
24
- num_steps = 50
25
- trajectory = torch.zeros((num_steps, 6), device=device, dtype=dtype) # Store trajectory
26
-
27
- with torch.no_grad(): # No gradients needed for inference
28
- state = start_state.clone().to(device)
29
- for t in range(num_steps):
30
- action = model.predict(state) # Predict the next action using the diffusion model
31
- state = env.step(state, action) # Move the robot in simulation
32
- trajectory[t] = state
33
-
34
- print("Generated 6-DOF trajectory:\n", trajectory.cpu().numpy())
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ from model import load_model, generate_trajectory
5
+ from environment import RobotEnv
6
 
7
+ # Set Streamlit title
8
+ st.title("6-DOF Robot Motion Planning (Diffusion Model)")
 
9
 
10
+ # Load model
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ model = load_model(device)
13
+
14
+ # Create robot environment
15
+ env = RobotEnv()
16
+
17
+ # User input for start and goal state
18
+ st.sidebar.header("Set Start & Goal Positions")
19
+ start_state = env.get_random_start_state()
20
+ goal_state = env.get_random_goal_state()
21
+
22
+ start_state = st.sidebar.text_input("Start State (comma-separated)", ",".join(map(str, start_state)))
23
+ goal_state = st.sidebar.text_input("Goal State (comma-separated)", ",".join(map(str, goal_state)))
24
+
25
+ # Convert input to NumPy array
26
+ try:
27
+ start_state = np.array([float(x) for x in start_state.split(",")])
28
+ goal_state = np.array([float(x) for x in goal_state.split(",")])
29
+ except ValueError:
30
+ st.sidebar.error("Enter valid numbers separated by commas.")
31
+ start_state = env.get_random_start_state()
32
+ goal_state = env.get_random_goal_state()
33
+
34
+ # Generate trajectory button
35
+ if st.button("Generate Trajectory"):
36
+ trajectory = generate_trajectory(model, start_state, goal_state)
37
+
38
+ # Plot the trajectory
39
+ fig, ax = plt.subplots()
40
+ for i in range(6): # 6-DOF joints
41
+ ax.plot(trajectory[:, i], label=f"Joint {i+1}")
42
+
43
+ ax.set_title("Generated 6-DOF Robot Trajectory")
44
+ ax.set_xlabel("Time Steps")
45
+ ax.set_ylabel("Joint Angles")
46
+ ax.legend()
47
+ st.pyplot(fig)