let's go (#2)
Browse files- let's go (28a202cb896ecb7388dcd60a130a61a763b87287)
Co-authored-by: Seif benayed <seifbenayed@users.noreply.huggingface.co>
- .gitattributes +3 -0
- DEPLOYMENT.md +194 -0
- README.md +105 -7
- app.py +152 -0
- checkpoints/dtd_doctamper.pth +3 -0
- checkpoints/qt_table.pk +0 -0
- checkpoints/swin_imagenet.pt +3 -0
- checkpoints/vph_imagenet.pt +3 -0
- examples/Paystub.jpg +3 -0
- examples/TamperedPaystub.jpg +3 -0
- examples/TamperedPaystubv1.jpg +3 -0
- examples/carte.jpeg +0 -0
- inference.py +187 -0
- models/__init__.py +1 -0
- models/dtd.py +356 -0
- models/fix_imports.py +31 -0
- models/fph.py +132 -0
- models/patch_droppath.py +17 -0
- models/patch_gelu.py +34 -0
- models/swins.py +454 -0
- requirements.txt +11 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
examples/Paystub.jpg filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
examples/TamperedPaystub.jpg filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
examples/TamperedPaystubv1.jpg filter=lfs diff=lfs merge=lfs -text
|
DEPLOYMENT.md
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🚀 Deployment Guide - DTD Document Tampering Detection
|
| 2 |
+
|
| 3 |
+
## Hugging Face Spaces Deployment
|
| 4 |
+
|
| 5 |
+
### Prerequisites
|
| 6 |
+
|
| 7 |
+
- Hugging Face account
|
| 8 |
+
- Git installed locally
|
| 9 |
+
- Git LFS installed (for large model files)
|
| 10 |
+
|
| 11 |
+
### Step 1: Install Git LFS
|
| 12 |
+
|
| 13 |
+
```bash
|
| 14 |
+
# Mac
|
| 15 |
+
brew install git-lfs
|
| 16 |
+
|
| 17 |
+
# Linux
|
| 18 |
+
sudo apt-get install git-lfs
|
| 19 |
+
|
| 20 |
+
# Initialize Git LFS
|
| 21 |
+
git lfs install
|
| 22 |
+
```
|
| 23 |
+
|
| 24 |
+
### Step 2: Create Hugging Face Space
|
| 25 |
+
|
| 26 |
+
1. Go to https://huggingface.co/new-space
|
| 27 |
+
2. Choose a name (e.g., `dtd-doctamper-detection`)
|
| 28 |
+
3. Select **Gradio** as SDK
|
| 29 |
+
4. Choose license: **MIT**
|
| 30 |
+
5. Click **Create Space**
|
| 31 |
+
|
| 32 |
+
### Step 3: Clone and Setup
|
| 33 |
+
|
| 34 |
+
```bash
|
| 35 |
+
# Clone your new space
|
| 36 |
+
git clone https://huggingface.co/spaces/YOUR_USERNAME/dtd-doctamper-detection
|
| 37 |
+
cd dtd-doctamper-detection
|
| 38 |
+
|
| 39 |
+
# Copy app files
|
| 40 |
+
cp -r /path/to/gradio_dtd_app/* .
|
| 41 |
+
|
| 42 |
+
# Track large files with Git LFS
|
| 43 |
+
git lfs track "*.pth"
|
| 44 |
+
git lfs track "*.pt"
|
| 45 |
+
git add .gitattributes
|
| 46 |
+
|
| 47 |
+
# Add all files
|
| 48 |
+
git add .
|
| 49 |
+
|
| 50 |
+
# Commit
|
| 51 |
+
git commit -m "Initial commit: DTD document tampering detection app"
|
| 52 |
+
|
| 53 |
+
# Push to Hugging Face
|
| 54 |
+
git push
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
### Step 4: Configure Space Settings
|
| 58 |
+
|
| 59 |
+
After pushing, Hugging Face will automatically:
|
| 60 |
+
- Install dependencies from `requirements.txt`
|
| 61 |
+
- Build the Docker container
|
| 62 |
+
- Start the Gradio app
|
| 63 |
+
- Assign a public URL
|
| 64 |
+
|
| 65 |
+
### Step 5: Test Your Space
|
| 66 |
+
|
| 67 |
+
Visit: `https://huggingface.co/spaces/YOUR_USERNAME/dtd-doctamper-detection`
|
| 68 |
+
|
| 69 |
+
## Local Testing
|
| 70 |
+
|
| 71 |
+
Before deploying, test locally:
|
| 72 |
+
|
| 73 |
+
```bash
|
| 74 |
+
cd gradio_dtd_app
|
| 75 |
+
|
| 76 |
+
# Create virtual environment
|
| 77 |
+
python -m venv venv
|
| 78 |
+
source venv/bin/activate # On Windows: venv\Scripts\activate
|
| 79 |
+
|
| 80 |
+
# Install dependencies
|
| 81 |
+
pip install -r requirements.txt
|
| 82 |
+
|
| 83 |
+
# Run app
|
| 84 |
+
python app.py
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
Open browser to: `http://localhost:7860`
|
| 88 |
+
|
| 89 |
+
## Troubleshooting
|
| 90 |
+
|
| 91 |
+
### Issue: Git LFS bandwidth limit
|
| 92 |
+
|
| 93 |
+
**Solution**: Use Hugging Face's built-in LFS storage:
|
| 94 |
+
|
| 95 |
+
```bash
|
| 96 |
+
# Track checkpoint files
|
| 97 |
+
git lfs track "checkpoints/*.pth"
|
| 98 |
+
git lfs track "checkpoints/*.pt"
|
| 99 |
+
git add .gitattributes checkpoints/
|
| 100 |
+
git commit -m "Add model checkpoints"
|
| 101 |
+
git push
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
### Issue: Build timeout
|
| 105 |
+
|
| 106 |
+
**Solution**: Reduce requirements versions or use pre-built images:
|
| 107 |
+
|
| 108 |
+
```yaml
|
| 109 |
+
# Create .github/workflows/deploy.yml
|
| 110 |
+
sdk: gradio
|
| 111 |
+
sdk_version: 4.44.0
|
| 112 |
+
python_version: "3.10"
|
| 113 |
+
```
|
| 114 |
+
|
| 115 |
+
### Issue: Out of memory
|
| 116 |
+
|
| 117 |
+
**Solution**: Enable GPU hardware in Space settings:
|
| 118 |
+
1. Go to Space settings
|
| 119 |
+
2. Select **Hardware**: GPU (free tier: T4)
|
| 120 |
+
3. Save changes
|
| 121 |
+
|
| 122 |
+
## File Size Optimization
|
| 123 |
+
|
| 124 |
+
Current app size: **~450MB**
|
| 125 |
+
|
| 126 |
+
To reduce size:
|
| 127 |
+
|
| 128 |
+
1. **Quantize models** (reduce precision):
|
| 129 |
+
```python
|
| 130 |
+
# In inference.py
|
| 131 |
+
torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
|
| 132 |
+
```
|
| 133 |
+
|
| 134 |
+
2. **Use model compression**:
|
| 135 |
+
```bash
|
| 136 |
+
pip install onnx onnxruntime
|
| 137 |
+
# Convert to ONNX format (smaller)
|
| 138 |
+
```
|
| 139 |
+
|
| 140 |
+
3. **Lazy loading**:
|
| 141 |
+
```python
|
| 142 |
+
# Load models on first request instead of startup
|
| 143 |
+
@lru_cache()
|
| 144 |
+
def get_model():
|
| 145 |
+
return DTDPredictor()
|
| 146 |
+
```
|
| 147 |
+
|
| 148 |
+
## Environment Variables
|
| 149 |
+
|
| 150 |
+
Add to Space secrets:
|
| 151 |
+
|
| 152 |
+
```bash
|
| 153 |
+
# Optional: Analytics tracking
|
| 154 |
+
ANALYTICS_TOKEN=your_token
|
| 155 |
+
|
| 156 |
+
# Optional: Rate limiting
|
| 157 |
+
MAX_REQUESTS_PER_HOUR=100
|
| 158 |
+
```
|
| 159 |
+
|
| 160 |
+
## Monitoring
|
| 161 |
+
|
| 162 |
+
Check Space logs:
|
| 163 |
+
1. Go to Space page
|
| 164 |
+
2. Click **Logs** tab
|
| 165 |
+
3. Monitor real-time inference
|
| 166 |
+
|
| 167 |
+
## Custom Domain (Optional)
|
| 168 |
+
|
| 169 |
+
1. Go to Space settings
|
| 170 |
+
2. Add custom domain
|
| 171 |
+
3. Configure DNS records
|
| 172 |
+
|
| 173 |
+
## Cost Optimization
|
| 174 |
+
|
| 175 |
+
**Free Tier Limits:**
|
| 176 |
+
- CPU: Free (slower inference)
|
| 177 |
+
- GPU T4: Free tier available
|
| 178 |
+
- Storage: 50GB LFS
|
| 179 |
+
- Bandwidth: Limited
|
| 180 |
+
|
| 181 |
+
**Upgrade Options:**
|
| 182 |
+
- GPU A10G: Faster inference
|
| 183 |
+
- Persistent storage
|
| 184 |
+
- Higher bandwidth
|
| 185 |
+
|
| 186 |
+
## Support
|
| 187 |
+
|
| 188 |
+
- [Hugging Face Docs](https://huggingface.co/docs/hub/spaces)
|
| 189 |
+
- [Gradio Docs](https://gradio.app/docs/)
|
| 190 |
+
- [Git LFS](https://git-lfs.github.com/)
|
| 191 |
+
|
| 192 |
+
## License
|
| 193 |
+
|
| 194 |
+
MIT License - See LICENSE file
|
README.md
CHANGED
|
@@ -1,13 +1,111 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
-
license:
|
| 11 |
---
|
| 12 |
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: DTD Document Tampering Detection
|
| 3 |
+
emoji: 🔍
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: red
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 4.44.0
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
---
|
| 12 |
|
| 13 |
+
# 🔍 DTD: Document Tampering Detection
|
| 14 |
+
|
| 15 |
+
Detect forged or tampered regions in document images using the **DTD (Document Tampering Detector)** model.
|
| 16 |
+
|
| 17 |
+
## 📝 Description
|
| 18 |
+
|
| 19 |
+
This application uses state-of-the-art deep learning to identify manipulated text in document images by analyzing JPEG compression artifacts (DCT coefficients).
|
| 20 |
+
|
| 21 |
+
### ✨ Features
|
| 22 |
+
|
| 23 |
+
- **DCT Analysis**: Examines JPEG compression patterns to detect inconsistencies
|
| 24 |
+
- **Real-time Detection**: Fast inference on CPU or GPU
|
| 25 |
+
- **Visual Heatmaps**: Clear visualization of tampered regions
|
| 26 |
+
- **High Accuracy**: Trained on DocTamper dataset with 120K+ document images
|
| 27 |
+
|
| 28 |
+
### 🎯 Use Cases
|
| 29 |
+
|
| 30 |
+
- Verify document authenticity
|
| 31 |
+
- Detect forged receipts, invoices, and forms
|
| 32 |
+
- Identify copy-paste text manipulation
|
| 33 |
+
- Detect splicing and content insertion
|
| 34 |
+
|
| 35 |
+
## 🚀 How It Works
|
| 36 |
+
|
| 37 |
+
1. **Upload** a document image (JPEG format works best)
|
| 38 |
+
2. **Adjust** JPEG quality setting for DCT analysis (default: 90)
|
| 39 |
+
3. **View** tampering detection results:
|
| 40 |
+
- **Heatmap**: Red overlay shows tampered regions
|
| 41 |
+
- **Binary Mask**: Clear segmentation of authentic vs tampered
|
| 42 |
+
- **Original**: Compare with input
|
| 43 |
+
|
| 44 |
+
## 🏗️ Model Architecture
|
| 45 |
+
|
| 46 |
+
- **Backbone**: VPH (Vision Pyramid Hybrid) + Swin Transformer
|
| 47 |
+
- **Decoder**: Multi-scale Iterative Decoder (MID)
|
| 48 |
+
- **Inputs**: RGB image + DCT coefficients + Quantization tables
|
| 49 |
+
- **Output**: Binary segmentation mask (0=authentic, 1=tampered)
|
| 50 |
+
|
| 51 |
+
## 📚 Citation
|
| 52 |
+
|
| 53 |
+
```bibtex
|
| 54 |
+
@inproceedings{qu2023towards,
|
| 55 |
+
title={Towards Robust Tampered Text Detection in Document Image: New Dataset and New Solution},
|
| 56 |
+
author={Qu, Chenfan and Liu, Chongyu and Liu, Yuliang and Chen, Xinhong and Peng, Dezhi and Guo, Fengjun and Jin, Lianwen},
|
| 57 |
+
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
|
| 58 |
+
pages={5937--5946},
|
| 59 |
+
year={2023}
|
| 60 |
+
}
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
## 📖 Paper
|
| 64 |
+
|
| 65 |
+
[Towards Robust Tampered Text Detection in Document Image: New Dataset and New Solution](https://openaccess.thecvf.com/content/CVPR2023/papers/Qu_Towards_Robust_Tampered_Text_Detection_in_Document_Image_New_Dataset_CVPR_2023_paper.pdf) (CVPR 2023)
|
| 66 |
+
|
| 67 |
+
## ⚠️ Limitations
|
| 68 |
+
|
| 69 |
+
- **JPEG Dependency**: Requires JPEG format for DCT analysis
|
| 70 |
+
- **Quality Sensitivity**: Detection accuracy varies with compression quality
|
| 71 |
+
- **False Positives**: May occur on low-quality scans or heavily compressed images
|
| 72 |
+
- **Preprocessing**: Images must contain text/document content
|
| 73 |
+
|
| 74 |
+
## 🛠️ Technical Details
|
| 75 |
+
|
| 76 |
+
### Model Weights
|
| 77 |
+
|
| 78 |
+
- **Main Model**: `dtd_doctamper.pth` (257MB)
|
| 79 |
+
- **VPH Backbone**: `vph_imagenet.pt` (4.8MB)
|
| 80 |
+
- **Swin Transformer**: `swin_imagenet.pt` (187MB)
|
| 81 |
+
- **Total Size**: ~449MB
|
| 82 |
+
|
| 83 |
+
### Performance
|
| 84 |
+
|
| 85 |
+
- **Input Size**: Variable (auto-resized)
|
| 86 |
+
- **Inference Time**: ~2-5 seconds on CPU
|
| 87 |
+
- **GPU Acceleration**: Supported (CUDA)
|
| 88 |
+
|
| 89 |
+
## 📦 Local Installation
|
| 90 |
+
|
| 91 |
+
```bash
|
| 92 |
+
# Clone repository
|
| 93 |
+
git clone https://huggingface.co/spaces/YOUR_USERNAME/dtd-doctamper-detection
|
| 94 |
+
cd dtd-doctamper-detection
|
| 95 |
+
|
| 96 |
+
# Install dependencies
|
| 97 |
+
pip install -r requirements.txt
|
| 98 |
+
|
| 99 |
+
# Run application
|
| 100 |
+
python app.py
|
| 101 |
+
```
|
| 102 |
+
|
| 103 |
+
## 📄 License
|
| 104 |
+
|
| 105 |
+
MIT License - See LICENSE file for details
|
| 106 |
+
|
| 107 |
+
## 🤝 Acknowledgments
|
| 108 |
+
|
| 109 |
+
- Original DTD model by Qu et al. (CVPR 2023)
|
| 110 |
+
- DocTamper dataset
|
| 111 |
+
- Hugging Face Spaces for hosting
|
app.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
DTD DocTamper - Gradio Application
|
| 3 |
+
Document Tampering Detection using DTD model
|
| 4 |
+
"""
|
| 5 |
+
import gradio as gr
|
| 6 |
+
import numpy as np
|
| 7 |
+
from inference import DTDPredictor
|
| 8 |
+
|
| 9 |
+
# Initialize predictor
|
| 10 |
+
print("Loading DTD model...")
|
| 11 |
+
predictor = DTDPredictor(
|
| 12 |
+
checkpoint_path='checkpoints/dtd_doctamper.pth',
|
| 13 |
+
device='auto'
|
| 14 |
+
)
|
| 15 |
+
print("Model loaded!")
|
| 16 |
+
|
| 17 |
+
def predict_tampering(image, quality=90):
|
| 18 |
+
"""
|
| 19 |
+
Predict document tampering
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
image: Input image (PIL Image or numpy array)
|
| 23 |
+
quality: JPEG compression quality for DCT analysis
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
Tuple of (original, mask, heatmap)
|
| 27 |
+
"""
|
| 28 |
+
# Save uploaded image temporarily
|
| 29 |
+
import tempfile
|
| 30 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as tmp:
|
| 31 |
+
if hasattr(image, 'save'):
|
| 32 |
+
image.save(tmp, 'JPEG', quality=95)
|
| 33 |
+
else:
|
| 34 |
+
from PIL import Image
|
| 35 |
+
Image.fromarray(image).save(tmp, 'JPEG', quality=95)
|
| 36 |
+
tmp_path = tmp.name
|
| 37 |
+
|
| 38 |
+
try:
|
| 39 |
+
# Run prediction
|
| 40 |
+
result = predictor.predict(tmp_path, quality=quality)
|
| 41 |
+
|
| 42 |
+
return (
|
| 43 |
+
result['original'],
|
| 44 |
+
result['mask'],
|
| 45 |
+
result['heatmap']
|
| 46 |
+
)
|
| 47 |
+
finally:
|
| 48 |
+
import os
|
| 49 |
+
os.unlink(tmp_path)
|
| 50 |
+
|
| 51 |
+
# Create Gradio interface
|
| 52 |
+
with gr.Blocks(title="DTD Document Tampering Detection") as demo:
|
| 53 |
+
gr.Markdown("""
|
| 54 |
+
# 🔍 DTD: Document Tampering Detection
|
| 55 |
+
|
| 56 |
+
Upload a document image to detect forged or tampered regions using the DTD (Document Tampering Detector) model.
|
| 57 |
+
|
| 58 |
+
**How it works:**
|
| 59 |
+
- The model analyzes JPEG compression artifacts (DCT coefficients)
|
| 60 |
+
- Red regions indicate potential tampering
|
| 61 |
+
- Works best on JPEG images of documents
|
| 62 |
+
|
| 63 |
+
**Paper:** [Towards Robust Tampered Text Detection in Document Image](https://openaccess.thecvf.com/content/CVPR2023/papers/Qu_Towards_Robust_Tampered_Text_Detection_in_Document_Image_New_Dataset_CVPR_2023_paper.pdf)
|
| 64 |
+
""")
|
| 65 |
+
|
| 66 |
+
with gr.Row():
|
| 67 |
+
with gr.Column():
|
| 68 |
+
input_image = gr.Image(
|
| 69 |
+
label="Upload Document Image",
|
| 70 |
+
type="pil"
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
quality_slider = gr.Slider(
|
| 74 |
+
minimum=75,
|
| 75 |
+
maximum=95,
|
| 76 |
+
value=90,
|
| 77 |
+
step=5,
|
| 78 |
+
label="JPEG Quality for DCT Analysis",
|
| 79 |
+
info="Higher quality = more sensitive detection"
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
submit_btn = gr.Button("Detect Tampering", variant="primary")
|
| 83 |
+
|
| 84 |
+
with gr.Column():
|
| 85 |
+
with gr.Tab("Heatmap Overlay"):
|
| 86 |
+
output_heatmap = gr.Image(label="Tampering Heatmap")
|
| 87 |
+
|
| 88 |
+
with gr.Tab("Binary Mask"):
|
| 89 |
+
output_mask = gr.Image(label="Tampering Mask")
|
| 90 |
+
|
| 91 |
+
with gr.Tab("Original"):
|
| 92 |
+
output_original = gr.Image(label="Original Image")
|
| 93 |
+
|
| 94 |
+
# Examples
|
| 95 |
+
gr.Examples(
|
| 96 |
+
examples=[
|
| 97 |
+
["examples/carte.jpeg", 90],
|
| 98 |
+
["examples/TamperedPaystub.jpg", 90],
|
| 99 |
+
["examples/Paystub.jpg", 90],
|
| 100 |
+
],
|
| 101 |
+
inputs=[input_image, quality_slider],
|
| 102 |
+
outputs=[output_original, output_mask, output_heatmap],
|
| 103 |
+
fn=predict_tampering,
|
| 104 |
+
cache_examples=False,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
# Event handlers
|
| 108 |
+
submit_btn.click(
|
| 109 |
+
fn=predict_tampering,
|
| 110 |
+
inputs=[input_image, quality_slider],
|
| 111 |
+
outputs=[output_original, output_mask, output_heatmap]
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
gr.Markdown("""
|
| 115 |
+
---
|
| 116 |
+
### ℹ️ About
|
| 117 |
+
|
| 118 |
+
**DTD (Document Tampering Detector)** is a deep learning model designed to detect forged text in document images.
|
| 119 |
+
|
| 120 |
+
**Features:**
|
| 121 |
+
- Analyzes JPEG compression artifacts using DCT (Discrete Cosine Transform)
|
| 122 |
+
- Detects copy-paste, splicing, and text manipulation
|
| 123 |
+
- Works on scanned documents, photos of documents, and digital documents
|
| 124 |
+
|
| 125 |
+
**Citation:**
|
| 126 |
+
```bibtex
|
| 127 |
+
@inproceedings{qu2023towards,
|
| 128 |
+
title={Towards Robust Tampered Text Detection in Document Image: New Dataset and New Solution},
|
| 129 |
+
author={Qu, Chenfan and Liu, Chongyu and Liu, Yuliang and Chen, Xinhong and Peng, Dezhi and Guo, Fengjun and Jin, Lianwen},
|
| 130 |
+
booktitle={CVPR},
|
| 131 |
+
year={2023}
|
| 132 |
+
}
|
| 133 |
+
```
|
| 134 |
+
|
| 135 |
+
**Model Architecture:**
|
| 136 |
+
- Backbone: VPH (Vision Pyramid Hybrid) + Swin Transformer
|
| 137 |
+
- Decoder: Multi-scale Iterative Decoder (MID)
|
| 138 |
+
- Input: RGB image + DCT coefficients + Quantization tables
|
| 139 |
+
- Output: Binary segmentation mask
|
| 140 |
+
|
| 141 |
+
**Limitations:**
|
| 142 |
+
- Requires JPEG images for DCT analysis
|
| 143 |
+
- May produce false positives on low-quality scans
|
| 144 |
+
- Performance varies with JPEG compression quality
|
| 145 |
+
""")
|
| 146 |
+
|
| 147 |
+
if __name__ == "__main__":
|
| 148 |
+
demo.launch(
|
| 149 |
+
server_name="0.0.0.0",
|
| 150 |
+
server_port=7860,
|
| 151 |
+
share=False
|
| 152 |
+
)
|
checkpoints/dtd_doctamper.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:81291d19e0c92fd56a8e76f422114a9c6bc6f67f4ac03a0facc18a045894a8c1
|
| 3 |
+
size 269695109
|
checkpoints/qt_table.pk
ADDED
|
Binary file (7.74 kB). View file
|
|
|
checkpoints/swin_imagenet.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1436a2d793dbbb74c9578a68e52d0e7deaa3f305560a34d287a8e4edc866b245
|
| 3 |
+
size 196402845
|
checkpoints/vph_imagenet.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f4e99cbcbd5f17a0004278a57e3ff199a0de7189345d7e6924e104710d602898
|
| 3 |
+
size 5000275
|
examples/Paystub.jpg
ADDED
|
Git LFS Details
|
examples/TamperedPaystub.jpg
ADDED
|
Git LFS Details
|
examples/TamperedPaystubv1.jpg
ADDED
|
Git LFS Details
|
examples/carte.jpeg
ADDED
|
inference.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
DTD DocTamper Inference Module
|
| 3 |
+
Simplified for Gradio deployment on Hugging Face
|
| 4 |
+
"""
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
import cv2
|
| 8 |
+
import torch
|
| 9 |
+
import jpegio
|
| 10 |
+
import numpy as np
|
| 11 |
+
import pickle
|
| 12 |
+
import tempfile
|
| 13 |
+
from PIL import Image
|
| 14 |
+
from models import fix_imports # Apply compatibility fixes
|
| 15 |
+
from models import patch_gelu
|
| 16 |
+
from models import patch_droppath
|
| 17 |
+
from models.dtd import seg_dtd
|
| 18 |
+
import torchvision.transforms as transforms
|
| 19 |
+
|
| 20 |
+
class DTDPredictor:
|
| 21 |
+
def __init__(self, checkpoint_path='checkpoints/dtd_doctamper.pth', device='cpu'):
|
| 22 |
+
"""
|
| 23 |
+
Initialize DTD model for inference
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
checkpoint_path: Path to model checkpoint
|
| 27 |
+
device: Device to use ('cpu', 'cuda', or 'auto')
|
| 28 |
+
"""
|
| 29 |
+
# Auto-detect device
|
| 30 |
+
if device == 'auto':
|
| 31 |
+
if torch.cuda.is_available():
|
| 32 |
+
self.device = 'cuda'
|
| 33 |
+
else:
|
| 34 |
+
self.device = 'cpu'
|
| 35 |
+
else:
|
| 36 |
+
self.device = device
|
| 37 |
+
|
| 38 |
+
print(f'Using device: {self.device}')
|
| 39 |
+
|
| 40 |
+
# Load QT table
|
| 41 |
+
with open('checkpoints/qt_table.pk', 'rb') as fpk:
|
| 42 |
+
pks = pickle.load(fpk)
|
| 43 |
+
self.pks = {}
|
| 44 |
+
for k, v in pks.items():
|
| 45 |
+
self.pks[k] = torch.LongTensor(v)
|
| 46 |
+
|
| 47 |
+
# Image transforms
|
| 48 |
+
self.transform = transforms.Compose([
|
| 49 |
+
transforms.ToTensor(),
|
| 50 |
+
transforms.Normalize(mean=(0.485, 0.455, 0.406),
|
| 51 |
+
std=(0.229, 0.224, 0.225))
|
| 52 |
+
])
|
| 53 |
+
|
| 54 |
+
# Load model
|
| 55 |
+
self.model = seg_dtd('', 2, device=self.device)
|
| 56 |
+
if self.device == 'cuda':
|
| 57 |
+
self.model = self.model.cuda()
|
| 58 |
+
self.model = self.model.to(self.device)
|
| 59 |
+
|
| 60 |
+
# Load checkpoint
|
| 61 |
+
ckpt = torch.load(checkpoint_path, map_location='cpu')
|
| 62 |
+
state_dict = ckpt['state_dict']
|
| 63 |
+
|
| 64 |
+
# Remove 'module.' prefix if present
|
| 65 |
+
new_state_dict = {}
|
| 66 |
+
for k, v in state_dict.items():
|
| 67 |
+
if k.startswith('module.'):
|
| 68 |
+
new_state_dict[k[7:]] = v
|
| 69 |
+
else:
|
| 70 |
+
new_state_dict[k] = v
|
| 71 |
+
|
| 72 |
+
self.model.load_state_dict(new_state_dict)
|
| 73 |
+
self.model.eval()
|
| 74 |
+
|
| 75 |
+
print('Model loaded successfully!')
|
| 76 |
+
|
| 77 |
+
def extract_dct(self, image_path, quality=90):
|
| 78 |
+
"""
|
| 79 |
+
Extract DCT coefficients from JPEG image
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
image_path: Path to JPEG image
|
| 83 |
+
quality: JPEG quality for re-compression
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
DCT coefficients and quantization table
|
| 87 |
+
"""
|
| 88 |
+
# Load image
|
| 89 |
+
im = Image.open(image_path).convert('RGB')
|
| 90 |
+
|
| 91 |
+
# Re-compress to JPEG with specified quality
|
| 92 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as tmp:
|
| 93 |
+
im_gray = im.convert("L")
|
| 94 |
+
im_gray.save(tmp, "JPEG", quality=quality)
|
| 95 |
+
tmp_path = tmp.name
|
| 96 |
+
|
| 97 |
+
try:
|
| 98 |
+
# Read JPEG with jpegio
|
| 99 |
+
jpg = jpegio.read(tmp_path)
|
| 100 |
+
dct = jpg.coef_arrays[0].copy()
|
| 101 |
+
|
| 102 |
+
# Get quantization table
|
| 103 |
+
qt = jpg.quant_tables[0]
|
| 104 |
+
qt_flat = qt.flatten()[:64] # First 64 values
|
| 105 |
+
|
| 106 |
+
return dct, qt_flat
|
| 107 |
+
finally:
|
| 108 |
+
# Clean up temp file
|
| 109 |
+
os.unlink(tmp_path)
|
| 110 |
+
|
| 111 |
+
@torch.no_grad()
|
| 112 |
+
def predict(self, image_path, quality=90):
|
| 113 |
+
"""
|
| 114 |
+
Predict tampering mask for input image
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
image_path: Path to input JPEG image
|
| 118 |
+
quality: JPEG quality for DCT extraction
|
| 119 |
+
|
| 120 |
+
Returns:
|
| 121 |
+
Dictionary containing:
|
| 122 |
+
- original: Original image (numpy array)
|
| 123 |
+
- mask: Binary tampering mask (numpy array)
|
| 124 |
+
- heatmap: Colorized heatmap overlay
|
| 125 |
+
"""
|
| 126 |
+
# Load image
|
| 127 |
+
im = Image.open(image_path).convert('RGB')
|
| 128 |
+
im_np = np.array(im)
|
| 129 |
+
|
| 130 |
+
# Extract DCT coefficients
|
| 131 |
+
dct, qt = self.extract_dct(image_path, quality)
|
| 132 |
+
|
| 133 |
+
# Prepare inputs
|
| 134 |
+
# Image
|
| 135 |
+
im_tensor = self.transform(im).unsqueeze(0).to(self.device)
|
| 136 |
+
|
| 137 |
+
# DCT coefficients (clip to [0, 20])
|
| 138 |
+
dct_tensor = torch.from_numpy(np.clip(np.abs(dct), 0, 20)).unsqueeze(0).unsqueeze(0).float().to(self.device)
|
| 139 |
+
|
| 140 |
+
# Quantization table
|
| 141 |
+
qt_indices = []
|
| 142 |
+
for val in qt:
|
| 143 |
+
# Find closest match in quantization table
|
| 144 |
+
if val in self.pks:
|
| 145 |
+
qt_indices.append(val)
|
| 146 |
+
else:
|
| 147 |
+
# Find closest
|
| 148 |
+
closest = min(self.pks.keys(), key=lambda x: abs(x - val))
|
| 149 |
+
qt_indices.append(closest)
|
| 150 |
+
|
| 151 |
+
qt_tensor = torch.LongTensor(qt_indices[:64]).unsqueeze(0).to(self.device)
|
| 152 |
+
|
| 153 |
+
# Forward pass
|
| 154 |
+
output = self.model(im_tensor, dct_tensor, qt_tensor)
|
| 155 |
+
|
| 156 |
+
# Get prediction mask
|
| 157 |
+
pred_mask = output.argmax(1).squeeze().cpu().numpy()
|
| 158 |
+
|
| 159 |
+
# Create heatmap overlay
|
| 160 |
+
heatmap = self.create_heatmap(im_np, pred_mask)
|
| 161 |
+
|
| 162 |
+
return {
|
| 163 |
+
'original': im_np,
|
| 164 |
+
'mask': (pred_mask * 255).astype(np.uint8),
|
| 165 |
+
'heatmap': heatmap
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
def create_heatmap(self, image, mask):
|
| 169 |
+
"""
|
| 170 |
+
Create colorized heatmap overlay
|
| 171 |
+
|
| 172 |
+
Args:
|
| 173 |
+
image: Original image (numpy array)
|
| 174 |
+
mask: Binary mask (numpy array)
|
| 175 |
+
|
| 176 |
+
Returns:
|
| 177 |
+
Heatmap overlay (numpy array)
|
| 178 |
+
"""
|
| 179 |
+
# Create colored mask
|
| 180 |
+
colored_mask = np.zeros_like(image)
|
| 181 |
+
colored_mask[mask == 1] = [255, 0, 0] # Red for tampered regions
|
| 182 |
+
|
| 183 |
+
# Blend with original image
|
| 184 |
+
alpha = 0.5
|
| 185 |
+
heatmap = cv2.addWeighted(image, 1 - alpha, colored_mask, alpha, 0)
|
| 186 |
+
|
| 187 |
+
return heatmap
|
models/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# DTD Models Module
|
models/dtd.py
ADDED
|
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import fix_imports # Apply import fixes
|
| 3 |
+
import cv2
|
| 4 |
+
import lmdb
|
| 5 |
+
import torch
|
| 6 |
+
import jpegio
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import gc
|
| 10 |
+
import math
|
| 11 |
+
import time
|
| 12 |
+
import copy
|
| 13 |
+
import logging
|
| 14 |
+
import torch.optim as optim
|
| 15 |
+
import torch.distributed as dist
|
| 16 |
+
import random
|
| 17 |
+
import pickle
|
| 18 |
+
import six
|
| 19 |
+
from glob import glob
|
| 20 |
+
from PIL import Image
|
| 21 |
+
from tqdm import tqdm
|
| 22 |
+
from torch.autograd import Variable
|
| 23 |
+
from torch.cuda.amp import autocast
|
| 24 |
+
import segmentation_models_pytorch as smp
|
| 25 |
+
from torch.utils.data import Dataset, DataLoader
|
| 26 |
+
from torch.cuda.amp import autocast, GradScaler#need pytorch>1.6
|
| 27 |
+
from losses import DiceLoss,FocalLoss,SoftCrossEntropyLoss,LovaszLoss
|
| 28 |
+
from fph import FPH
|
| 29 |
+
import albumentations as A
|
| 30 |
+
from swins import *
|
| 31 |
+
from albumentations.pytorch import ToTensorV2
|
| 32 |
+
import torchvision
|
| 33 |
+
import torch.nn.functional as F
|
| 34 |
+
try:
|
| 35 |
+
from timm.models.layers import trunc_normal_, DropPath
|
| 36 |
+
except ImportError:
|
| 37 |
+
from timm.layers import trunc_normal_, DropPath
|
| 38 |
+
from functools import partial
|
| 39 |
+
from segmentation_models_pytorch.base import modules as md
|
| 40 |
+
from typing import Optional, Union, List
|
| 41 |
+
from segmentation_models_pytorch.base import SegmentationModel
|
| 42 |
+
|
| 43 |
+
# Custom GELU for compatibility
|
| 44 |
+
class GELU(nn.Module):
|
| 45 |
+
def forward(self, x):
|
| 46 |
+
return F.gelu(x)
|
| 47 |
+
|
| 48 |
+
class LayerNorm(nn.Module):
|
| 49 |
+
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
|
| 50 |
+
super().__init__()
|
| 51 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
| 52 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
| 53 |
+
self.eps = eps
|
| 54 |
+
self.data_format = data_format
|
| 55 |
+
if self.data_format not in ["channels_last", "channels_first"]:
|
| 56 |
+
raise NotImplementedError
|
| 57 |
+
self.normalized_shape = (normalized_shape, )
|
| 58 |
+
|
| 59 |
+
def forward(self, x):
|
| 60 |
+
if self.data_format == "channels_last":
|
| 61 |
+
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
| 62 |
+
elif self.data_format == "channels_first":
|
| 63 |
+
u = x.mean(1, keepdim=True)
|
| 64 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
| 65 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
| 66 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
| 67 |
+
return x
|
| 68 |
+
|
| 69 |
+
class SCSEModule(nn.Module):
|
| 70 |
+
def __init__(self, in_channels, reduction=16):
|
| 71 |
+
super().__init__()
|
| 72 |
+
self.cSE = nn.Sequential(
|
| 73 |
+
nn.AdaptiveAvgPool2d(1),
|
| 74 |
+
nn.Conv2d(in_channels, in_channels // reduction, 1),
|
| 75 |
+
nn.ReLU(inplace=True),
|
| 76 |
+
nn.Conv2d(in_channels // reduction, in_channels, 1),
|
| 77 |
+
nn.Sigmoid(),
|
| 78 |
+
)
|
| 79 |
+
self.sSE = nn.Sequential(nn.Conv2d(in_channels, 1, 1), nn.Sigmoid())
|
| 80 |
+
|
| 81 |
+
def forward(self, x):
|
| 82 |
+
return x * self.cSE(x) + x * self.sSE(x)
|
| 83 |
+
|
| 84 |
+
class ConvBlock(nn.Module):
|
| 85 |
+
def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
|
| 86 |
+
super().__init__()
|
| 87 |
+
self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim)
|
| 88 |
+
self.norm = LayerNorm(dim, eps=1e-6)
|
| 89 |
+
self.pwconv1 = nn.Linear(dim, 4 * dim)
|
| 90 |
+
self.act = GELU()
|
| 91 |
+
self.pwconv2 = nn.Linear(4 * dim, dim)
|
| 92 |
+
self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) if layer_scale_init_value > 0 else None
|
| 93 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 94 |
+
|
| 95 |
+
def forward(self, x):
|
| 96 |
+
ipt = x
|
| 97 |
+
x = self.dwconv(x)
|
| 98 |
+
x = x.permute(0, 2, 3, 1)
|
| 99 |
+
x = self.norm(x)
|
| 100 |
+
x = self.pwconv1(x)
|
| 101 |
+
x = self.act(x)
|
| 102 |
+
x = self.pwconv2(x)
|
| 103 |
+
if self.gamma is not None:
|
| 104 |
+
x = self.gamma * x
|
| 105 |
+
x = x.permute(0, 3, 1, 2)
|
| 106 |
+
x = ipt + self.drop_path(x)
|
| 107 |
+
return x
|
| 108 |
+
|
| 109 |
+
class AddCoords(nn.Module):
|
| 110 |
+
def __init__(self, with_r=True):
|
| 111 |
+
super().__init__()
|
| 112 |
+
self.with_r = with_r
|
| 113 |
+
def forward(self, input_tensor):
|
| 114 |
+
batch_size, _, x_dim, y_dim = input_tensor.size()
|
| 115 |
+
xx_c, yy_c = torch.meshgrid(torch.arange(x_dim,dtype=input_tensor.dtype), torch.arange(y_dim,dtype=input_tensor.dtype))
|
| 116 |
+
xx_c = xx_c.to(input_tensor.device) / (x_dim - 1) * 2 - 1
|
| 117 |
+
yy_c = yy_c.to(input_tensor.device) / (y_dim - 1) * 2 - 1
|
| 118 |
+
xx_c = xx_c.expand(batch_size,1,x_dim,y_dim)
|
| 119 |
+
yy_c = yy_c.expand(batch_size,1,x_dim,y_dim)
|
| 120 |
+
ret = torch.cat((input_tensor,xx_c,yy_c), dim=1)
|
| 121 |
+
if self.with_r:
|
| 122 |
+
rr = torch.sqrt(torch.pow(xx_c - 0.5, 2) + torch.pow(yy_c - 0.5, 2))
|
| 123 |
+
ret = torch.cat([ret, rr], dim=1)
|
| 124 |
+
return ret
|
| 125 |
+
|
| 126 |
+
class VPH(nn.Module):
|
| 127 |
+
def __init__(self, dims=[96, 192], drop_path_rate=0.4, layer_scale_init_value=1e-6):
|
| 128 |
+
super().__init__()
|
| 129 |
+
dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
| 130 |
+
self.downsample_layers = nn.ModuleList([nn.Sequential(nn.Conv2d(6, dims[0], kernel_size=4, stride=4), LayerNorm(dims[0], eps=1e-6, data_format="channels_first")), nn.Sequential(LayerNorm(dims[1], eps=1e-6, data_format="channels_first"),nn.Conv2d(dims[1], dims[2], kernel_size=2, stride=2))])
|
| 131 |
+
self.stages = nn.ModuleList([nn.Sequential(*[ConvBlock(dim=dims[0], drop_path=dp_rates[j],layer_scale_init_value=layer_scale_init_value) for j in range(3)]), nn.Sequential(*[ConvBlock(dim=dims[1], drop_path=dp_rates[3 + j],layer_scale_init_value=layer_scale_init_value) for j in range(3)])])
|
| 132 |
+
self.apply(self._init_weights)
|
| 133 |
+
|
| 134 |
+
def initnorm(self):
|
| 135 |
+
norm_layer = partial(LayerNorm, eps=1e-6, data_format="channels_first")
|
| 136 |
+
for i_layer in range(4):
|
| 137 |
+
layer = norm_layer(self.dims[i_layer])
|
| 138 |
+
layer_name = f'norm{i_layer}'
|
| 139 |
+
self.add_module(layer_name, layer)
|
| 140 |
+
|
| 141 |
+
def _init_weights(self, m):
|
| 142 |
+
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
| 143 |
+
trunc_normal_(m.weight, std=.02)
|
| 144 |
+
nn.init.constant_(m.bias, 0)
|
| 145 |
+
|
| 146 |
+
def init_weights(self, pretrained=None):
|
| 147 |
+
def _init_weights(m):
|
| 148 |
+
if isinstance(m, nn.Linear):
|
| 149 |
+
trunc_normal_(m.weight, std=.02)
|
| 150 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 151 |
+
nn.init.constant_(m.bias, 0)
|
| 152 |
+
elif isinstance(m, nn.LayerNorm):
|
| 153 |
+
nn.init.constant_(m.bias, 0)
|
| 154 |
+
nn.init.constant_(m.weight, 1.0)
|
| 155 |
+
self.apply(_init_weights)
|
| 156 |
+
|
| 157 |
+
def forward(self, x):
|
| 158 |
+
outs = []
|
| 159 |
+
x = self.stages[0](self.downsample_layers[0](x))
|
| 160 |
+
outs = [self.norm0(x)]
|
| 161 |
+
x = self.stages[1](self.downsample_layers[1](x))
|
| 162 |
+
outs.append(self.norm1(x))
|
| 163 |
+
return outs
|
| 164 |
+
|
| 165 |
+
class SegmentationHead(nn.Sequential):
|
| 166 |
+
def __init__(self, in_channels, out_channels, kernel_size=3, activation=None, upsampling=1):
|
| 167 |
+
upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity()
|
| 168 |
+
conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2)
|
| 169 |
+
activation = md.Activation(activation)
|
| 170 |
+
super().__init__(conv2d, upsampling, activation)
|
| 171 |
+
|
| 172 |
+
class DecoderBlock(nn.Module):
|
| 173 |
+
def __init__(self,cin,cadd,cout,):
|
| 174 |
+
super().__init__()
|
| 175 |
+
self.cin = (cin + cadd)
|
| 176 |
+
self.cout = cout
|
| 177 |
+
self.conv1 = nn.Sequential(
|
| 178 |
+
nn.Conv2d(self.cin, self.cout, kernel_size=3, padding=1, bias=False),
|
| 179 |
+
nn.BatchNorm2d(self.cout),
|
| 180 |
+
nn.ReLU(inplace=True)
|
| 181 |
+
)
|
| 182 |
+
self.conv2 = nn.Sequential(
|
| 183 |
+
nn.Conv2d(self.cout, self.cout, kernel_size=3, padding=1, bias=False),
|
| 184 |
+
nn.BatchNorm2d(self.cout),
|
| 185 |
+
nn.ReLU(inplace=True)
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
def forward(self, x1, x2=None):
|
| 189 |
+
x1 = F.interpolate(x1, scale_factor=2.0, mode="nearest")
|
| 190 |
+
if x2 is not None:
|
| 191 |
+
x1 = torch.cat([x1, x2], dim=1)
|
| 192 |
+
x1 = self.conv1(x1[:,:self.cin])
|
| 193 |
+
x1 = self.conv2(x1)
|
| 194 |
+
return x1
|
| 195 |
+
|
| 196 |
+
class ConvBNReLU(nn.Module):
|
| 197 |
+
def __init__(self,in_c,out_c,ks,stride=1,norm=True,res=False):
|
| 198 |
+
super(ConvBNReLU, self).__init__()
|
| 199 |
+
if norm:
|
| 200 |
+
self.conv = nn.Sequential(nn.Conv2d(in_c, out_c, kernel_size=ks, padding = ks//2, stride=stride,bias=False),nn.BatchNorm2d(out_c),nn.ReLU(True))
|
| 201 |
+
else:
|
| 202 |
+
self.conv = nn.Conv2d(in_c, out_c, kernel_size=ks, padding = ks//2, stride=stride,bias=False)
|
| 203 |
+
self.res = res
|
| 204 |
+
def forward(self,x):
|
| 205 |
+
if self.res:
|
| 206 |
+
return (x + self.conv(x))
|
| 207 |
+
else:
|
| 208 |
+
return self.conv(x)
|
| 209 |
+
|
| 210 |
+
class FUSE1(nn.Module):
|
| 211 |
+
def __init__(self,in_channels_list=(96,192,384,768)):
|
| 212 |
+
super(FUSE1, self).__init__()
|
| 213 |
+
self.c31 = ConvBNReLU(in_channels_list[2],in_channels_list[2],1)
|
| 214 |
+
self.c32 = ConvBNReLU(in_channels_list[3],in_channels_list[2],1)
|
| 215 |
+
self.c33 = ConvBNReLU(in_channels_list[2],in_channels_list[2],3)
|
| 216 |
+
|
| 217 |
+
self.c21 = ConvBNReLU(in_channels_list[1],in_channels_list[1],1)
|
| 218 |
+
self.c22 = ConvBNReLU(in_channels_list[2],in_channels_list[1],1)
|
| 219 |
+
self.c23 = ConvBNReLU(in_channels_list[1],in_channels_list[1],3)
|
| 220 |
+
|
| 221 |
+
self.c11 = ConvBNReLU(in_channels_list[0],in_channels_list[0],1)
|
| 222 |
+
self.c12 = ConvBNReLU(in_channels_list[1],in_channels_list[0],1)
|
| 223 |
+
self.c13 = ConvBNReLU(in_channels_list[0],in_channels_list[0],3)
|
| 224 |
+
|
| 225 |
+
def forward(self,x):
|
| 226 |
+
x,x1,x2,x3 = x
|
| 227 |
+
h,w = x2.shape[-2:]
|
| 228 |
+
x2 = self.c33(F.interpolate(self.c32(x3),size=(h,w))+self.c31(x2))
|
| 229 |
+
h,w = x1.shape[-2:]
|
| 230 |
+
x1 = self.c23(F.interpolate(self.c22(x2),size=(h,w))+self.c21(x1))
|
| 231 |
+
h,w = x.shape[-2:]
|
| 232 |
+
x = self.c13(F.interpolate(self.c12(x1),size=(h,w))+self.c11(x))
|
| 233 |
+
return x,x1,x2,x3
|
| 234 |
+
|
| 235 |
+
class FUSE2(nn.Module):
|
| 236 |
+
def __init__(self,in_channels_list=(96,192,384)):
|
| 237 |
+
super(FUSE2, self).__init__()
|
| 238 |
+
|
| 239 |
+
self.c21 = ConvBNReLU(in_channels_list[1],in_channels_list[1],1)
|
| 240 |
+
self.c22 = ConvBNReLU(in_channels_list[2],in_channels_list[1],1)
|
| 241 |
+
self.c23 = ConvBNReLU(in_channels_list[1],in_channels_list[1],3)
|
| 242 |
+
|
| 243 |
+
self.c11 = ConvBNReLU(in_channels_list[0],in_channels_list[0],1)
|
| 244 |
+
self.c12 = ConvBNReLU(in_channels_list[1],in_channels_list[0],1)
|
| 245 |
+
self.c13 = ConvBNReLU(in_channels_list[0],in_channels_list[0],3)
|
| 246 |
+
|
| 247 |
+
def forward(self,x):
|
| 248 |
+
x,x1,x2 = x
|
| 249 |
+
h,w = x1.shape[-2:]
|
| 250 |
+
x1 = self.c23(F.interpolate(self.c22(x2),size=(h,w),mode='bilinear',align_corners=True)+self.c21(x1))
|
| 251 |
+
h,w = x.shape[-2:]
|
| 252 |
+
x = self.c13(F.interpolate(self.c12(x1),size=(h,w),mode='bilinear',align_corners=True)+self.c11(x))
|
| 253 |
+
return x,x1,x2
|
| 254 |
+
|
| 255 |
+
class FUSE3(nn.Module):
|
| 256 |
+
def __init__(self,in_channels_list=(96,192)):
|
| 257 |
+
super(FUSE3, self).__init__()
|
| 258 |
+
|
| 259 |
+
self.c11 = ConvBNReLU(in_channels_list[0],in_channels_list[0],1)
|
| 260 |
+
self.c12 = ConvBNReLU(in_channels_list[1],in_channels_list[0],1)
|
| 261 |
+
self.c13 = ConvBNReLU(in_channels_list[0],in_channels_list[0],3)
|
| 262 |
+
|
| 263 |
+
def forward(self,x):
|
| 264 |
+
x,x1 = x
|
| 265 |
+
h,w = x.shape[-2:]
|
| 266 |
+
x = self.c13(F.interpolate(self.c12(x1),size=(h,w),mode='bilinear',align_corners=True)+self.c11(x))
|
| 267 |
+
return x,x1
|
| 268 |
+
|
| 269 |
+
class MID(nn.Module):
|
| 270 |
+
def __init__(self, encoder_channels, decoder_channels):
|
| 271 |
+
super().__init__()
|
| 272 |
+
encoder_channels = encoder_channels[1:][::-1]
|
| 273 |
+
self.in_channels = [encoder_channels[0]] + list(decoder_channels[:-1])
|
| 274 |
+
self.add_channels = list(encoder_channels[1:]) + [96]
|
| 275 |
+
self.out_channels = decoder_channels
|
| 276 |
+
self.fuse1 = FUSE1()
|
| 277 |
+
self.fuse2 = FUSE2()
|
| 278 |
+
self.fuse3 = FUSE3()
|
| 279 |
+
decoder_convs = {}
|
| 280 |
+
for layer_idx in range(len(self.in_channels) - 1):
|
| 281 |
+
for depth_idx in range(layer_idx + 1):
|
| 282 |
+
if depth_idx == 0:
|
| 283 |
+
in_ch = self.in_channels[layer_idx]
|
| 284 |
+
skip_ch = self.add_channels[layer_idx] * (layer_idx + 1)
|
| 285 |
+
out_ch = self.out_channels[layer_idx]
|
| 286 |
+
else:
|
| 287 |
+
out_ch = self.add_channels[layer_idx]
|
| 288 |
+
skip_ch = self.add_channels[layer_idx] * (layer_idx + 1 - depth_idx)
|
| 289 |
+
in_ch = self.add_channels[layer_idx - 1]
|
| 290 |
+
decoder_convs[f"x_{depth_idx}_{layer_idx}"] = DecoderBlock(in_ch, skip_ch, out_ch)
|
| 291 |
+
decoder_convs[f"x_{0}_{len(self.in_channels)-1}"] = DecoderBlock(self.in_channels[-1], 0, self.out_channels[-1])
|
| 292 |
+
self.decoder_convs = nn.ModuleDict(decoder_convs)
|
| 293 |
+
|
| 294 |
+
def forward(self, *features):
|
| 295 |
+
decoder_features = {}
|
| 296 |
+
features = self.fuse1(features)[::-1]
|
| 297 |
+
decoder_features["x_0_0"] = self.decoder_convs["x_0_0"](features[0],features[1])
|
| 298 |
+
decoder_features["x_1_1"] = self.decoder_convs["x_1_1"](features[1],features[2])
|
| 299 |
+
decoder_features["x_2_2"] = self.decoder_convs["x_2_2"](features[2],features[3])
|
| 300 |
+
decoder_features["x_2_2"], decoder_features["x_1_1"], decoder_features["x_0_0"] = self.fuse2((decoder_features["x_2_2"], decoder_features["x_1_1"], decoder_features["x_0_0"]))
|
| 301 |
+
decoder_features["x_0_1"] = self.decoder_convs["x_0_1"](decoder_features["x_0_0"], torch.cat((decoder_features["x_1_1"], features[2]),1))
|
| 302 |
+
decoder_features["x_1_2"] = self.decoder_convs["x_1_2"](decoder_features["x_1_1"], torch.cat((decoder_features["x_2_2"], features[3]),1))
|
| 303 |
+
decoder_features["x_1_2"], decoder_features["x_0_1"] = self.fuse3((decoder_features["x_1_2"], decoder_features["x_0_1"]))
|
| 304 |
+
decoder_features["x_0_2"] = self.decoder_convs["x_0_2"](decoder_features["x_0_1"], torch.cat((decoder_features["x_1_2"], decoder_features["x_2_2"], features[3]),1))
|
| 305 |
+
return self.decoder_convs["x_0_3"](torch.cat((decoder_features["x_0_2"], decoder_features["x_1_2"], decoder_features["x_2_2"]),1))
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
class DTD(SegmentationModel):
|
| 309 |
+
def __init__(self, encoder_name = "resnet18", decoder_channels = (384, 192, 96, 64), classes = 1, device='cpu'):
|
| 310 |
+
super().__init__()
|
| 311 |
+
# Load models with proper device mapping
|
| 312 |
+
import os
|
| 313 |
+
model_dir = os.path.dirname(os.path.abspath(__file__))
|
| 314 |
+
vph_path = os.path.join(model_dir, '..', 'pths', 'vph_imagenet.pt')
|
| 315 |
+
swin_path = os.path.join(model_dir, '..', 'pths', 'swin_imagenet.pt')
|
| 316 |
+
|
| 317 |
+
if device == 'mps':
|
| 318 |
+
self.vph = torch.load(vph_path, map_location=torch.device('cpu'))
|
| 319 |
+
self.swin = torch.load(swin_path, map_location=torch.device('cpu'))
|
| 320 |
+
else:
|
| 321 |
+
self.vph = torch.load(vph_path, map_location=device)
|
| 322 |
+
self.swin = torch.load(swin_path, map_location=device)
|
| 323 |
+
self.fph = FPH()
|
| 324 |
+
self.decoder = MID(encoder_channels=(96, 192, 384, 768), decoder_channels=decoder_channels)
|
| 325 |
+
self.segmentation_head = SegmentationHead(in_channels=decoder_channels[-1], out_channels=classes, upsampling=2.0)
|
| 326 |
+
self.addcoords = AddCoords()
|
| 327 |
+
self.FU = nn.Sequential(SCSEModule(448),nn.Conv2d(448,192,3,1,1),nn.BatchNorm2d(192),nn.ReLU(True))
|
| 328 |
+
self.classification_head = None
|
| 329 |
+
self.initialize()
|
| 330 |
+
|
| 331 |
+
def forward(self,x,dct,qt):
|
| 332 |
+
features = self.vph(self.addcoords(x))
|
| 333 |
+
features[1] = self.FU(torch.cat((features[1],self.fph(dct,qt)),1))
|
| 334 |
+
rst = self.swin[0](features[1].flatten(2).transpose(1,2).contiguous())
|
| 335 |
+
N,L,C = rst.shape
|
| 336 |
+
H = W = int(L**(1/2))
|
| 337 |
+
features.append(self.vph.norm2(rst.transpose(1,2).contiguous().view(N,C,H,W)))
|
| 338 |
+
features.append(self.vph.norm3(self.swin[2](self.swin[1](rst)).transpose(1,2).contiguous().view(N,C*2,H//2,W//2)))
|
| 339 |
+
decoder_output = self.decoder(*features)
|
| 340 |
+
return self.segmentation_head(decoder_output)
|
| 341 |
+
|
| 342 |
+
class seg_dtd(nn.Module):
|
| 343 |
+
def __init__(self, model_name='resnet18', n_class=1, device='cpu'):
|
| 344 |
+
super().__init__()
|
| 345 |
+
self.model = DTD(encoder_name=model_name, classes=n_class, device=device)
|
| 346 |
+
self.device = device
|
| 347 |
+
|
| 348 |
+
def forward(self, x, dct, qt):
|
| 349 |
+
# Use autocast only for CUDA, not for MPS
|
| 350 |
+
if self.device == 'cuda':
|
| 351 |
+
with autocast():
|
| 352 |
+
x = self.model(x, dct, qt)
|
| 353 |
+
else:
|
| 354 |
+
x = self.model(x, dct, qt)
|
| 355 |
+
return x
|
| 356 |
+
|
models/fix_imports.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Simple import compatibility fix for timm
|
| 3 |
+
"""
|
| 4 |
+
import sys
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
try:
|
| 7 |
+
import timm.layers as new_layers
|
| 8 |
+
|
| 9 |
+
# Create fake modules for backward compatibility
|
| 10 |
+
sys.modules['timm.models.layers.drop'] = new_layers.drop
|
| 11 |
+
sys.modules['timm.models.layers'] = new_layers
|
| 12 |
+
|
| 13 |
+
# Also ensure the imports work
|
| 14 |
+
from timm.layers import DropPath, trunc_normal_
|
| 15 |
+
|
| 16 |
+
# Patch DropPath to add missing attribute
|
| 17 |
+
def patched_droppath_init(self, drop_prob=0., scale_by_keep=True):
|
| 18 |
+
super(DropPath, self).__init__()
|
| 19 |
+
self.drop_prob = drop_prob
|
| 20 |
+
self.scale_by_keep = scale_by_keep
|
| 21 |
+
|
| 22 |
+
# Save original
|
| 23 |
+
_original_droppath_init = DropPath.__init__
|
| 24 |
+
|
| 25 |
+
# Apply patch
|
| 26 |
+
DropPath.__init__ = patched_droppath_init
|
| 27 |
+
|
| 28 |
+
except ImportError:
|
| 29 |
+
pass
|
| 30 |
+
|
| 31 |
+
print("Import compatibility fixes applied")
|
models/fph.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from efficientnet_pytorch.utils import *
|
| 2 |
+
import os
|
| 3 |
+
import logging
|
| 4 |
+
import functools
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch._utils
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from functools import partial
|
| 11 |
+
try:
|
| 12 |
+
from timm.models.layers import trunc_normal_, DropPath
|
| 13 |
+
except ImportError:
|
| 14 |
+
from timm.layers import trunc_normal_, DropPath
|
| 15 |
+
import collections
|
| 16 |
+
|
| 17 |
+
BlockArgs = collections.namedtuple('BlockArgs', ['num_repeat', 'kernel_size', 'stride', 'expand_ratio','input_filters', 'output_filters', 'se_ratio', 'id_skip'])
|
| 18 |
+
GlobalParams = collections.namedtuple('GlobalParams', ['width_coefficient', 'depth_coefficient', 'image_size', 'dropout_rate','num_classes', 'batch_norm_momentum', 'batch_norm_epsilon','drop_connect_rate', 'depth_divisor', 'min_depth', 'include_top'])
|
| 19 |
+
global_params = GlobalParams(width_coefficient=1.8, depth_coefficient=2.6, image_size=528, dropout_rate=0.0, num_classes=1000, batch_norm_momentum=0.99, batch_norm_epsilon=0.001, drop_connect_rate=0.0, depth_divisor=8, min_depth=None, include_top=True)
|
| 20 |
+
|
| 21 |
+
def get_width_and_height_from_size(x):
|
| 22 |
+
if isinstance(x, int):
|
| 23 |
+
return x, x
|
| 24 |
+
if isinstance(x, list) or isinstance(x, tuple):
|
| 25 |
+
return x
|
| 26 |
+
else:
|
| 27 |
+
raise TypeError()
|
| 28 |
+
|
| 29 |
+
def calculate_output_image_size(input_image_size, stride):
|
| 30 |
+
if input_image_size is None:
|
| 31 |
+
return None
|
| 32 |
+
image_height, image_width = get_width_and_height_from_size(input_image_size)
|
| 33 |
+
stride = stride if isinstance(stride, int) else stride[0]
|
| 34 |
+
image_height = int(math.ceil(image_height / stride))
|
| 35 |
+
image_width = int(math.ceil(image_width / stride))
|
| 36 |
+
return [image_height, image_width]
|
| 37 |
+
|
| 38 |
+
class MBConvBlock(nn.Module):
|
| 39 |
+
def __init__(self, block_args, global_params, image_size=25):
|
| 40 |
+
super().__init__()
|
| 41 |
+
self._block_args = block_args
|
| 42 |
+
self._bn_mom = 1 - global_params.batch_norm_momentum # pytorch's difference from tensorflow
|
| 43 |
+
self._bn_eps = global_params.batch_norm_epsilon
|
| 44 |
+
self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1)
|
| 45 |
+
self.id_skip = block_args.id_skip # whether to use skip connection and drop connect
|
| 46 |
+
inp = self._block_args.input_filters # number of input channels
|
| 47 |
+
oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels
|
| 48 |
+
if self._block_args.expand_ratio != 1:
|
| 49 |
+
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
| 50 |
+
self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False)
|
| 51 |
+
self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
|
| 52 |
+
k = self._block_args.kernel_size
|
| 53 |
+
s = self._block_args.stride
|
| 54 |
+
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
| 55 |
+
self._depthwise_conv = Conv2d(
|
| 56 |
+
in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise
|
| 57 |
+
kernel_size=k, stride=s, bias=False)
|
| 58 |
+
self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
|
| 59 |
+
image_size = calculate_output_image_size(image_size, s)
|
| 60 |
+
if self.has_se:
|
| 61 |
+
Conv2d = get_same_padding_conv2d(image_size=(1, 1))
|
| 62 |
+
num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio))
|
| 63 |
+
self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
|
| 64 |
+
self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)
|
| 65 |
+
final_oup = self._block_args.output_filters
|
| 66 |
+
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
| 67 |
+
self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)
|
| 68 |
+
self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps)
|
| 69 |
+
self._swish = MemoryEfficientSwish()
|
| 70 |
+
|
| 71 |
+
def forward(self, inputs, drop_connect_rate=None):
|
| 72 |
+
x = inputs
|
| 73 |
+
if self._block_args.expand_ratio != 1:
|
| 74 |
+
x = self._expand_conv(inputs)
|
| 75 |
+
x = self._bn0(x)
|
| 76 |
+
x = self._swish(x)
|
| 77 |
+
x = self._depthwise_conv(x)
|
| 78 |
+
x = self._bn1(x)
|
| 79 |
+
x = self._swish(x)
|
| 80 |
+
if self.has_se:
|
| 81 |
+
x_squeezed = F.adaptive_avg_pool2d(x, 1)
|
| 82 |
+
x_squeezed = self._se_reduce(x_squeezed)
|
| 83 |
+
x_squeezed = self._swish(x_squeezed)
|
| 84 |
+
x_squeezed = self._se_expand(x_squeezed)
|
| 85 |
+
x = torch.sigmoid(x_squeezed) * x
|
| 86 |
+
x = self._project_conv(x)
|
| 87 |
+
x = self._bn2(x)
|
| 88 |
+
input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters
|
| 89 |
+
if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters:
|
| 90 |
+
if drop_connect_rate:
|
| 91 |
+
x = drop_connect(x, p=drop_connect_rate, training=self.training)
|
| 92 |
+
x = x + inputs # skip connection
|
| 93 |
+
return x
|
| 94 |
+
|
| 95 |
+
def set_swish(self, memory_efficient=True):
|
| 96 |
+
self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
|
| 97 |
+
|
| 98 |
+
class AddCoords(nn.Module):
|
| 99 |
+
def __init__(self, with_r=True):
|
| 100 |
+
super().__init__()
|
| 101 |
+
self.with_r = with_r
|
| 102 |
+
def forward(self, input_tensor):
|
| 103 |
+
batch_size, _, x_dim, y_dim = input_tensor.size()
|
| 104 |
+
xx_c, yy_c = torch.meshgrid(torch.arange(x_dim,dtype=input_tensor.dtype), torch.arange(y_dim,dtype=input_tensor.dtype))
|
| 105 |
+
xx_c = xx_c.to(input_tensor.device) / (x_dim - 1) * 2 - 1
|
| 106 |
+
yy_c = yy_c.to(input_tensor.device) / (y_dim - 1) * 2 - 1
|
| 107 |
+
xx_c = xx_c.expand(batch_size,1,x_dim,y_dim)
|
| 108 |
+
yy_c = yy_c.expand(batch_size,1,x_dim,y_dim)
|
| 109 |
+
ret = torch.cat((input_tensor,xx_c,yy_c), dim=1)
|
| 110 |
+
if self.with_r:
|
| 111 |
+
rr = torch.sqrt(torch.pow(xx_c - 0.5, 2) + torch.pow(yy_c - 0.5, 2))
|
| 112 |
+
ret = torch.cat([ret, rr], dim=1)
|
| 113 |
+
return ret
|
| 114 |
+
|
| 115 |
+
class FPH(nn.Module):
|
| 116 |
+
|
| 117 |
+
def __init__(self):
|
| 118 |
+
super(FPH, self).__init__()
|
| 119 |
+
self.obembed = nn.Embedding(21,21).from_pretrained(torch.eye(21))
|
| 120 |
+
self.qtembed = nn.Embedding(64,16)
|
| 121 |
+
self.conv1 = nn.Sequential(nn.Conv2d(in_channels=21,out_channels=64,kernel_size=3,stride=1,dilation=8,padding=8),nn.BatchNorm2d(64, momentum=0.01),nn.ReLU(inplace=True))
|
| 122 |
+
self.conv2 = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=16, kernel_size=1, stride=1, padding=0, bias=False),nn.BatchNorm2d(16, momentum=0.01),nn.ReLU(inplace=True))
|
| 123 |
+
self.addcoords = AddCoords()
|
| 124 |
+
repeats = (1,1,1)
|
| 125 |
+
in_channles = (256,256,256)
|
| 126 |
+
out_channles = (256,256,512)
|
| 127 |
+
self.conv0 = nn.Sequential(nn.Conv2d(in_channels=35, out_channels=256, kernel_size=8, stride=8, padding=0, bias=False),nn.BatchNorm2d(256, momentum=0.01),nn.ReLU(inplace=True),MBConvBlock(BlockArgs(num_repeat=repeats[0], kernel_size=3, stride=[1], expand_ratio=6, input_filters=in_channles[0], output_filters=in_channles[1], se_ratio=0.25, id_skip=True), global_params),MBConvBlock(BlockArgs(num_repeat=repeats[0], kernel_size=3, stride=[1], expand_ratio=6, input_filters=in_channles[1], output_filters=in_channles[1], se_ratio=0.25, id_skip=True), global_params),MBConvBlock(BlockArgs(num_repeat=repeats[0], kernel_size=3, stride=[1], expand_ratio=6, input_filters=in_channles[1], output_filters=in_channles[1], se_ratio=0.25, id_skip=True), global_params),)
|
| 128 |
+
|
| 129 |
+
def forward(self, x, qtable):
|
| 130 |
+
x = self.conv2(self.conv1(self.obembed(x).permute(0,3,1,2).contiguous()))
|
| 131 |
+
B, C, H, W = x.shape
|
| 132 |
+
return self.conv0(self.addcoords(torch.cat(((x.reshape(B,C,H//8,8,W//8,8).permute(0,1,3,5,2,4)*self.qtembed(qtable.unsqueeze(-1).unsqueeze(-1).long()).transpose(1,6).squeeze(6).contiguous()).permute(0,1,4,2,5,3).reshape(B,C,H,W),x), dim=1)))
|
models/patch_droppath.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Patch DropPath for compatibility
|
| 3 |
+
"""
|
| 4 |
+
try:
|
| 5 |
+
from timm.layers import DropPath
|
| 6 |
+
|
| 7 |
+
# Patch existing instances to add scale_by_keep
|
| 8 |
+
def patched_droppath_getattr(self, name):
|
| 9 |
+
if name == 'scale_by_keep':
|
| 10 |
+
return True
|
| 11 |
+
return object.__getattribute__(self, name)
|
| 12 |
+
|
| 13 |
+
DropPath.__getattr__ = patched_droppath_getattr
|
| 14 |
+
|
| 15 |
+
print("DropPath patched for compatibility")
|
| 16 |
+
except ImportError:
|
| 17 |
+
pass
|
models/patch_gelu.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Monkey patch GELU to fix compatibility issues
|
| 3 |
+
"""
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
# Patch the forward method of existing GELU instances
|
| 9 |
+
def patched_gelu_forward(self, input):
|
| 10 |
+
return F.gelu(input)
|
| 11 |
+
|
| 12 |
+
# Save original
|
| 13 |
+
_original_gelu_forward = nn.GELU.forward
|
| 14 |
+
|
| 15 |
+
# Apply patch
|
| 16 |
+
nn.GELU.forward = patched_gelu_forward
|
| 17 |
+
|
| 18 |
+
# Also create a new GELU class
|
| 19 |
+
class PatchedGELU(nn.Module):
|
| 20 |
+
def __init__(self, approximate='none'):
|
| 21 |
+
super().__init__()
|
| 22 |
+
|
| 23 |
+
def forward(self, input):
|
| 24 |
+
return F.gelu(input)
|
| 25 |
+
|
| 26 |
+
def __getattr__(self, name):
|
| 27 |
+
if name == 'approximate':
|
| 28 |
+
return 'none'
|
| 29 |
+
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
|
| 30 |
+
|
| 31 |
+
# Replace the class too
|
| 32 |
+
nn.GELU = PatchedGELU
|
| 33 |
+
|
| 34 |
+
print("GELU patched for compatibility")
|
models/swins.py
ADDED
|
@@ -0,0 +1,454 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import torch.utils.checkpoint as checkpoint
|
| 6 |
+
try:
|
| 7 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
| 8 |
+
except ImportError:
|
| 9 |
+
from timm.layers import DropPath, to_2tuple, trunc_normal_
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Mlp(nn.Module):
|
| 14 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
| 15 |
+
super().__init__()
|
| 16 |
+
out_features = out_features or in_features
|
| 17 |
+
hidden_features = hidden_features or in_features
|
| 18 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 19 |
+
self.act = nn.GELU()# act_layer()
|
| 20 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 21 |
+
self.drop = nn.Dropout(drop)
|
| 22 |
+
|
| 23 |
+
def forward(self, x):
|
| 24 |
+
x = self.fc1(x)
|
| 25 |
+
x = F.gelu(x)
|
| 26 |
+
x = self.drop(x)
|
| 27 |
+
x = self.fc2(x)
|
| 28 |
+
x = self.drop(x)
|
| 29 |
+
return x
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def window_partition(x, window_size):
|
| 33 |
+
B, H, W, C = x.shape
|
| 34 |
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
| 35 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
| 36 |
+
return windows
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def window_reverse(windows, window_size, H, W):
|
| 40 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
| 41 |
+
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
| 42 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
| 43 |
+
return x
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class WindowAttention(nn.Module):
|
| 47 |
+
def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
|
| 48 |
+
pretrained_window_size=[0, 0]):
|
| 49 |
+
|
| 50 |
+
super().__init__()
|
| 51 |
+
self.dim = dim
|
| 52 |
+
self.window_size = window_size # Wh, Ww
|
| 53 |
+
self.pretrained_window_size = pretrained_window_size
|
| 54 |
+
self.num_heads = num_heads
|
| 55 |
+
|
| 56 |
+
self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)
|
| 57 |
+
|
| 58 |
+
# mlp to generate continuous relative position bias
|
| 59 |
+
self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),
|
| 60 |
+
nn.ReLU(inplace=True),
|
| 61 |
+
nn.Linear(512, num_heads, bias=False))
|
| 62 |
+
|
| 63 |
+
# get relative_coords_table
|
| 64 |
+
relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
|
| 65 |
+
relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
|
| 66 |
+
relative_coords_table = torch.stack(
|
| 67 |
+
torch.meshgrid([relative_coords_h,
|
| 68 |
+
relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
|
| 69 |
+
if pretrained_window_size[0] > 0:
|
| 70 |
+
relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
|
| 71 |
+
relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
|
| 72 |
+
else:
|
| 73 |
+
relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
|
| 74 |
+
relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
|
| 75 |
+
relative_coords_table *= 8 # normalize to -8, 8
|
| 76 |
+
relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
|
| 77 |
+
torch.abs(relative_coords_table) + 1.0) / np.log2(8)
|
| 78 |
+
|
| 79 |
+
self.register_buffer("relative_coords_table", relative_coords_table)
|
| 80 |
+
|
| 81 |
+
# get pair-wise relative position index for each token inside the window
|
| 82 |
+
coords_h = torch.arange(self.window_size[0])
|
| 83 |
+
coords_w = torch.arange(self.window_size[1])
|
| 84 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
| 85 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
| 86 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
| 87 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
| 88 |
+
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
|
| 89 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
| 90 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
| 91 |
+
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
| 92 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
| 93 |
+
|
| 94 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=False)
|
| 95 |
+
if qkv_bias:
|
| 96 |
+
self.q_bias = nn.Parameter(torch.zeros(dim))
|
| 97 |
+
self.v_bias = nn.Parameter(torch.zeros(dim))
|
| 98 |
+
else:
|
| 99 |
+
self.q_bias = None
|
| 100 |
+
self.v_bias = None
|
| 101 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 102 |
+
self.proj = nn.Linear(dim, dim)
|
| 103 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 104 |
+
self.softmax = nn.Softmax(dim=-1)
|
| 105 |
+
|
| 106 |
+
def forward(self, x, mask=None):
|
| 107 |
+
B_, N, C = x.shape
|
| 108 |
+
qkv_bias = None
|
| 109 |
+
if self.q_bias is not None:
|
| 110 |
+
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
|
| 111 |
+
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
| 112 |
+
qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
| 113 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
| 114 |
+
|
| 115 |
+
# cosine attention
|
| 116 |
+
attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
|
| 117 |
+
logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01,device=attn.device))).exp()
|
| 118 |
+
attn = attn * logit_scale
|
| 119 |
+
|
| 120 |
+
relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
|
| 121 |
+
relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
| 122 |
+
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
| 123 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
| 124 |
+
relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
|
| 125 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
| 126 |
+
|
| 127 |
+
if mask is not None:
|
| 128 |
+
nW = mask.shape[0]
|
| 129 |
+
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
| 130 |
+
attn = attn.view(-1, self.num_heads, N, N)
|
| 131 |
+
attn = self.softmax(attn)
|
| 132 |
+
else:
|
| 133 |
+
attn = self.softmax(attn)
|
| 134 |
+
|
| 135 |
+
attn = self.attn_drop(attn)
|
| 136 |
+
|
| 137 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
| 138 |
+
x = self.proj(x)
|
| 139 |
+
x = self.proj_drop(x)
|
| 140 |
+
return x
|
| 141 |
+
|
| 142 |
+
def extra_repr(self) -> str:
|
| 143 |
+
return f'dim={self.dim}, window_size={self.window_size}, ' \
|
| 144 |
+
f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}'
|
| 145 |
+
|
| 146 |
+
class SwinTransformerBlock(nn.Module):
|
| 147 |
+
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
|
| 148 |
+
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
|
| 149 |
+
act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0):
|
| 150 |
+
super().__init__()
|
| 151 |
+
self.dim = dim
|
| 152 |
+
self.input_resolution = input_resolution
|
| 153 |
+
self.num_heads = num_heads
|
| 154 |
+
self.window_size = window_size
|
| 155 |
+
self.shift_size = shift_size
|
| 156 |
+
self.mlp_ratio = mlp_ratio
|
| 157 |
+
if min(self.input_resolution) <= self.window_size:
|
| 158 |
+
# if window size is larger than input resolution, we don't partition windows
|
| 159 |
+
self.shift_size = 0
|
| 160 |
+
self.window_size = min(self.input_resolution)
|
| 161 |
+
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
| 162 |
+
|
| 163 |
+
self.norm1 = norm_layer(dim)
|
| 164 |
+
self.attn = WindowAttention(
|
| 165 |
+
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
|
| 166 |
+
qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
|
| 167 |
+
pretrained_window_size=to_2tuple(pretrained_window_size))
|
| 168 |
+
|
| 169 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 170 |
+
self.norm2 = norm_layer(dim)
|
| 171 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 172 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 173 |
+
|
| 174 |
+
if self.shift_size > 0:
|
| 175 |
+
# calculate attention mask for SW-MSA
|
| 176 |
+
H, W = self.input_resolution
|
| 177 |
+
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
| 178 |
+
h_slices = (slice(0, -self.window_size),
|
| 179 |
+
slice(-self.window_size, -self.shift_size),
|
| 180 |
+
slice(-self.shift_size, None))
|
| 181 |
+
w_slices = (slice(0, -self.window_size),
|
| 182 |
+
slice(-self.window_size, -self.shift_size),
|
| 183 |
+
slice(-self.shift_size, None))
|
| 184 |
+
cnt = 0
|
| 185 |
+
for h in h_slices:
|
| 186 |
+
for w in w_slices:
|
| 187 |
+
img_mask[:, h, w, :] = cnt
|
| 188 |
+
cnt += 1
|
| 189 |
+
|
| 190 |
+
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
|
| 191 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
| 192 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
| 193 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
| 194 |
+
else:
|
| 195 |
+
attn_mask = None
|
| 196 |
+
|
| 197 |
+
self.register_buffer("attn_mask", attn_mask)
|
| 198 |
+
|
| 199 |
+
def forward(self, x):
|
| 200 |
+
# H, W = self.input_resolution
|
| 201 |
+
B, L, C = x.shape
|
| 202 |
+
H = W = int(L**(1/2))
|
| 203 |
+
assert L == H * W, "input feature has wrong size"
|
| 204 |
+
|
| 205 |
+
shortcut = x
|
| 206 |
+
x = x.view(B, H, W, C)
|
| 207 |
+
|
| 208 |
+
# cyclic shift
|
| 209 |
+
if self.shift_size > 0:
|
| 210 |
+
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
| 211 |
+
else:
|
| 212 |
+
shifted_x = x
|
| 213 |
+
|
| 214 |
+
# partition windows
|
| 215 |
+
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
| 216 |
+
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
|
| 217 |
+
|
| 218 |
+
# W-MSA/SW-MSA
|
| 219 |
+
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
|
| 220 |
+
|
| 221 |
+
# merge windows
|
| 222 |
+
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
| 223 |
+
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
|
| 224 |
+
|
| 225 |
+
# reverse cyclic shift
|
| 226 |
+
if self.shift_size > 0:
|
| 227 |
+
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
| 228 |
+
else:
|
| 229 |
+
x = shifted_x
|
| 230 |
+
x = x.view(B, H * W, C)
|
| 231 |
+
x = shortcut + self.norm1(x)##self.drop_path(self.norm1(x))
|
| 232 |
+
|
| 233 |
+
# FFN
|
| 234 |
+
x = x + self.norm2(self.mlp(x))##self.drop_path(self.norm2(self.mlp(x)))
|
| 235 |
+
|
| 236 |
+
return x
|
| 237 |
+
|
| 238 |
+
def extra_repr(self) -> str:
|
| 239 |
+
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
|
| 240 |
+
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
class PatchMerging(nn.Module):
|
| 244 |
+
|
| 245 |
+
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
|
| 246 |
+
super().__init__()
|
| 247 |
+
self.input_resolution = input_resolution
|
| 248 |
+
self.dim = dim
|
| 249 |
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
| 250 |
+
self.norm = norm_layer(2 * dim)
|
| 251 |
+
|
| 252 |
+
def forward(self, x):
|
| 253 |
+
"""
|
| 254 |
+
x: B, H*W, C
|
| 255 |
+
"""
|
| 256 |
+
# H, W = self.input_resolution
|
| 257 |
+
B, L, C = x.shape
|
| 258 |
+
H = W = int(L**(1/2))
|
| 259 |
+
assert L == H * W, "input feature has wrong size"
|
| 260 |
+
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
|
| 261 |
+
|
| 262 |
+
x = x.view(B, H, W, C)
|
| 263 |
+
|
| 264 |
+
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
| 265 |
+
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
| 266 |
+
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
| 267 |
+
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
| 268 |
+
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
| 269 |
+
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
| 270 |
+
|
| 271 |
+
x = self.reduction(x)
|
| 272 |
+
x = self.norm(x)
|
| 273 |
+
|
| 274 |
+
return x
|
| 275 |
+
|
| 276 |
+
def extra_repr(self) -> str:
|
| 277 |
+
return f"input_resolution={self.input_resolution}, dim={self.dim}"
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
class BasicLayer(nn.Module):
|
| 281 |
+
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
|
| 282 |
+
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
|
| 283 |
+
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
|
| 284 |
+
pretrained_window_size=0):
|
| 285 |
+
|
| 286 |
+
super().__init__()
|
| 287 |
+
self.dim = dim
|
| 288 |
+
self.input_resolution = input_resolution
|
| 289 |
+
self.depth = depth
|
| 290 |
+
self.use_checkpoint = use_checkpoint
|
| 291 |
+
|
| 292 |
+
# build blocks
|
| 293 |
+
self.blocks = nn.ModuleList([
|
| 294 |
+
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
|
| 295 |
+
num_heads=num_heads, window_size=window_size,
|
| 296 |
+
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
| 297 |
+
mlp_ratio=mlp_ratio,
|
| 298 |
+
qkv_bias=qkv_bias,
|
| 299 |
+
drop=drop, attn_drop=attn_drop,
|
| 300 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
| 301 |
+
norm_layer=norm_layer,
|
| 302 |
+
pretrained_window_size=pretrained_window_size)
|
| 303 |
+
for i in range(depth)])
|
| 304 |
+
|
| 305 |
+
# patch merging layer
|
| 306 |
+
if downsample is not None:
|
| 307 |
+
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
|
| 308 |
+
else:
|
| 309 |
+
self.downsample = None
|
| 310 |
+
|
| 311 |
+
def forward(self, x):
|
| 312 |
+
for blk in self.blocks:
|
| 313 |
+
if self.use_checkpoint:
|
| 314 |
+
x = checkpoint.checkpoint(blk, x)
|
| 315 |
+
else:
|
| 316 |
+
x = blk(x)
|
| 317 |
+
if self.downsample is not None:
|
| 318 |
+
x = self.downsample(x)
|
| 319 |
+
return x
|
| 320 |
+
|
| 321 |
+
def extra_repr(self) -> str:
|
| 322 |
+
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
|
| 323 |
+
|
| 324 |
+
def _init_respostnorm(self):
|
| 325 |
+
for blk in self.blocks:
|
| 326 |
+
nn.init.constant_(blk.norm1.bias, 0)
|
| 327 |
+
nn.init.constant_(blk.norm1.weight, 0)
|
| 328 |
+
nn.init.constant_(blk.norm2.bias, 0)
|
| 329 |
+
nn.init.constant_(blk.norm2.weight, 0)
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
class PatchEmbed(nn.Module):
|
| 333 |
+
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
| 334 |
+
super().__init__()
|
| 335 |
+
img_size = to_2tuple(img_size)
|
| 336 |
+
patch_size = to_2tuple(patch_size)
|
| 337 |
+
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
|
| 338 |
+
self.img_size = img_size
|
| 339 |
+
self.patch_size = patch_size
|
| 340 |
+
self.patches_resolution = patches_resolution
|
| 341 |
+
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
| 342 |
+
|
| 343 |
+
self.in_chans = in_chans
|
| 344 |
+
self.embed_dim = embed_dim
|
| 345 |
+
|
| 346 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
| 347 |
+
if norm_layer is not None:
|
| 348 |
+
self.norm = norm_layer(embed_dim)
|
| 349 |
+
else:
|
| 350 |
+
self.norm = None
|
| 351 |
+
|
| 352 |
+
def forward(self, x):
|
| 353 |
+
B, C, H, W = x.shape
|
| 354 |
+
# FIXME look at relaxing size constraints
|
| 355 |
+
assert H == self.img_size[0] and W == self.img_size[1], \
|
| 356 |
+
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
| 357 |
+
x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
|
| 358 |
+
if self.norm is not None:
|
| 359 |
+
x = self.norm(x)
|
| 360 |
+
return x
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
class SwinTransformerV2(nn.Module):
|
| 364 |
+
def __init__(self, img_size=256, patch_size=4, in_chans=3, num_classes=1000,
|
| 365 |
+
embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32],
|
| 366 |
+
window_size=8, mlp_ratio=4., qkv_bias=True,
|
| 367 |
+
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.0,
|
| 368 |
+
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
| 369 |
+
use_checkpoint=False, pretrained_window_sizes=[8, 8, 8, 6], **kwargs):
|
| 370 |
+
super().__init__()
|
| 371 |
+
|
| 372 |
+
self.num_classes = num_classes
|
| 373 |
+
self.num_layers = len(depths)
|
| 374 |
+
self.embed_dim = embed_dim
|
| 375 |
+
self.ape = ape
|
| 376 |
+
self.patch_norm = patch_norm
|
| 377 |
+
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
|
| 378 |
+
self.mlp_ratio = mlp_ratio
|
| 379 |
+
|
| 380 |
+
# split image into non-overlapping patches
|
| 381 |
+
self.patch_embed = PatchEmbed(
|
| 382 |
+
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
|
| 383 |
+
norm_layer=norm_layer if self.patch_norm else None)
|
| 384 |
+
num_patches = self.patch_embed.num_patches
|
| 385 |
+
patches_resolution = self.patch_embed.patches_resolution
|
| 386 |
+
self.patches_resolution = patches_resolution
|
| 387 |
+
|
| 388 |
+
# absolute position embedding
|
| 389 |
+
if self.ape:
|
| 390 |
+
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
|
| 391 |
+
trunc_normal_(self.absolute_pos_embed, std=.02)
|
| 392 |
+
|
| 393 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
| 394 |
+
|
| 395 |
+
# stochastic depth
|
| 396 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
| 397 |
+
|
| 398 |
+
# build layers
|
| 399 |
+
self.layers = nn.ModuleList()
|
| 400 |
+
for i_layer in range(self.num_layers):
|
| 401 |
+
layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
|
| 402 |
+
input_resolution=(patches_resolution[0] // (2 ** i_layer),
|
| 403 |
+
patches_resolution[1] // (2 ** i_layer)),
|
| 404 |
+
depth=depths[i_layer],
|
| 405 |
+
num_heads=num_heads[i_layer],
|
| 406 |
+
window_size=window_size,
|
| 407 |
+
mlp_ratio=self.mlp_ratio,
|
| 408 |
+
qkv_bias=qkv_bias,
|
| 409 |
+
drop=drop_rate, attn_drop=attn_drop_rate,
|
| 410 |
+
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
|
| 411 |
+
norm_layer=norm_layer,
|
| 412 |
+
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
|
| 413 |
+
use_checkpoint=use_checkpoint,
|
| 414 |
+
pretrained_window_size=pretrained_window_sizes[i_layer])
|
| 415 |
+
self.layers.append(layer)
|
| 416 |
+
|
| 417 |
+
self.norm = norm_layer(self.num_features)
|
| 418 |
+
self.avgpool = nn.AdaptiveAvgPool1d(1)
|
| 419 |
+
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
| 420 |
+
|
| 421 |
+
self.apply(self._init_weights)
|
| 422 |
+
for bly in self.layers:
|
| 423 |
+
bly._init_respostnorm()
|
| 424 |
+
|
| 425 |
+
def _init_weights(self, m):
|
| 426 |
+
if isinstance(m, nn.Linear):
|
| 427 |
+
trunc_normal_(m.weight, std=.02)
|
| 428 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 429 |
+
nn.init.constant_(m.bias, 0)
|
| 430 |
+
elif isinstance(m, nn.LayerNorm):
|
| 431 |
+
nn.init.constant_(m.bias, 0)
|
| 432 |
+
nn.init.constant_(m.weight, 1.0)
|
| 433 |
+
|
| 434 |
+
@torch.jit.ignore
|
| 435 |
+
def no_weight_decay(self):
|
| 436 |
+
return {'absolute_pos_embed'}
|
| 437 |
+
|
| 438 |
+
@torch.jit.ignore
|
| 439 |
+
def no_weight_decay_keywords(self):
|
| 440 |
+
return {"cpb_mlp", "logit_scale", 'relative_position_bias_table'}
|
| 441 |
+
|
| 442 |
+
def forward(self, x):
|
| 443 |
+
x = self.patch_embed(x)
|
| 444 |
+
if self.ape:
|
| 445 |
+
x = x + self.absolute_pos_embed
|
| 446 |
+
x = self.pos_drop(x)
|
| 447 |
+
|
| 448 |
+
for li,layer in enumerate(self.layers):
|
| 449 |
+
print(li,'0',x.shape)
|
| 450 |
+
x = layer(x)
|
| 451 |
+
print(li,'1',x.shape)
|
| 452 |
+
|
| 453 |
+
x = self.norm(x)
|
| 454 |
+
return x
|
requirements.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio==4.44.0
|
| 2 |
+
torch==2.0.1
|
| 3 |
+
torchvision==0.15.2
|
| 4 |
+
opencv-python-headless==4.8.1.78
|
| 5 |
+
numpy==1.24.3
|
| 6 |
+
pillow==10.0.0
|
| 7 |
+
jpegio==0.2.3
|
| 8 |
+
segmentation-models-pytorch==0.3.3
|
| 9 |
+
timm==0.9.12
|
| 10 |
+
efficientnet-pytorch==0.7.1
|
| 11 |
+
albumentations==1.3.1
|