sdoerrich97 commited on
Commit
725eee2
·
verified ·
1 Parent(s): b8e3842

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +243 -0
README.md ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ library_name: stylizing-vit
4
+ tags:
5
+ - style-transfer
6
+ - medical
7
+ - dermatology
8
+ - domain-generalization
9
+ - vision-transformer
10
+ - pytorch
11
+ - stylizing-vit
12
+ pipeline_tag: image-to-image
13
+ language:
14
+ - en
15
+ metrics:
16
+ - PSNR
17
+ - SSIM
18
+ - FID
19
+ - ArtFID
20
+ - LPIPS
21
+ - Accuracy
22
+ ---
23
+
24
+ # Stylizing ViT Base - Fitzpatrick 17k *(Dermatology)*
25
+
26
+ <!-- Provide a quick summary of what the model is/does. -->
27
+ This model is the **Base** variant of **Stylizing ViT**, trained on the [**Fitzpatrick 17k**](https://github.com/mattgroh/fitzpatrick17k) (dermatology) dataset with the following splits: **Train: {1,2} / Val: {3,4} / Test: {5,6}**.
28
+
29
+ **Stylizing ViT** is a novel Vision Transformer encoder that utilizes weight-shared attention blocks for both self- and cross-attention. This design allows the same attention block to maintain anatomical consistency (via self-attention) while performing style transfer (via cross-attention), enabling anatomy-preserving instance style transfer for domain generalization in medical imaging.
30
+
31
+ ## Model Details
32
+
33
+ ### Model Description
34
+
35
+ deep learning models in medical image analysis often struggle with generalizability across domains and demographic groups due to data heterogeneity and scarcity. Traditional augmentation improves robustness, but fails under substantial domain shifts. Recent advances in stylistic augmentation enhance domain generalization by varying image styles but fall short in terms of style diversity or by introducing artifacts into the generated images.
36
+
37
+ To address these limitations, we propose **Stylizing ViT**, a modality-agnostic style augmentation method. It uses a single-encoder Vision Transformer (ViT) architecture to fuse anatomical structure from a content image with stylistic attributes from a reference image.
38
+
39
+ - **Developed by:** Sebastian Doerrich (xAILab Bamberg, University of Bamberg)
40
+ - **Funded by:** Hightech Agenda Bayern (HTA) of the Free State of Bavaria, Germany
41
+ - **Model type:** Vision Transformer (ViT) with Cross-Attention Mechanism
42
+ - **Language(s):** English (Documentation)
43
+ - **License:** Apache-2.0
44
+
45
+ ### Model Sources
46
+
47
+ - **Repository:** https://github.com/sdoerrich97/stylizing-vit
48
+ - **Paper:**
49
+ - arXiv: https://arxiv.org/abs/2601.17586
50
+
51
+ ## Uses
52
+
53
+ ### Direct Use
54
+
55
+ The primary use case is **style transfer** for medical images. This model takes a content image (e.g., a specific pathology slide) and a style reference, and generates a new image that retains the anatomical content of the first but adopts the visual style (staining, color distribution) of the second.
56
+
57
+ ### Downstream Use
58
+
59
+ - **Data Augmentation:** The model is designed to be used during the training of downstream classifiers (e.g., for cancer detection) to improve domain generalization. By generating stylistically diverse samples, it encourages the classifier to learn shape-aware features rather than relying on spurious color/texture correlations.
60
+ - **Test-Time Augmentation (TTA):** It can be used at inference time to map input images to a known training distribution, improving performance on out-of-distribution data.
61
+
62
+ ### Out-of-Scope Use
63
+
64
+ - **Diagnostic Use:** This model is an augmentation tool, not a diagnostic device. The generated images should not be used for primary diagnosis without expert validation, as artifacts could theoretically mask or create pathological features (though the method optimizes for anatomical preservation).
65
+ - **Non-Medical Style Transfer:** While the architecture is general, this specific checkpoint is trained on the above specified dataset. Using it for artistic style transfer on natural images may yield suboptimal results.
66
+
67
+ ## Bias, Risks, and Limitations
68
+
69
+ ### Limitations
70
+
71
+ - **Artifacts:** While Stylizing ViT outperforms prior methods in reducing artifacts, style transfer can occasionally introduce unnatural textures or lose fine-grained details if the domain gap is too large.
72
+ - **Computational Cost:** Being a transformer-based model, it requires more compute than simple color augmentation techniques.
73
+
74
+ ### Recommendations
75
+
76
+ Users should valid visually that the anatomical structures (e.g., cell boundaries, tissue architecture) are preserved in the stylized output before using the generated data for training sensitive downstream models.
77
+
78
+ ## How to Get Started with the Model
79
+
80
+ You can load this model using the `stylizing-vit` library. Since this model is hosted in its own repository, you can download the weights using `huggingface_hub` and load them into the model.
81
+
82
+ ### Input Requirements
83
+
84
+ The model requires the following input specification:
85
+ - **Resolution:** 224x224 pixels.
86
+ - **Format:** PyTorch Tensor `(B, C, H, W)`.
87
+ - **Normalization:** Images must be normalized (e.g., using ImageNet statistics or dataset-specific mean/std) before inference.
88
+
89
+ ### Installation
90
+
91
+ ```bash
92
+ pip install stylizing-vit
93
+ ```
94
+
95
+ ### Inference Snippet
96
+
97
+ ```python
98
+ import torch
99
+ from stylizing_vit import create_model
100
+
101
+ # Initialize the model
102
+ device = "cuda" if torch.cuda.is_available() else "cpu"
103
+
104
+ # Load the model with pretrained weights
105
+ # This automatically downloads the weights from the Hugging Face Hub
106
+ model = create_model(backbone="base", weights="fitzpatrick17k", train=False).to(device)
107
+ model.eval()
108
+
109
+ # Apply Style Transfer
110
+ # content_img and style_img should be normalized torch tensors of shape (1, 3, 224, 224)
111
+ # with torch.no_grad():
112
+ # stylized_img = model(content_img, style_img)
113
+ ```
114
+
115
+ ## Training Details
116
+
117
+ ### Training Procedure
118
+
119
+ The model is trained to minimize a combination of losses using a frozen VGG19 perceptual network:
120
+ - **Anatomical Loss** ( \\(\lambda_a=7.0\\) ): Preserves structural content.
121
+ - **Style Loss** ( \\(\lambda_s=10.0\\) ): Enforces stylistic similarity to the reference.
122
+ - **Identity Loss** ( \\(\lambda_{id}=70.0\\) ): Ensures reconstruction fidelity when input and style are identical.
123
+ - **Consistency Loss** ( \\(\lambda_c=1.0\\) ): Regularizes the feature space.
124
+
125
+ #### Training Hyperparameters
126
+
127
+ - **Architecture:** `vit_base` (12 layers, 12 heads, 768 embedding dim)
128
+ - **Image Size:** 224x224
129
+ - **Patch Size:** 16
130
+ - **Epochs:** 50
131
+ - **Batch Size:** 64
132
+ - **Optimizer:** AdamW (`timm.optim.create_optimizer_v2`)
133
+ - **LR Scheduler:** Cosine Annealing (`timm.scheduler.CosineLRScheduler`)
134
+
135
+ #### Training Snippet
136
+
137
+ ```python
138
+ import torch
139
+ from accelerate import Accelerator
140
+ from stylizing_vit.model import StylizingViT
141
+
142
+ # 1. Setup
143
+ accelerator = Accelerator()
144
+ device = accelerator.device
145
+ model = StylizingViT(backbone="base", train=True).to(device)
146
+
147
+ # Frozen VGG encoder for loss computation
148
+ model.vgg_encoder.requires_grad_(False)
149
+ model.vgg_encoder.eval()
150
+
151
+ # Optimized parameters: Encoder + Bottleneck + Post-Process Conv
152
+ params = (
153
+ list(model.encoder.parameters()) +
154
+ list(model.bottleneck.parameters()) +
155
+ list(model.post_process_conv.parameters())
156
+ )
157
+ optimizer = torch.optim.AdamW(params, lr=1e-4) # Example LR
158
+
159
+ # 2. Training Loop
160
+ model.train()
161
+ for epoch in range(100):
162
+ for batch in train_loader:
163
+ # images: (B, C, H, W)
164
+ images, _ = batch
165
+ images = images.to(device)
166
+
167
+ # Create style pairs (e.g., by rolling the batch)
168
+ style_images = images.roll(shifts=1, dims=0)
169
+
170
+ # Forward pass returns loss components and reconstructions
171
+ # Model internally computes Identity, Consistency, Anatomical, and Style losses
172
+ loss_dict, _ = model(images, style_images)
173
+
174
+ # Total loss is weighted sum (internal to model return or manually summed)
175
+ total_loss = loss_dict.total_loss
176
+
177
+ optimizer.zero_grad()
178
+ accelerator.backward(total_loss)
179
+ optimizer.step()
180
+ ```
181
+
182
+ #### Data Augmentation Snippet
183
+
184
+ ```python
185
+ import torch
186
+ from stylizing_vit import create_model
187
+
188
+ # Load pre-trained Stylizing ViT
189
+ stylizer = create_model(backbone="base", weights="fitzpatrick17k", train=False)
190
+ stylizer.eval()
191
+ stylizer.requires_grad_(False)
192
+
193
+ def augment_batch(images):
194
+ """
195
+ Augment a batch of images using style transfer.
196
+ """
197
+ # Create style reference (e.g., shuffle current batch)
198
+ style_reference = images[torch.randperm(images.size(0))]
199
+
200
+ with torch.no_grad():
201
+ # Generate stylized images
202
+ # Input images should be normalized
203
+ stylized_images = stylizer(images, style_reference)
204
+
205
+ return stylized_images
206
+
207
+ # Usage in training loop
208
+ # for images, labels in dataloader:
209
+ # augmented_images = augment_batch(images)
210
+ # # Pass augmented_images to your downstream classifier...
211
+ ```
212
+
213
+ ## Evaluation
214
+
215
+ ### Metrics
216
+
217
+ The model is evaluated on:
218
+ - **Reconstruction:** PSNR, SSIM (structure preservation).
219
+ - **Style Transfer:** FID, ArtFID, LPIPS (perceptual quality and diversity).
220
+ - **Classification Performance:** Accuracy.
221
+
222
+ ### Results
223
+
224
+ As reported in the associated ISBI 2026 paper, Stylizing ViT demonstrates improved robustness (up to 13% accuracy gain) over state-of-the-art methods in domain generalization tasks on histopathology and dermatology datasets.
225
+
226
+ ## Citation
227
+
228
+ If you use this model in your research, please cite:
229
+
230
+ ```bibtex
231
+ @article{doerrich2026stylizingvit,
232
+ title={Stylizing ViT: Anatomy-Preserving Instance Style Transfer for Domain Generalization},
233
+ author={Sebastian Doerrich and Francesco Di Salvo and Jonas Alle and Christian Ledig},
234
+ year={2026},
235
+ eprint={2601.17586},
236
+ archivePrefix={arXiv},
237
+ primaryClass={cs.CV}
238
+ }
239
+ ```
240
+
241
+ ## Model Card Contact
242
+
243
+ For questions or issues, please open an issue in the [GitHub repository](https://github.com/sdoerrich97/stylizing-vit).