MViT-Stroke-GAN / README.md
HJPeng's picture
Update README.md
cc4ed95 verified
---
license: cc-by-4.0
library_name: pytorch
inference: true
tags:
- image-to-image
- style-transfer
- cyclegan
- vision-transformer
- mvit
- sketch-synthesis
- hand-drawn
---
# MViT-Stroke-CycleGAN - Pretrained Weights
Pretrained weights for **MViT-Stroke-CycleGAN**, a CycleGAN variant using **Mobile Vision Transformer (MViT)** for hand-drawn stroke synthesis.
> 🔗 Code: [GitHub Repository](https://github.com/HJ-Peng/MViT-Stroke-GAN)
---
## 📦 Checkpoint
| File | Description |
|------|-------------|
| `MViT-Stroke-GAN.pth` | Trained model weights. |
---
## ⚠️ Inference: Use `train()` Mode Only
Do **not** use `model.eval()` — it causes severe artifacts due to MViT's design.
✅ Correct:
```python
# Load the model and weights
model = MViTCycleGANModel(3, 3) # or your model class
model.netG_A.load_state_dict(torch.load("Stroke_MViTGAN.pth", map_location="cpu")["netG_A"])
model.train() # Keep training mode
def disable_dropout(m):
if isinstance(m, torch.nn.Dropout):
m.eval()
if isinstance(m, torch.nn.Dropout2d):
m.eval()
model.apply(disable_dropout)
with torch.no_grad():
output = model.netG_A(input_tensor)
```
---
## 📦 Installation
```bash
pip install torch torchvision timm
```
---
## 📄 Citation
If you use this model or weights in your work, please cite:
```bibtex
@misc{mvit_stroke_cyclegan_2025,
author = {},
title = {},
year = {2025},
publisher = {Hugging Face},
howpublished = {\url{https://github.com/HJ-Peng/MViT-Stroke-GAN}}
}
```