|
|
--- |
|
|
license: apache-2.0 |
|
|
tags: |
|
|
- jax |
|
|
- safetensors |
|
|
--- |
|
|
|
|
|
# Baseline PerceptNet |
|
|
|
|
|
## Model Description |
|
|
|
|
|
## How to use it |
|
|
|
|
|
### Install the model's package from source: |
|
|
``` |
|
|
git clone https://github.com/Jorgvt/paramperceptnet.git |
|
|
cd paramperceptnet |
|
|
pip install -e . |
|
|
``` |
|
|
|
|
|
### 1.Import necessary libraries: |
|
|
|
|
|
``` |
|
|
import json |
|
|
|
|
|
from huggingface_hub import hf_hub_download |
|
|
import flax |
|
|
import orbax.checkpoint |
|
|
from ml_collections import ConfigDict |
|
|
|
|
|
from paramperceptnet.models import Baseline as PerceptNet |
|
|
``` |
|
|
|
|
|
### 2.Download the configuration |
|
|
|
|
|
``` |
|
|
config_path = hf_hub_download(repo_id="Jorgvt/ppnet-baseline", |
|
|
filename="config.json") |
|
|
with open(config_path, "r") as f: |
|
|
config = ConfigDict(json.load(f)) |
|
|
``` |
|
|
|
|
|
### 3. Download the weights |
|
|
|
|
|
#### 3.1. Using `safetensors` |
|
|
|
|
|
``` |
|
|
from safetensors.flax import load_file |
|
|
|
|
|
weights_path = hf_hub_download(repo_id="Jorgvt/ppnet-baseline", |
|
|
filename="weights.safetensors") |
|
|
variables = load_file(weights_path) |
|
|
variables = flax.traverse_util.unflatten_dict(variables, sep=".") |
|
|
params = variables["params"] |
|
|
``` |
|
|
|
|
|
#### 3.2. Using `mgspack` |
|
|
``` |
|
|
weights_path = hf_hub_download(repo_id="Jorgvt/ppnet-fully-trained", |
|
|
filename="weights.msgpack") |
|
|
with open(weights_path, "rb") as f: |
|
|
variables = orbax.checkpoint.msgpack_utils.msgpack_restore(f.read()) |
|
|
variables = jax.tree_util.tree_map(lambda x: jnp.array(x), variables) |
|
|
params = variables["params"] |
|
|
``` |
|
|
|
|
|
### 4. Use the model |
|
|
|
|
|
``` |
|
|
from jax import numpy as jnp |
|
|
model = PerceptNet(config) |
|
|
pred = model.apply({"params": params}, jnp.ones((1,384,512,3))) |
|
|
``` |
|
|
|