| --- |
| license: cc-by-4.0 |
| library_name: pytorch |
| inference: true |
| tags: |
| - image-to-image |
| - style-transfer |
| - cyclegan |
| - vision-transformer |
| - mvit |
| - sketch-synthesis |
| - hand-drawn |
| --- |
| |
| # MViT-Stroke-CycleGAN - Pretrained Weights |
|
|
| Pretrained weights for **MViT-Stroke-CycleGAN**, a CycleGAN variant using **Mobile Vision Transformer (MViT)** for hand-drawn stroke synthesis. |
|
|
| > 🔗 Code: [GitHub Repository](https://github.com/HJ-Peng/MViT-Stroke-GAN) |
|
|
| --- |
|
|
| ## 📦 Checkpoint |
|
|
| | File | Description | |
| |------|-------------| |
| | `MViT-Stroke-GAN.pth` | Trained model weights. | |
|
|
| --- |
|
|
| ## ⚠️ Inference: Use `train()` Mode Only |
|
|
| Do **not** use `model.eval()` — it causes severe artifacts due to MViT's design. |
|
|
| ✅ Correct: |
| ```python |
| # Load the model and weights |
| model = MViTCycleGANModel(3, 3) # or your model class |
| model.netG_A.load_state_dict(torch.load("Stroke_MViTGAN.pth", map_location="cpu")["netG_A"]) |
| |
| model.train() # Keep training mode |
| |
| def disable_dropout(m): |
| if isinstance(m, torch.nn.Dropout): |
| m.eval() |
| if isinstance(m, torch.nn.Dropout2d): |
| m.eval() |
| |
| model.apply(disable_dropout) |
| |
| with torch.no_grad(): |
| output = model.netG_A(input_tensor) |
| ``` |
|
|
| --- |
|
|
| ## 📦 Installation |
|
|
| ```bash |
| pip install torch torchvision timm |
| ``` |
|
|
| --- |
|
|
| ## 📄 Citation |
|
|
| If you use this model or weights in your work, please cite: |
|
|
| ```bibtex |
| @misc{mvit_stroke_cyclegan_2025, |
| author = {}, |
| title = {}, |
| year = {2025}, |
| publisher = {Hugging Face}, |
| howpublished = {\url{https://github.com/HJ-Peng/MViT-Stroke-GAN}} |
| } |
| ``` |