Fouzanjaved commited on
Commit
18d48bd
·
verified ·
1 Parent(s): 6d8bb59

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -45
app.py CHANGED
@@ -1,47 +1,93 @@
1
  import streamlit as st
2
  import numpy as np
3
- import matplotlib.pyplot as plt
4
- from model import load_model, generate_trajectory
5
- from environment import RobotEnv
6
-
7
- # Set Streamlit title
8
- st.title("6-DOF Robot Motion Planning (Diffusion Model)")
9
-
10
- # Load model
11
- device = "cuda" if torch.cuda.is_available() else "cpu"
12
- model = load_model(device)
13
-
14
- # Create robot environment
15
- env = RobotEnv()
16
-
17
- # User input for start and goal state
18
- st.sidebar.header("Set Start & Goal Positions")
19
- start_state = env.get_random_start_state()
20
- goal_state = env.get_random_goal_state()
21
-
22
- start_state = st.sidebar.text_input("Start State (comma-separated)", ",".join(map(str, start_state)))
23
- goal_state = st.sidebar.text_input("Goal State (comma-separated)", ",".join(map(str, goal_state)))
24
-
25
- # Convert input to NumPy array
26
- try:
27
- start_state = np.array([float(x) for x in start_state.split(",")])
28
- goal_state = np.array([float(x) for x in goal_state.split(",")])
29
- except ValueError:
30
- st.sidebar.error("Enter valid numbers separated by commas.")
31
- start_state = env.get_random_start_state()
32
- goal_state = env.get_random_goal_state()
33
-
34
- # Generate trajectory button
35
- if st.button("Generate Trajectory"):
36
- trajectory = generate_trajectory(model, start_state, goal_state)
37
-
38
- # Plot the trajectory
39
- fig, ax = plt.subplots()
40
- for i in range(6): # 6-DOF joints
41
- ax.plot(trajectory[:, i], label=f"Joint {i+1}")
42
-
43
- ax.set_title("Generated 6-DOF Robot Trajectory")
44
- ax.set_xlabel("Time Steps")
45
- ax.set_ylabel("Joint Angles")
46
- ax.legend()
47
- st.pyplot(fig)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  import numpy as np
3
+ import pybullet as p
4
+ import pybullet_data
5
+ import time
6
+ import tensorflow as tf
7
+
8
+ # Streamlit app
9
+ st.title("6-DOF Robotic Arm Path Finder with Pretrained Model")
10
+
11
+ # User inputs
12
+ start_position = st.text_input("Enter start position (x, y, z):", "0, 0, 0")
13
+ target_position = st.text_input("Enter target position (x, y, z):", "1, 1, 1")
14
+ obstacles = st.text_area("Enter obstacle positions (x, y, z, radius):", "0.5, 0.5, 0.5, 0.2")
15
+
16
+ # Parse inputs
17
+ start = np.array([float(x) for x in start_position.split(",")])
18
+ target = np.array([float(x) for x in target_position.split(",")])
19
+ obstacle_list = [tuple(map(float, line.split(","))) for line in obstacles.split("\n")]
20
+
21
+ # Load pretrained model (replace with your model)
22
+ def load_pretrained_model():
23
+ # Example: Load a simple neural network
24
+ model = tf.keras.Sequential([
25
+ tf.keras.layers.Dense(64, activation='relu', input_shape=(6,)), # Input: start + target
26
+ tf.keras.layers.Dense(64, activation='relu'),
27
+ tf.keras.layers.Dense(3) # Output: next waypoint
28
+ ])
29
+ # Load pretrained weights (dummy for demonstration)
30
+ model.load_weights("pretrained_model_weights.h5")
31
+ return model
32
+
33
+ # Predict path using the pretrained model
34
+ def predict_path(model, start, target, obstacles):
35
+ # Prepare input (start + target + obstacle info)
36
+ input_data = np.concatenate([start, target, np.array(obstacles).flatten()])
37
+ input_data = np.expand_dims(input_data, axis=0)
38
+
39
+ # Predict waypoints
40
+ waypoints = model.predict(input_data)
41
+ return waypoints
42
+
43
+ # PyBullet simulation
44
+ def visualize_path(path):
45
+ # Initialize PyBullet
46
+ p.connect(p.GUI)
47
+ p.setAdditionalSearchPath(pybullet_data.getDataPath())
48
+ p.setGravity(0, 0, -10)
49
+
50
+ # Load plane and robot arm
51
+ plane_id = p.loadURDF("plane.urdf")
52
+ robot_id = p.loadURDF("kuka_iiwa/model.urdf", [0, 0, 0])
53
+
54
+ # Add obstacles
55
+ for obstacle in obstacle_list:
56
+ obstacle_pos = obstacle[:3]
57
+ obstacle_radius = obstacle[3]
58
+ obstacle_shape = p.createCollisionShape(p.GEOM_SPHERE, radius=obstacle_radius)
59
+ obstacle_id = p.createMultiBody(baseMass=0, baseCollisionShapeIndex=obstacle_shape, basePosition=obstacle_pos)
60
+
61
+ # Visualize path
62
+ for i in range(len(path) - 1):
63
+ p.addUserDebugLine(path[i], path[i + 1], lineColorRGB=[1, 0, 0], lineWidth=2)
64
+ p.addUserDebugText(f"{path[i]}", path[i] + [0, 0, 0.2], textColorRGB=[1, 0, 0])
65
+
66
+ # Move the robot arm along the path
67
+ for point in path:
68
+ # Set target position (simplified for demonstration)
69
+ p.resetBasePositionAndOrientation(robot_id, point, [0, 0, 0, 1])
70
+ time.sleep(1)
71
+
72
+ # Keep the simulation running
73
+ st.write("Simulation running... Close the PyBullet window to continue.")
74
+ while True:
75
+ p.stepSimulation()
76
+ time.sleep(1 / 240)
77
+
78
+ # Run the app
79
+ if st.button("Find Path and Visualize"):
80
+ # Load pretrained model
81
+ model = load_pretrained_model()
82
+
83
+ # Predict path
84
+ path = predict_path(model, start, target, obstacle_list)
85
+
86
+ if path is not None:
87
+ st.success("Path found!")
88
+ st.write("Path:", path)
89
+
90
+ # Visualize the path in PyBullet
91
+ visualize_path(path)
92
+ else:
93
+ st.error("No path found.")