Retrained: batch 16, lambda_cycle=lambda_id=0.125
Browse files- README.md +28 -44
- config.json +16 -7
- full_checkpoint.pth +1 -1
- generator.pth +1 -1
- training_log.csv +0 -0
README.md
CHANGED
|
@@ -12,47 +12,37 @@ library_name: pytorch
|
|
| 12 |
pipeline_tag: image-to-image
|
| 13 |
---
|
| 14 |
|
| 15 |
-
# CIMP Style Transfer Generator (ResNet-18 CIMP, crop 512)
|
| 16 |
|
| 17 |
-
A metadata-conditioned style-transfer generator that translates HAADF-STEM images between acquisition settings. Given an input image $x$ and a pair of CIMP metadata embeddings $(e_{\text{id}}, e_{\text{tgt}})$, the network produces an image that preserves the content of $x$ but matches the style associated with $e_{\text{tgt}}$
|
| 18 |
|
| 19 |
-
Conditioning runs on top of [Stemson-AI/cmmp-resnet18-512](https://huggingface.co/Stemson-AI/cmmp-resnet18-512)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
## Architecture
|
| 22 |
|
| 23 |
-
- **Generator**: `StyleUNet`
|
| 24 |
-
|
| 25 |
-
- ~9.9M parameters.
|
| 26 |
-
- **Discriminator**: `NoisePatchGAN`, PatchGAN-style with metadata conditioning tiled spatially.
|
| 27 |
-
- `base_filters = 64`, `meta_embed_dim = 128`.
|
| 28 |
-
- ~0.66M parameters.
|
| 29 |
|
| 30 |
## Training Configuration
|
| 31 |
|
| 32 |
| Parameter | Value |
|
| 33 |
|---|---|
|
| 34 |
| CIMP encoder | [Stemson-AI/cmmp-resnet18-512](https://huggingface.co/Stemson-AI/cmmp-resnet18-512) (frozen) |
|
| 35 |
-
| Crop size | 512
|
| 36 |
-
| Batch size |
|
| 37 |
-
| Epochs | 250
|
| 38 |
| Optimizer | Adam, $\text{lr}_G = \text{lr}_D = 2 \cdot 10^{-4}$ |
|
| 39 |
-
|
|
| 40 |
-
| Cycle
|
| 41 |
-
| Identity
|
| 42 |
-
|
|
| 43 |
| Hardware | 1 $\times$ H100 |
|
| 44 |
|
| 45 |
-
## Best-Epoch Validation Metrics
|
| 46 |
-
|
| 47 |
-
| Metric | Best epoch | Value |
|
| 48 |
-
|---|---|---|
|
| 49 |
-
| Embedding alignment | 234 | 0.00638 |
|
| 50 |
-
| Cycle consistency | 40 | 0.00432 |
|
| 51 |
-
| Identity stability | 118 | 0.00183 |
|
| 52 |
-
| Composite validation | **176** | **0.00391** |
|
| 53 |
-
|
| 54 |
-
The uploaded `generator.pth` / `full_checkpoint.pth` correspond to the composite-best checkpoint (epoch 176).
|
| 55 |
-
|
| 56 |
## Files
|
| 57 |
|
| 58 |
- `generator.pth` - inference-ready state dict for the `StyleUNet` generator.
|
|
@@ -70,37 +60,31 @@ import torch.nn.functional as F
|
|
| 70 |
|
| 71 |
device = "cuda"
|
| 72 |
|
| 73 |
-
# 1. Load the CIMP encoder (provides the metadata embeddings)
|
| 74 |
cimp = CMMP(
|
| 75 |
-
meta_input_dim=7,
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
image_size=512,
|
| 79 |
-
meta_hidden_dim=256,
|
| 80 |
-
meta_num_layers=3,
|
| 81 |
).to(device)
|
| 82 |
-
|
| 83 |
-
|
| 84 |
cimp.eval()
|
| 85 |
|
| 86 |
-
# 2. Load the style-transfer generator
|
| 87 |
gen = StyleUNet(embed_dim=256, in_channels=1, out_channels=1, base_filters=32, use_FiLM=True).to(device)
|
| 88 |
-
|
| 89 |
-
|
| 90 |
gen.eval()
|
| 91 |
|
| 92 |
-
#
|
| 93 |
-
# src_meta and tgt_meta are 7-d z-scored metadata vectors (1, 7)
|
| 94 |
with torch.no_grad():
|
| 95 |
e_id = F.normalize(cimp.meta(src_meta.to(device)), p=2, dim=-1)
|
| 96 |
e_tgt = F.normalize(cimp.meta(tgt_meta.to(device)), p=2, dim=-1)
|
| 97 |
-
y = gen(x.to(device), e_tgt, e_id)
|
| 98 |
```
|
| 99 |
|
| 100 |
## Related Models
|
| 101 |
|
| 102 |
-
- [Stemson-AI/cmmp-resnet18-512](https://huggingface.co/Stemson-AI/cmmp-resnet18-512) - the CIMP encoder
|
| 103 |
-
- [Stemson-AI/cmmp-resnet18-256](https://huggingface.co/Stemson-AI/cmmp-resnet18-256) - earlier
|
| 104 |
|
| 105 |
## Citation
|
| 106 |
|
|
|
|
| 12 |
pipeline_tag: image-to-image
|
| 13 |
---
|
| 14 |
|
| 15 |
+
# CIMP Style Transfer Generator (ResNet-18 CIMP, crop 512, batch 16)
|
| 16 |
|
| 17 |
+
A metadata-conditioned style-transfer generator that translates HAADF-STEM images between acquisition settings. Given an input image $x$ and a pair of CIMP metadata embeddings $(e_{\text{id}}, e_{\text{tgt}})$, the network produces an image that preserves the content of $x$ but matches the style associated with $e_{\text{tgt}}$.
|
| 18 |
|
| 19 |
+
Conditioning runs on top of [Stemson-AI/cmmp-resnet18-512](https://huggingface.co/Stemson-AI/cmmp-resnet18-512). Training used four objectives: LSGAN adversarial, LPIPS cycle-consistency, LPIPS identity, and a CIMP-space embedding-alignment term.
|
| 20 |
+
|
| 21 |
+
## What's new vs. the previous upload
|
| 22 |
+
|
| 23 |
+
- `batch_size` increased from 8 to 16.
|
| 24 |
+
- `lambda_cycle` and `lambda_id` increased from 0.1 to 0.125 to put slightly more weight on content preservation.
|
| 25 |
|
| 26 |
## Architecture
|
| 27 |
|
| 28 |
+
- **Generator**: `StyleUNet` with FiLM conditioning at every convolutional block. `base_filters=32`, `embed_dim=256` (concatenated 128-d target + identity CIMP embeddings). ~9.9M params.
|
| 29 |
+
- **Discriminator**: `NoisePatchGAN`, `base_filters=64`, `meta_embed_dim=128`. ~0.66M params.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
## Training Configuration
|
| 32 |
|
| 33 |
| Parameter | Value |
|
| 34 |
|---|---|
|
| 35 |
| CIMP encoder | [Stemson-AI/cmmp-resnet18-512](https://huggingface.co/Stemson-AI/cmmp-resnet18-512) (frozen) |
|
| 36 |
+
| Crop size | $512 \times 512$ |
|
| 37 |
+
| Batch size | 16 |
|
| 38 |
+
| Epochs | 250 |
|
| 39 |
| Optimizer | Adam, $\text{lr}_G = \text{lr}_D = 2 \cdot 10^{-4}$ |
|
| 40 |
+
| LSGAN | $\mathcal{L}_{\text{LSGAN}}$ |
|
| 41 |
+
| Cycle | LPIPS, $\lambda_1 = 0.125$ |
|
| 42 |
+
| Identity | LPIPS, $\lambda_2 = 0.125$ |
|
| 43 |
+
| Emb alignment | MSE in CIMP visual-embedding space, $\lambda_3 = 0.5$ |
|
| 44 |
| Hardware | 1 $\times$ H100 |
|
| 45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
## Files
|
| 47 |
|
| 48 |
- `generator.pth` - inference-ready state dict for the `StyleUNet` generator.
|
|
|
|
| 60 |
|
| 61 |
device = "cuda"
|
| 62 |
|
|
|
|
| 63 |
cimp = CMMP(
|
| 64 |
+
meta_input_dim=7, embed_dim=128,
|
| 65 |
+
image_encoder="resnet18", image_size=512,
|
| 66 |
+
meta_hidden_dim=256, meta_num_layers=3,
|
|
|
|
|
|
|
|
|
|
| 67 |
).to(device)
|
| 68 |
+
cimp.load_state_dict(torch.load(hf_hub_download("Stemson-AI/cmmp-resnet18-512", "model.pth"),
|
| 69 |
+
map_location=device))
|
| 70 |
cimp.eval()
|
| 71 |
|
|
|
|
| 72 |
gen = StyleUNet(embed_dim=256, in_channels=1, out_channels=1, base_filters=32, use_FiLM=True).to(device)
|
| 73 |
+
gen.load_state_dict(torch.load(hf_hub_download("Stemson-AI/cimp-style-transfer-512", "generator.pth"),
|
| 74 |
+
map_location=device))
|
| 75 |
gen.eval()
|
| 76 |
|
| 77 |
+
# x: (1, 1, 512, 512) in [0, 1]; src_meta, tgt_meta: (1, 7) z-scored
|
|
|
|
| 78 |
with torch.no_grad():
|
| 79 |
e_id = F.normalize(cimp.meta(src_meta.to(device)), p=2, dim=-1)
|
| 80 |
e_tgt = F.normalize(cimp.meta(tgt_meta.to(device)), p=2, dim=-1)
|
| 81 |
+
y = gen(x.to(device), e_tgt, e_id)
|
| 82 |
```
|
| 83 |
|
| 84 |
## Related Models
|
| 85 |
|
| 86 |
+
- [Stemson-AI/cmmp-resnet18-512](https://huggingface.co/Stemson-AI/cmmp-resnet18-512) - the CIMP encoder.
|
| 87 |
+
- [Stemson-AI/cmmp-resnet18-256](https://huggingface.co/Stemson-AI/cmmp-resnet18-256) - earlier CIMP variant.
|
| 88 |
|
| 89 |
## Citation
|
| 90 |
|
config.json
CHANGED
|
@@ -22,19 +22,28 @@
|
|
| 22 |
},
|
| 23 |
"training": {
|
| 24 |
"crop_size": 512,
|
| 25 |
-
"batch_size":
|
| 26 |
"epochs": 250,
|
| 27 |
-
"lr_G":
|
| 28 |
-
"lr_D":
|
| 29 |
"lambda_GAN": 1.0,
|
| 30 |
-
"lambda_cycle": 0.
|
| 31 |
"lambda_emb": 0.5,
|
| 32 |
-
"lambda_id": 0.
|
| 33 |
"id_loss_fn": "lpips",
|
| 34 |
"cycle_loss_fn": "lpips",
|
| 35 |
"optimizer": "Adam",
|
| 36 |
"adversarial_loss": "LSGAN"
|
| 37 |
},
|
| 38 |
"meta_dim": 7,
|
| 39 |
-
"meta_names": [
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
},
|
| 23 |
"training": {
|
| 24 |
"crop_size": 512,
|
| 25 |
+
"batch_size": 16,
|
| 26 |
"epochs": 250,
|
| 27 |
+
"lr_G": 0.0002,
|
| 28 |
+
"lr_D": 0.0002,
|
| 29 |
"lambda_GAN": 1.0,
|
| 30 |
+
"lambda_cycle": 0.125,
|
| 31 |
"lambda_emb": 0.5,
|
| 32 |
+
"lambda_id": 0.125,
|
| 33 |
"id_loss_fn": "lpips",
|
| 34 |
"cycle_loss_fn": "lpips",
|
| 35 |
"optimizer": "Adam",
|
| 36 |
"adversarial_loss": "LSGAN"
|
| 37 |
},
|
| 38 |
"meta_dim": 7,
|
| 39 |
+
"meta_names": [
|
| 40 |
+
"pixel_size",
|
| 41 |
+
"dwell_time",
|
| 42 |
+
"convergence_angle",
|
| 43 |
+
"beam_current",
|
| 44 |
+
"gain",
|
| 45 |
+
"offset",
|
| 46 |
+
"inner_coll_angle"
|
| 47 |
+
],
|
| 48 |
+
"best_epoch_composite_val": 216
|
| 49 |
+
}
|
full_checkpoint.pth
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 42249339
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e687972a154b1f8a765e5ebd7dda8dde706cac83e13305295bcd0728004c8426
|
| 3 |
size 42249339
|
generator.pth
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 39597771
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:633756be32f683595ce9f61a2c91ed04a90de4a052b0005ed94083a9660a34a5
|
| 3 |
size 39597771
|
training_log.csv
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|