--- license: apache-2.0 library_name: stylizing-vit tags: - style-transfer - medical - dermatology - domain-generalization - vision-transformer - pytorch - stylizing-vit pipeline_tag: image-to-image language: - en metrics: - PSNR - SSIM - FID - ArtFID - LPIPS - Accuracy --- # Stylizing ViT Tiny - Fitzpatrick 17k *(Dermatology)* This model is the **Tiny** variant of **Stylizing ViT**, trained on the [**Fitzpatrick 17k**](https://github.com/mattgroh/fitzpatrick17k) (dermatology) dataset with the following splits: **Train: {5,6} / Val: {3,4} / Test: {1,2}**. **Stylizing ViT** is a novel Vision Transformer encoder that utilizes weight-shared attention blocks for both self- and cross-attention. This design allows the same attention block to maintain anatomical consistency (via self-attention) while performing style transfer (via cross-attention), enabling anatomy-preserving instance style transfer for domain generalization in medical imaging. ## Model Details ### Model Description deep learning models in medical image analysis often struggle with generalizability across domains and demographic groups due to data heterogeneity and scarcity. Traditional augmentation improves robustness, but fails under substantial domain shifts. Recent advances in stylistic augmentation enhance domain generalization by varying image styles but fall short in terms of style diversity or by introducing artifacts into the generated images. To address these limitations, we propose **Stylizing ViT**, a modality-agnostic style augmentation method. It uses a single-encoder Vision Transformer (ViT) architecture to fuse anatomical structure from a content image with stylistic attributes from a reference image. - **Developed by:** Sebastian Doerrich (xAILab Bamberg, University of Bamberg) - **Funded by:** Hightech Agenda Bayern (HTA) of the Free State of Bavaria, Germany - **Model type:** Vision Transformer (ViT) with Cross-Attention Mechanism - **Language(s):** English (Documentation) - **License:** Apache-2.0 ### Model Sources - **Repository:** https://github.com/sdoerrich97/stylizing-vit - **Paper:** - arXiv: https://arxiv.org/abs/2601.17586 ## Uses ### Direct Use The primary use case is **style transfer** for medical images. This model takes a content image (e.g., a specific pathology slide) and a style reference, and generates a new image that retains the anatomical content of the first but adopts the visual style (staining, color distribution) of the second. ### Downstream Use - **Data Augmentation:** The model is designed to be used during the training of downstream classifiers (e.g., for cancer detection) to improve domain generalization. By generating stylistically diverse samples, it encourages the classifier to learn shape-aware features rather than relying on spurious color/texture correlations. - **Test-Time Augmentation (TTA):** It can be used at inference time to map input images to a known training distribution, improving performance on out-of-distribution data. ### Out-of-Scope Use - **Diagnostic Use:** This model is an augmentation tool, not a diagnostic device. The generated images should not be used for primary diagnosis without expert validation, as artifacts could theoretically mask or create pathological features (though the method optimizes for anatomical preservation). - **Non-Medical Style Transfer:** While the architecture is general, this specific checkpoint is trained on the above specified dataset. Using it for artistic style transfer on natural images may yield suboptimal results. ## Bias, Risks, and Limitations ### Limitations - **Artifacts:** While Stylizing ViT outperforms prior methods in reducing artifacts, style transfer can occasionally introduce unnatural textures or lose fine-grained details if the domain gap is too large. - **Computational Cost:** Being a transformer-based model, it requires more compute than simple color augmentation techniques. ### Recommendations Users should valid visually that the anatomical structures (e.g., cell boundaries, tissue architecture) are preserved in the stylized output before using the generated data for training sensitive downstream models. ## How to Get Started with the Model You can load this model using the `stylizing-vit` library. Since this model is hosted in its own repository, you can download the weights using `huggingface_hub` and load them into the model. ### Input Requirements The model requires the following input specification: - **Resolution:** 224x224 pixels. - **Format:** PyTorch Tensor `(B, C, H, W)`. - **Normalization:** Images must be normalized (e.g., using ImageNet statistics or dataset-specific mean/std) before inference. ### Installation ```bash pip install stylizing-vit ``` ### Inference Snippet ```python import torch from stylizing_vit import create_model # Initialize the model device = "cuda" if torch.cuda.is_available() else "cpu" # Load the model with pretrained weights # This automatically downloads the weights from the Hugging Face Hub model = create_model(backbone="tiny", weights="fitzpatrick17k_65_43_21", train=False).to(device) model.eval() # Apply Style Transfer # content_img and style_img should be normalized torch tensors of shape (1, 3, 224, 224) # with torch.no_grad(): # stylized_img = model(content_img, style_img) ``` ## Training Details ### Training Procedure The model is trained to minimize a combination of losses using a frozen VGG19 perceptual network: - **Anatomical Loss** ( \\(\lambda_a=7.0\\) ): Preserves structural content. - **Style Loss** ( \\(\lambda_s=10.0\\) ): Enforces stylistic similarity to the reference. - **Identity Loss** ( \\(\lambda_{id}=70.0\\) ): Ensures reconstruction fidelity when input and style are identical. - **Consistency Loss** ( \\(\lambda_c=1.0\\) ): Regularizes the feature space. #### Training Hyperparameters - **Architecture:** `vit_tiny` (12 layers, 3 heads, 192 embedding dim) - **Image Size:** 224x224 - **Patch Size:** 16 - **Epochs:** 50 - **Batch Size:** 64 - **Optimizer:** AdamW (`timm.optim.create_optimizer_v2`) - **LR Scheduler:** Cosine Annealing (`timm.scheduler.CosineLRScheduler`) #### Training Snippet ```python import torch from accelerate import Accelerator from stylizing_vit.model import StylizingViT # 1. Setup accelerator = Accelerator() device = accelerator.device model = StylizingViT(backbone="tiny", train=True).to(device) # Frozen VGG encoder for loss computation model.vgg_encoder.requires_grad_(False) model.vgg_encoder.eval() # Optimized parameters: Encoder + Bottleneck + Post-Process Conv params = ( list(model.encoder.parameters()) + list(model.bottleneck.parameters()) + list(model.post_process_conv.parameters()) ) optimizer = torch.optim.AdamW(params, lr=1e-4) # Example LR # 2. Training Loop model.train() for epoch in range(100): for batch in train_loader: # images: (B, C, H, W) images, _ = batch images = images.to(device) # Create style pairs (e.g., by rolling the batch) style_images = images.roll(shifts=1, dims=0) # Forward pass returns loss components and reconstructions # Model internally computes Identity, Consistency, Anatomical, and Style losses loss_dict, _ = model(images, style_images) # Total loss is weighted sum (internal to model return or manually summed) total_loss = loss_dict.total_loss optimizer.zero_grad() accelerator.backward(total_loss) optimizer.step() ``` #### Data Augmentation Snippet ```python import torch from stylizing_vit import create_model # Load pre-trained Stylizing ViT stylizer = create_model(backbone="tiny", weights="fitzpatrick17k_65_43_21", train=False) stylizer.eval() stylizer.requires_grad_(False) def augment_batch(images): """ Augment a batch of images using style transfer. """ # Create style reference (e.g., shuffle current batch) style_reference = images[torch.randperm(images.size(0))] with torch.no_grad(): # Generate stylized images # Input images should be normalized stylized_images = stylizer(images, style_reference) return stylized_images # Usage in training loop # for images, labels in dataloader: # augmented_images = augment_batch(images) # # Pass augmented_images to your downstream classifier... ``` ## Evaluation ### Metrics The model is evaluated on: - **Reconstruction:** PSNR, SSIM (structure preservation). - **Style Transfer:** FID, ArtFID, LPIPS (perceptual quality and diversity). - **Classification Performance:** Accuracy. ### Results As reported in the associated ISBI 2026 paper, Stylizing ViT demonstrates improved robustness (up to 13% accuracy gain) over state-of-the-art methods in domain generalization tasks on histopathology and dermatology datasets. ## Citation If you use this model in your research, please cite: ```bibtex @article{doerrich2026stylizingvit, title={Stylizing ViT: Anatomy-Preserving Instance Style Transfer for Domain Generalization}, author={Sebastian Doerrich and Francesco Di Salvo and Jonas Alle and Christian Ledig}, year={2026}, eprint={2601.17586}, archivePrefix={arXiv}, primaryClass={cs.CV} } ``` ## Model Card Contact For questions or issues, please open an issue in the [GitHub repository](https://github.com/sdoerrich97/stylizing-vit).