kuds commited on
Commit
19cb6d4
·
verified ·
1 Parent(s): 8f1db03

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +94 -3
README.md CHANGED
@@ -1,3 +1,94 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: stable-baselines3
3
+ tags:
4
+ - FetchReachDense-v4
5
+ - deep-reinforcement-learning
6
+ - reinforcement-learning
7
+ - stable-baselines3
8
+ model-index:
9
+ - name: DDPG
10
+ results:
11
+ - task:
12
+ type: reinforcement-learning
13
+ name: reinforcement-learning
14
+ dataset:
15
+ name: FetchReachDense-v4
16
+ type: FetchReachDense-v4
17
+ metrics:
18
+ - type: mean_reward
19
+ value: -0.65 +/- 0.39
20
+ name: mean_reward
21
+ verified: false
22
+ license: mit
23
+ language:
24
+ - en
25
+ ---
26
+
27
+ # **DDPG** Agent playing **FetchReachDense-v4**
28
+ - [Github Repository](https://github.com/kuds/rl-fetch)
29
+ - [Google Colab Notebook](https://colab.research.google.com/github/kuds/rl-fetch/blob/main/Fetch/Reach/%5BFetch%20Reach%5D%20Deep%20Deterministic%20Policy%20Gradient%20(DDPG).ipynb)
30
+ - [Finding Theta - Blog Post](https://www.findingtheta.com/blog/mastering-robotic-manipulation-with-reinforcement-learning-tqc-and-ddpg-for-fetch-environments)
31
+
32
+
33
+ Then, you can load the model using the following Python code:
34
+
35
+ ```python
36
+ import gymnasium as gym
37
+ from stable_baselines3 import DDPG
38
+ from stable_baselines3.common.env_util import make_vec_env
39
+
40
+ gymnasium.register_envs(gymnasium_robotics)
41
+
42
+ # Load the trained model
43
+ model = DDPG.load("best-model.zip")
44
+
45
+ # Create the environment
46
+ env = make_vec_env("FetchReachDense-v4", n_envs=1)
47
+
48
+ # Reset the environment
49
+ obs, info = env.reset()
50
+
51
+ # Enjoy the trained agent
52
+ for _ in range(1000):
53
+ action, _states = model.predict(obs, deterministic=True)
54
+ obs, rewards, terminated, truncated, info = env.step(action)
55
+ if terminated or truncated:
56
+ obs, info = env.reset()
57
+ env.render()
58
+ env.close()
59
+ ```
60
+
61
+ ### Hugging Face Hub
62
+
63
+ You can also use the Hugging Face Hub to load the model. First, you need to install the Hugging Face Hub library:
64
+
65
+ ```bash
66
+ pip install huggingface_hub
67
+ ```
68
+
69
+ Then, you can load the model from the hub using the following code:
70
+
71
+ ```python
72
+ from huggingface_hub import hf_hub_download
73
+ import torch as th
74
+ from stable_baselines3 import DDPG
75
+ from stable_baselines3.common.env_util import make_vec_env
76
+
77
+ gymnasium.register_envs(gymnasium_robotics)
78
+
79
+ # Download the model from the Hub
80
+ model_path = hf_hub_download(repo_id="kuds/fetch-reach-ddpg", filename="best-model.zip")
81
+
82
+ # Load the model
83
+ model = DDPG.load(model_path)
84
+
85
+ # Create the environment
86
+ env = make_vec_env("FetchReachDense-v4", n_envs=1)
87
+
88
+ # Enjoy the trained agent
89
+ obs = env.reset()
90
+ for i in range(1000):
91
+ action, _states = model.predict(obs, deterministic=True)
92
+ obs, rewards, dones, info = env.step(action)
93
+ env.render("human")
94
+ ```