MViT-Stroke-GAN / README.md
HJPeng's picture
Update README.md
cc4ed95 verified
metadata
license: cc-by-4.0
library_name: pytorch
inference: true
tags:
  - image-to-image
  - style-transfer
  - cyclegan
  - vision-transformer
  - mvit
  - sketch-synthesis
  - hand-drawn

MViT-Stroke-CycleGAN - Pretrained Weights

Pretrained weights for MViT-Stroke-CycleGAN, a CycleGAN variant using Mobile Vision Transformer (MViT) for hand-drawn stroke synthesis.

🔗 Code: GitHub Repository


📦 Checkpoint

File Description
MViT-Stroke-GAN.pth Trained model weights.

⚠️ Inference: Use train() Mode Only

Do not use model.eval() — it causes severe artifacts due to MViT's design.

✅ Correct:

# Load the model and weights
model = MViTCycleGANModel(3, 3)  # or your model class
model.netG_A.load_state_dict(torch.load("Stroke_MViTGAN.pth", map_location="cpu")["netG_A"])

model.train()  # Keep training mode

def disable_dropout(m):
    if isinstance(m, torch.nn.Dropout):
        m.eval()
    if isinstance(m, torch.nn.Dropout2d):
        m.eval()

model.apply(disable_dropout)

with torch.no_grad():
    output = model.netG_A(input_tensor)

📦 Installation

pip install torch torchvision timm

📄 Citation

If you use this model or weights in your work, please cite:

@misc{mvit_stroke_cyclegan_2025,
  author = {},
  title = {},
  year = {2025},
  publisher = {Hugging Face},
  howpublished = {\url{https://github.com/HJ-Peng/MViT-Stroke-GAN}}
}