jcopo commited on
Commit
f7a97bc
·
verified ·
1 Parent(s): bc1a35a

Upload model at step 45000

Browse files
Files changed (2) hide show
  1. README.md +54 -0
  2. config.py +23 -0
README.md ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - jax
4
+ - flax
5
+ - flax-nnx
6
+ library_name: triax
7
+ ---
8
+
9
+ # mnist
10
+
11
+ Model trained using [Triax](https://github.com/your-org/triax), a JAX/Flax training framework.
12
+
13
+ ## Model Details
14
+
15
+ - **Model Type**: CondUNet2D
16
+ - **Training Step**: 45,000
17
+ - **Precision**: float32
18
+ - **Framework**: JAX/Flax (NNX)
19
+ - **Format**: msgpack
20
+
21
+ ## Usage
22
+
23
+ ```python
24
+ from flax import nnx, serialization
25
+ from huggingface_hub import hf_hub_download
26
+ import importlib.util
27
+
28
+ # Download model weights and config
29
+ model_path = hf_hub_download(repo_id="jcopo/mnist", filename="model.msgpack")
30
+ config_path = hf_hub_download(repo_id="jcopo/mnist", filename="config.py")
31
+
32
+ # Load config to get model architecture
33
+ spec = importlib.util.spec_from_file_location("model_config", config_path)
34
+ config_module = importlib.util.module_from_spec(spec)
35
+ spec.loader.exec_module(config_module)
36
+
37
+ # Initialize model from config
38
+ model = config_module.model
39
+
40
+ # Load weights
41
+ with open(model_path, "rb") as f:
42
+ state_dict = serialization.from_bytes(None, f.read())
43
+
44
+ # Restore weights into model
45
+ nnx.update(model, state_dict)
46
+ model.eval() # Set to evaluation mode
47
+
48
+ # Now use the model for inference
49
+ # output = model(input_data)
50
+ ```
51
+
52
+ ## Training Configuration
53
+
54
+ This model was trained with the Triax framework using the configuration saved in the checkpoint.
config.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Model configuration for jcopo/mnist
2
+
3
+ This file contains the model architecture definition.
4
+ Training step: 45000
5
+ Precision: float32
6
+ """
7
+
8
+ from triax.models.nn.condUNet import CondUNet2D
9
+
10
+ # Model architecture
11
+ # TODO: Fill in the actual initialization parameters from your training config
12
+ model = CondUNet2D(
13
+ # Add your model parameters here
14
+ # Example:
15
+ # hidden_dim=256,
16
+ # num_layers=4,
17
+ # etc.
18
+ )
19
+
20
+ # Metadata
21
+ STEP = 45000
22
+ PRECISION = "float32"
23
+ MODEL_TYPE = "CondUNet2D"