improved
Browse files
README.md
CHANGED
|
@@ -51,6 +51,8 @@ weights_path = hf_hub_download(repo_id="Jorgvt/ppnet-bio-fitted",
|
|
| 51 |
filename="weights.safetensors")
|
| 52 |
variables = load_file(weights_path)
|
| 53 |
variables = flax.traverse_util.unflatten_dict(variables, sep=".")
|
|
|
|
|
|
|
| 54 |
```
|
| 55 |
|
| 56 |
#### 3.2. Using `mgspack`
|
|
@@ -59,10 +61,9 @@ weights_path = hf_hub_download(repo_id="Jorgvt/ppnet-bio-fitted",
|
|
| 59 |
filename="weights.msgpack")
|
| 60 |
with open(weights_path, "rb") as f:
|
| 61 |
variables = orbax.checkpoint.msgpack_utils.msgpack_restore(f.read())
|
|
|
|
| 62 |
state = variables["state"]
|
| 63 |
params = variables["params"]
|
| 64 |
-
state = jax.tree_util.tree_map(lambda x: jnp.array(x), state)
|
| 65 |
-
params = jax.tree_util.tree_map(lambda x: jnp.array(x), params)
|
| 66 |
```
|
| 67 |
|
| 68 |
### 4. Use the model
|
|
|
|
| 51 |
filename="weights.safetensors")
|
| 52 |
variables = load_file(weights_path)
|
| 53 |
variables = flax.traverse_util.unflatten_dict(variables, sep=".")
|
| 54 |
+
state = variables["state"]
|
| 55 |
+
params = variables["params"]
|
| 56 |
```
|
| 57 |
|
| 58 |
#### 3.2. Using `mgspack`
|
|
|
|
| 61 |
filename="weights.msgpack")
|
| 62 |
with open(weights_path, "rb") as f:
|
| 63 |
variables = orbax.checkpoint.msgpack_utils.msgpack_restore(f.read())
|
| 64 |
+
variables = jax.tree_util.tree_map(lambda x: jnp.array(x), variables)
|
| 65 |
state = variables["state"]
|
| 66 |
params = variables["params"]
|
|
|
|
|
|
|
| 67 |
```
|
| 68 |
|
| 69 |
### 4. Use the model
|