File size: 1,618 Bytes
42e022d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
---
license: apache-2.0
tags:
- jax
- safetensors
---

# Baseline PerceptNet

## Model Description

## How to use it

### Install the model's package from source:
```
git clone https://github.com/Jorgvt/paramperceptnet.git
cd paramperceptnet
pip install -e .
```

### 1.Import necessary libraries:

```
import json

from huggingface_hub import hf_hub_download
import flax
import orbax.checkpoint
from ml_collections import ConfigDict

from paramperceptnet.models import Baseline as PerceptNet
```

### 2.Download the configuration

```
config_path = hf_hub_download(repo_id="Jorgvt/ppnet-baseline",
                              filename="config.json")
with open(config_path, "r") as f:
    config = ConfigDict(json.load(f))
```

### 3. Download the weights

#### 3.1. Using `safetensors`

```
from safetensors.flax import load_file

weights_path = hf_hub_download(repo_id="Jorgvt/ppnet-baseline",
                               filename="weights.safetensors")
variables = load_file(weights_path)
variables = flax.traverse_util.unflatten_dict(variables, sep=".")
params = variables["params"]
```

#### 3.2. Using `mgspack`
```
weights_path = hf_hub_download(repo_id="Jorgvt/ppnet-fully-trained",
                               filename="weights.msgpack")
with open(weights_path, "rb") as f:
    variables = orbax.checkpoint.msgpack_utils.msgpack_restore(f.read())
variables = jax.tree_util.tree_map(lambda x: jnp.array(x), variables)
params = variables["params"]
```

### 4. Use the model

```
from jax import numpy as jnp
model = PerceptNet(config)
pred = model.apply({"params": params}, jnp.ones((1,384,512,3)))
```