cgeorgiaw HF Staff commited on
Commit
d58da04
·
verified ·
1 Parent(s): d824ddc

Retrained: batch 16, lambda_cycle=lambda_id=0.125

Browse files
Files changed (5) hide show
  1. README.md +28 -44
  2. config.json +16 -7
  3. full_checkpoint.pth +1 -1
  4. generator.pth +1 -1
  5. training_log.csv +0 -0
README.md CHANGED
@@ -12,47 +12,37 @@ library_name: pytorch
12
  pipeline_tag: image-to-image
13
  ---
14
 
15
- # CIMP Style Transfer Generator (ResNet-18 CIMP, crop 512)
16
 
17
- A metadata-conditioned style-transfer generator that translates HAADF-STEM images between acquisition settings. Given an input image $x$ and a pair of CIMP metadata embeddings $(e_{\text{id}}, e_{\text{tgt}})$, the network produces an image that preserves the content of $x$ but matches the style associated with $e_{\text{tgt}}$ (target dwell time, beam current, detector gain, offset, inner collection angle, or convergence angle).
18
 
19
- Conditioning runs on top of [Stemson-AI/cmmp-resnet18-512](https://huggingface.co/Stemson-AI/cmmp-resnet18-512), our best CIMP visual-metadata encoder. The generator was trained against four objectives: an LSGAN adversarial loss, LPIPS cycle-consistency, LPIPS identity, and a CIMP-space embedding-alignment term.
 
 
 
 
 
20
 
21
  ## Architecture
22
 
23
- - **Generator**: `StyleUNet`, 4-level U-Net with FiLM conditioning at every convolutional block.
24
- - `base_filters = 32`, `embed_dim = 256` (concatenated target + identity CIMP embeddings of 128-d each).
25
- - ~9.9M parameters.
26
- - **Discriminator**: `NoisePatchGAN`, PatchGAN-style with metadata conditioning tiled spatially.
27
- - `base_filters = 64`, `meta_embed_dim = 128`.
28
- - ~0.66M parameters.
29
 
30
  ## Training Configuration
31
 
32
  | Parameter | Value |
33
  |---|---|
34
  | CIMP encoder | [Stemson-AI/cmmp-resnet18-512](https://huggingface.co/Stemson-AI/cmmp-resnet18-512) (frozen) |
35
- | Crop size | 512 $\times$ 512 |
36
- | Batch size | 8 |
37
- | Epochs | 250 (best checkpoint by composite val loss) |
38
  | Optimizer | Adam, $\text{lr}_G = \text{lr}_D = 2 \cdot 10^{-4}$ |
39
- | Adversarial loss | LSGAN ($\mathcal{L}_{\text{LSGAN}}$) |
40
- | Cycle loss | LPIPS round-trip ($\mathcal{L}_{\text{cyc}}$), $\lambda_1 = 0.1$ |
41
- | Identity loss | LPIPS ($\mathcal{L}_{\text{id}}$), $\lambda_2 = 0.1$ |
42
- | Embedding alignment | MSE in CIMP visual-embedding space ($\mathcal{L}_{\text{emb}}$), $\lambda_3 = 0.5$ |
43
  | Hardware | 1 $\times$ H100 |
44
 
45
- ## Best-Epoch Validation Metrics
46
-
47
- | Metric | Best epoch | Value |
48
- |---|---|---|
49
- | Embedding alignment | 234 | 0.00638 |
50
- | Cycle consistency | 40 | 0.00432 |
51
- | Identity stability | 118 | 0.00183 |
52
- | Composite validation | **176** | **0.00391** |
53
-
54
- The uploaded `generator.pth` / `full_checkpoint.pth` correspond to the composite-best checkpoint (epoch 176).
55
-
56
  ## Files
57
 
58
  - `generator.pth` - inference-ready state dict for the `StyleUNet` generator.
@@ -70,37 +60,31 @@ import torch.nn.functional as F
70
 
71
  device = "cuda"
72
 
73
- # 1. Load the CIMP encoder (provides the metadata embeddings)
74
  cimp = CMMP(
75
- meta_input_dim=7,
76
- embed_dim=128,
77
- image_encoder="resnet18",
78
- image_size=512,
79
- meta_hidden_dim=256,
80
- meta_num_layers=3,
81
  ).to(device)
82
- cimp_path = hf_hub_download("Stemson-AI/cmmp-resnet18-512", "model.pth")
83
- cimp.load_state_dict(torch.load(cimp_path, map_location=device))
84
  cimp.eval()
85
 
86
- # 2. Load the style-transfer generator
87
  gen = StyleUNet(embed_dim=256, in_channels=1, out_channels=1, base_filters=32, use_FiLM=True).to(device)
88
- gen_path = hf_hub_download("Stemson-AI/cimp-style-transfer-512", "generator.pth")
89
- gen.load_state_dict(torch.load(gen_path, map_location=device))
90
  gen.eval()
91
 
92
- # 3. Generate: x is a (1, 1, 512, 512) grayscale image in [0, 1]
93
- # src_meta and tgt_meta are 7-d z-scored metadata vectors (1, 7)
94
  with torch.no_grad():
95
  e_id = F.normalize(cimp.meta(src_meta.to(device)), p=2, dim=-1)
96
  e_tgt = F.normalize(cimp.meta(tgt_meta.to(device)), p=2, dim=-1)
97
- y = gen(x.to(device), e_tgt, e_id) # (1, 1, 512, 512) in [-1, 1] (Tanh output)
98
  ```
99
 
100
  ## Related Models
101
 
102
- - [Stemson-AI/cmmp-resnet18-512](https://huggingface.co/Stemson-AI/cmmp-resnet18-512) - the CIMP encoder this generator conditions on (required for inference).
103
- - [Stemson-AI/cmmp-resnet18-256](https://huggingface.co/Stemson-AI/cmmp-resnet18-256) - earlier ResNet-18 CIMP variant at crop 256.
104
 
105
  ## Citation
106
 
 
12
  pipeline_tag: image-to-image
13
  ---
14
 
15
+ # CIMP Style Transfer Generator (ResNet-18 CIMP, crop 512, batch 16)
16
 
17
+ A metadata-conditioned style-transfer generator that translates HAADF-STEM images between acquisition settings. Given an input image $x$ and a pair of CIMP metadata embeddings $(e_{\text{id}}, e_{\text{tgt}})$, the network produces an image that preserves the content of $x$ but matches the style associated with $e_{\text{tgt}}$.
18
 
19
+ Conditioning runs on top of [Stemson-AI/cmmp-resnet18-512](https://huggingface.co/Stemson-AI/cmmp-resnet18-512). Training used four objectives: LSGAN adversarial, LPIPS cycle-consistency, LPIPS identity, and a CIMP-space embedding-alignment term.
20
+
21
+ ## What's new vs. the previous upload
22
+
23
+ - `batch_size` increased from 8 to 16.
24
+ - `lambda_cycle` and `lambda_id` increased from 0.1 to 0.125 to put slightly more weight on content preservation.
25
 
26
  ## Architecture
27
 
28
+ - **Generator**: `StyleUNet` with FiLM conditioning at every convolutional block. `base_filters=32`, `embed_dim=256` (concatenated 128-d target + identity CIMP embeddings). ~9.9M params.
29
+ - **Discriminator**: `NoisePatchGAN`, `base_filters=64`, `meta_embed_dim=128`. ~0.66M params.
 
 
 
 
30
 
31
  ## Training Configuration
32
 
33
  | Parameter | Value |
34
  |---|---|
35
  | CIMP encoder | [Stemson-AI/cmmp-resnet18-512](https://huggingface.co/Stemson-AI/cmmp-resnet18-512) (frozen) |
36
+ | Crop size | $512 \times 512$ |
37
+ | Batch size | 16 |
38
+ | Epochs | 250 |
39
  | Optimizer | Adam, $\text{lr}_G = \text{lr}_D = 2 \cdot 10^{-4}$ |
40
+ | LSGAN | $\mathcal{L}_{\text{LSGAN}}$ |
41
+ | Cycle | LPIPS, $\lambda_1 = 0.125$ |
42
+ | Identity | LPIPS, $\lambda_2 = 0.125$ |
43
+ | Emb alignment | MSE in CIMP visual-embedding space, $\lambda_3 = 0.5$ |
44
  | Hardware | 1 $\times$ H100 |
45
 
 
 
 
 
 
 
 
 
 
 
 
46
  ## Files
47
 
48
  - `generator.pth` - inference-ready state dict for the `StyleUNet` generator.
 
60
 
61
  device = "cuda"
62
 
 
63
  cimp = CMMP(
64
+ meta_input_dim=7, embed_dim=128,
65
+ image_encoder="resnet18", image_size=512,
66
+ meta_hidden_dim=256, meta_num_layers=3,
 
 
 
67
  ).to(device)
68
+ cimp.load_state_dict(torch.load(hf_hub_download("Stemson-AI/cmmp-resnet18-512", "model.pth"),
69
+ map_location=device))
70
  cimp.eval()
71
 
 
72
  gen = StyleUNet(embed_dim=256, in_channels=1, out_channels=1, base_filters=32, use_FiLM=True).to(device)
73
+ gen.load_state_dict(torch.load(hf_hub_download("Stemson-AI/cimp-style-transfer-512", "generator.pth"),
74
+ map_location=device))
75
  gen.eval()
76
 
77
+ # x: (1, 1, 512, 512) in [0, 1]; src_meta, tgt_meta: (1, 7) z-scored
 
78
  with torch.no_grad():
79
  e_id = F.normalize(cimp.meta(src_meta.to(device)), p=2, dim=-1)
80
  e_tgt = F.normalize(cimp.meta(tgt_meta.to(device)), p=2, dim=-1)
81
+ y = gen(x.to(device), e_tgt, e_id)
82
  ```
83
 
84
  ## Related Models
85
 
86
+ - [Stemson-AI/cmmp-resnet18-512](https://huggingface.co/Stemson-AI/cmmp-resnet18-512) - the CIMP encoder.
87
+ - [Stemson-AI/cmmp-resnet18-256](https://huggingface.co/Stemson-AI/cmmp-resnet18-256) - earlier CIMP variant.
88
 
89
  ## Citation
90
 
config.json CHANGED
@@ -22,19 +22,28 @@
22
  },
23
  "training": {
24
  "crop_size": 512,
25
- "batch_size": 8,
26
  "epochs": 250,
27
- "lr_G": 2e-4,
28
- "lr_D": 2e-4,
29
  "lambda_GAN": 1.0,
30
- "lambda_cycle": 0.1,
31
  "lambda_emb": 0.5,
32
- "lambda_id": 0.1,
33
  "id_loss_fn": "lpips",
34
  "cycle_loss_fn": "lpips",
35
  "optimizer": "Adam",
36
  "adversarial_loss": "LSGAN"
37
  },
38
  "meta_dim": 7,
39
- "meta_names": ["pixel_size", "dwell_time", "convergence_angle", "beam_current", "gain", "offset", "inner_coll_angle"]
40
- }
 
 
 
 
 
 
 
 
 
 
22
  },
23
  "training": {
24
  "crop_size": 512,
25
+ "batch_size": 16,
26
  "epochs": 250,
27
+ "lr_G": 0.0002,
28
+ "lr_D": 0.0002,
29
  "lambda_GAN": 1.0,
30
+ "lambda_cycle": 0.125,
31
  "lambda_emb": 0.5,
32
+ "lambda_id": 0.125,
33
  "id_loss_fn": "lpips",
34
  "cycle_loss_fn": "lpips",
35
  "optimizer": "Adam",
36
  "adversarial_loss": "LSGAN"
37
  },
38
  "meta_dim": 7,
39
+ "meta_names": [
40
+ "pixel_size",
41
+ "dwell_time",
42
+ "convergence_angle",
43
+ "beam_current",
44
+ "gain",
45
+ "offset",
46
+ "inner_coll_angle"
47
+ ],
48
+ "best_epoch_composite_val": 216
49
+ }
full_checkpoint.pth CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:fa0cda61783480aa28d683fb7d388d6950bdace9cf8f53eea63f6e50c1a0c410
3
  size 42249339
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e687972a154b1f8a765e5ebd7dda8dde706cac83e13305295bcd0728004c8426
3
  size 42249339
generator.pth CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:9632cd2d5a858200654e95cefd29ed3870dbd7bc360259cbcd59bfb1eb596555
3
  size 39597771
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:633756be32f683595ce9f61a2c91ed04a90de4a052b0005ed94083a9660a34a5
3
  size 39597771
training_log.csv CHANGED
The diff for this file is too large to render. See raw diff