Stylizing ViT Tiny - Cholec80 (Laparoscopy, Cholecystectomy)

This model is the Tiny variant of Stylizing ViT, trained on the Cholec80 (laparoscopy, cholecystectomy) dataset with the following splits: Train: {41, 42} / Val: {43} / Test: {44, 45}.

Stylizing ViT is a novel Vision Transformer encoder that utilizes weight-shared attention blocks for both self- and cross-attention. This design allows the same attention block to maintain anatomical consistency (via self-attention) while performing style transfer (via cross-attention), enabling anatomy-preserving instance style transfer for domain generalization in medical imaging.

Model Details

Model Description

deep learning models in medical image analysis often struggle with generalizability across domains and demographic groups due to data heterogeneity and scarcity. Traditional augmentation improves robustness, but fails under substantial domain shifts. Recent advances in stylistic augmentation enhance domain generalization by varying image styles but fall short in terms of style diversity or by introducing artifacts into the generated images.

To address these limitations, we propose Stylizing ViT, a modality-agnostic style augmentation method. It uses a single-encoder Vision Transformer (ViT) architecture to fuse anatomical structure from a content image with stylistic attributes from a reference image.

  • Developed by: Sebastian Doerrich (xAILab Bamberg, University of Bamberg)
  • Funded by: Hightech Agenda Bayern (HTA) of the Free State of Bavaria, Germany
  • Model type: Vision Transformer (ViT) with Cross-Attention Mechanism
  • Language(s): English (Documentation)
  • License: Apache-2.0

Model Sources

Uses

Direct Use

The primary use case is style transfer for medical images. This model takes a content image (e.g., a specific pathology slide) and a style reference, and generates a new image that retains the anatomical content of the first but adopts the visual style (staining, color distribution) of the second.

Downstream Use

  • Data Augmentation: The model is designed to be used during the training of downstream classifiers (e.g., for cancer detection) to improve domain generalization. By generating stylistically diverse samples, it encourages the classifier to learn shape-aware features rather than relying on spurious color/texture correlations.
  • Test-Time Augmentation (TTA): It can be used at inference time to map input images to a known training distribution, improving performance on out-of-distribution data.

Out-of-Scope Use

  • Diagnostic Use: This model is an augmentation tool, not a diagnostic device. The generated images should not be used for primary diagnosis without expert validation, as artifacts could theoretically mask or create pathological features (though the method optimizes for anatomical preservation).
  • Non-Medical Style Transfer: While the architecture is general, this specific checkpoint is trained on the above specified dataset. Using it for artistic style transfer on natural images may yield suboptimal results.

Bias, Risks, and Limitations

Limitations

  • Artifacts: While Stylizing ViT outperforms prior methods in reducing artifacts, style transfer can occasionally introduce unnatural textures or lose fine-grained details if the domain gap is too large.
  • Computational Cost: Being a transformer-based model, it requires more compute than simple color augmentation techniques.

Recommendations

Users should valid visually that the anatomical structures (e.g., cell boundaries, tissue architecture) are preserved in the stylized output before using the generated data for training sensitive downstream models.

How to Get Started with the Model

You can load this model using the stylizing-vit library. Since this model is hosted in its own repository, you can download the weights using huggingface_hub and load them into the model.

Input Requirements

The model requires the following input specification:

  • Resolution: 224x224 pixels.
  • Format: PyTorch Tensor (B, C, H, W).
  • Normalization: Images must be normalized (e.g., using ImageNet statistics or dataset-specific mean/std) before inference.

Installation

pip install stylizing-vit

Inference Snippet

import torch
from stylizing_vit import create_model

# Initialize the model
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load the model with pretrained weights
# This automatically downloads the weights from the Hugging Face Hub
model = create_model(backbone="tiny", weights="cholec80", train=False).to(device)
model.eval()

# Apply Style Transfer
# content_img and style_img should be normalized torch tensors of shape (1, 3, 224, 224)
# with torch.no_grad():
#     stylized_img = model(content_img, style_img)

Training Details

Training Procedure

The model is trained to minimize a combination of losses using a frozen VGG19 perceptual network:

  • Anatomical Loss ( λa=7.0\lambda_a=7.0 ): Preserves structural content.
  • Style Loss ( λs=10.0\lambda_s=10.0 ): Enforces stylistic similarity to the reference.
  • Identity Loss ( λid=70.0\lambda_{id}=70.0 ): Ensures reconstruction fidelity when input and style are identical.
  • Consistency Loss ( λc=1.0\lambda_c=1.0 ): Regularizes the feature space.

Training Hyperparameters

  • Architecture: vit_tiny (12 layers, 3 heads, 192 embedding dim)
  • Image Size: 224x224
  • Patch Size: 16
  • Epochs: 50
  • Batch Size: 64
  • Optimizer: AdamW (timm.optim.create_optimizer_v2)
  • LR Scheduler: Cosine Annealing (timm.scheduler.CosineLRScheduler)

Training Snippet

import torch
from accelerate import Accelerator
from stylizing_vit.model import StylizingViT

# 1. Setup
accelerator = Accelerator()
device = accelerator.device
model = StylizingViT(backbone="tiny", train=True).to(device)

# Frozen VGG encoder for loss computation
model.vgg_encoder.requires_grad_(False) 
model.vgg_encoder.eval()

# Optimized parameters: Encoder + Bottleneck + Post-Process Conv
params = (
    list(model.encoder.parameters()) + 
    list(model.bottleneck.parameters()) + 
    list(model.post_process_conv.parameters())
)
optimizer = torch.optim.AdamW(params, lr=1e-4) # Example LR

# 2. Training Loop
model.train()
for epoch in range(100):
   for batch in train_loader:
       # images: (B, C, H, W)
       images, _ = batch
       images = images.to(device)
       
       # Create style pairs (e.g., by rolling the batch)
       style_images = images.roll(shifts=1, dims=0)
       
       # Forward pass returns loss components and reconstructions
       # Model internally computes Identity, Consistency, Anatomical, and Style losses
       loss_dict, _ = model(images, style_images)
       
       # Total loss is weighted sum (internal to model return or manually summed)
       total_loss = loss_dict.total_loss 
       
       optimizer.zero_grad()
       accelerator.backward(total_loss)
       optimizer.step()

Data Augmentation Snippet

import torch
from stylizing_vit import create_model

# Load pre-trained Stylizing ViT
stylizer = create_model(backbone="tiny", weights="cholec80", train=False)
stylizer.eval()
stylizer.requires_grad_(False)

def augment_batch(images):
    """
    Augment a batch of images using style transfer.
    """
    # Create style reference (e.g., shuffle current batch)
    style_reference = images[torch.randperm(images.size(0))]
    
    with torch.no_grad():
        # Generate stylized images
        # Input images should be normalized
        stylized_images = stylizer(images, style_reference)
    
    return stylized_images

# Usage in training loop
# for images, labels in dataloader:
#     augmented_images = augment_batch(images)
#     # Pass augmented_images to your downstream classifier...

Evaluation

Metrics

The model is evaluated on:

  • Reconstruction: PSNR, SSIM (structure preservation).
  • Style Transfer: FID, ArtFID, LPIPS (perceptual quality and diversity).
  • Classification Performance: Accuracy.

Results

As reported in the associated ISBI 2026 paper, Stylizing ViT demonstrates improved robustness (up to 13% accuracy gain) over state-of-the-art methods in domain generalization tasks on histopathology and dermatology datasets.

Citation

If you use this model in your research, please cite:

@article{doerrich2026stylizingvit,
  title={Stylizing ViT: Anatomy-Preserving Instance Style Transfer for Domain Generalization},
  author={Sebastian Doerrich and Francesco Di Salvo and Jonas Alle and Christian Ledig},
  year={2026},
  eprint={2601.17586},
  archivePrefix={arXiv},
  primaryClass={cs.CV}
}

Model Card Contact

For questions or issues, please open an issue in the GitHub repository.

Downloads last month
-
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Collection including sdoerrich97/stylizing_vit_tiny_cholec80

Paper for sdoerrich97/stylizing_vit_tiny_cholec80