astrosbd seifbenayed commited on
Commit
58d7142
·
verified ·
1 Parent(s): b35b448

let's go (#2)

Browse files

- let's go (28a202cb896ecb7388dcd60a130a61a763b87287)


Co-authored-by: Seif benayed <seifbenayed@users.noreply.huggingface.co>

.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/Paystub.jpg filter=lfs diff=lfs merge=lfs -text
37
+ examples/TamperedPaystub.jpg filter=lfs diff=lfs merge=lfs -text
38
+ examples/TamperedPaystubv1.jpg filter=lfs diff=lfs merge=lfs -text
DEPLOYMENT.md ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚀 Deployment Guide - DTD Document Tampering Detection
2
+
3
+ ## Hugging Face Spaces Deployment
4
+
5
+ ### Prerequisites
6
+
7
+ - Hugging Face account
8
+ - Git installed locally
9
+ - Git LFS installed (for large model files)
10
+
11
+ ### Step 1: Install Git LFS
12
+
13
+ ```bash
14
+ # Mac
15
+ brew install git-lfs
16
+
17
+ # Linux
18
+ sudo apt-get install git-lfs
19
+
20
+ # Initialize Git LFS
21
+ git lfs install
22
+ ```
23
+
24
+ ### Step 2: Create Hugging Face Space
25
+
26
+ 1. Go to https://huggingface.co/new-space
27
+ 2. Choose a name (e.g., `dtd-doctamper-detection`)
28
+ 3. Select **Gradio** as SDK
29
+ 4. Choose license: **MIT**
30
+ 5. Click **Create Space**
31
+
32
+ ### Step 3: Clone and Setup
33
+
34
+ ```bash
35
+ # Clone your new space
36
+ git clone https://huggingface.co/spaces/YOUR_USERNAME/dtd-doctamper-detection
37
+ cd dtd-doctamper-detection
38
+
39
+ # Copy app files
40
+ cp -r /path/to/gradio_dtd_app/* .
41
+
42
+ # Track large files with Git LFS
43
+ git lfs track "*.pth"
44
+ git lfs track "*.pt"
45
+ git add .gitattributes
46
+
47
+ # Add all files
48
+ git add .
49
+
50
+ # Commit
51
+ git commit -m "Initial commit: DTD document tampering detection app"
52
+
53
+ # Push to Hugging Face
54
+ git push
55
+ ```
56
+
57
+ ### Step 4: Configure Space Settings
58
+
59
+ After pushing, Hugging Face will automatically:
60
+ - Install dependencies from `requirements.txt`
61
+ - Build the Docker container
62
+ - Start the Gradio app
63
+ - Assign a public URL
64
+
65
+ ### Step 5: Test Your Space
66
+
67
+ Visit: `https://huggingface.co/spaces/YOUR_USERNAME/dtd-doctamper-detection`
68
+
69
+ ## Local Testing
70
+
71
+ Before deploying, test locally:
72
+
73
+ ```bash
74
+ cd gradio_dtd_app
75
+
76
+ # Create virtual environment
77
+ python -m venv venv
78
+ source venv/bin/activate # On Windows: venv\Scripts\activate
79
+
80
+ # Install dependencies
81
+ pip install -r requirements.txt
82
+
83
+ # Run app
84
+ python app.py
85
+ ```
86
+
87
+ Open browser to: `http://localhost:7860`
88
+
89
+ ## Troubleshooting
90
+
91
+ ### Issue: Git LFS bandwidth limit
92
+
93
+ **Solution**: Use Hugging Face's built-in LFS storage:
94
+
95
+ ```bash
96
+ # Track checkpoint files
97
+ git lfs track "checkpoints/*.pth"
98
+ git lfs track "checkpoints/*.pt"
99
+ git add .gitattributes checkpoints/
100
+ git commit -m "Add model checkpoints"
101
+ git push
102
+ ```
103
+
104
+ ### Issue: Build timeout
105
+
106
+ **Solution**: Reduce requirements versions or use pre-built images:
107
+
108
+ ```yaml
109
+ # Create .github/workflows/deploy.yml
110
+ sdk: gradio
111
+ sdk_version: 4.44.0
112
+ python_version: "3.10"
113
+ ```
114
+
115
+ ### Issue: Out of memory
116
+
117
+ **Solution**: Enable GPU hardware in Space settings:
118
+ 1. Go to Space settings
119
+ 2. Select **Hardware**: GPU (free tier: T4)
120
+ 3. Save changes
121
+
122
+ ## File Size Optimization
123
+
124
+ Current app size: **~450MB**
125
+
126
+ To reduce size:
127
+
128
+ 1. **Quantize models** (reduce precision):
129
+ ```python
130
+ # In inference.py
131
+ torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
132
+ ```
133
+
134
+ 2. **Use model compression**:
135
+ ```bash
136
+ pip install onnx onnxruntime
137
+ # Convert to ONNX format (smaller)
138
+ ```
139
+
140
+ 3. **Lazy loading**:
141
+ ```python
142
+ # Load models on first request instead of startup
143
+ @lru_cache()
144
+ def get_model():
145
+ return DTDPredictor()
146
+ ```
147
+
148
+ ## Environment Variables
149
+
150
+ Add to Space secrets:
151
+
152
+ ```bash
153
+ # Optional: Analytics tracking
154
+ ANALYTICS_TOKEN=your_token
155
+
156
+ # Optional: Rate limiting
157
+ MAX_REQUESTS_PER_HOUR=100
158
+ ```
159
+
160
+ ## Monitoring
161
+
162
+ Check Space logs:
163
+ 1. Go to Space page
164
+ 2. Click **Logs** tab
165
+ 3. Monitor real-time inference
166
+
167
+ ## Custom Domain (Optional)
168
+
169
+ 1. Go to Space settings
170
+ 2. Add custom domain
171
+ 3. Configure DNS records
172
+
173
+ ## Cost Optimization
174
+
175
+ **Free Tier Limits:**
176
+ - CPU: Free (slower inference)
177
+ - GPU T4: Free tier available
178
+ - Storage: 50GB LFS
179
+ - Bandwidth: Limited
180
+
181
+ **Upgrade Options:**
182
+ - GPU A10G: Faster inference
183
+ - Persistent storage
184
+ - Higher bandwidth
185
+
186
+ ## Support
187
+
188
+ - [Hugging Face Docs](https://huggingface.co/docs/hub/spaces)
189
+ - [Gradio Docs](https://gradio.app/docs/)
190
+ - [Git LFS](https://git-lfs.github.com/)
191
+
192
+ ## License
193
+
194
+ MIT License - See LICENSE file
README.md CHANGED
@@ -1,13 +1,111 @@
1
  ---
2
- title: Dtd Document Tampering
3
- emoji: 📈
4
- colorFrom: pink
5
- colorTo: gray
6
  sdk: gradio
7
- sdk_version: 5.49.1
8
  app_file: app.py
9
  pinned: false
10
- license: apache-2.0
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: DTD Document Tampering Detection
3
+ emoji: 🔍
4
+ colorFrom: blue
5
+ colorTo: red
6
  sdk: gradio
7
+ sdk_version: 4.44.0
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
  ---
12
 
13
+ # 🔍 DTD: Document Tampering Detection
14
+
15
+ Detect forged or tampered regions in document images using the **DTD (Document Tampering Detector)** model.
16
+
17
+ ## 📝 Description
18
+
19
+ This application uses state-of-the-art deep learning to identify manipulated text in document images by analyzing JPEG compression artifacts (DCT coefficients).
20
+
21
+ ### ✨ Features
22
+
23
+ - **DCT Analysis**: Examines JPEG compression patterns to detect inconsistencies
24
+ - **Real-time Detection**: Fast inference on CPU or GPU
25
+ - **Visual Heatmaps**: Clear visualization of tampered regions
26
+ - **High Accuracy**: Trained on DocTamper dataset with 120K+ document images
27
+
28
+ ### 🎯 Use Cases
29
+
30
+ - Verify document authenticity
31
+ - Detect forged receipts, invoices, and forms
32
+ - Identify copy-paste text manipulation
33
+ - Detect splicing and content insertion
34
+
35
+ ## 🚀 How It Works
36
+
37
+ 1. **Upload** a document image (JPEG format works best)
38
+ 2. **Adjust** JPEG quality setting for DCT analysis (default: 90)
39
+ 3. **View** tampering detection results:
40
+ - **Heatmap**: Red overlay shows tampered regions
41
+ - **Binary Mask**: Clear segmentation of authentic vs tampered
42
+ - **Original**: Compare with input
43
+
44
+ ## 🏗️ Model Architecture
45
+
46
+ - **Backbone**: VPH (Vision Pyramid Hybrid) + Swin Transformer
47
+ - **Decoder**: Multi-scale Iterative Decoder (MID)
48
+ - **Inputs**: RGB image + DCT coefficients + Quantization tables
49
+ - **Output**: Binary segmentation mask (0=authentic, 1=tampered)
50
+
51
+ ## 📚 Citation
52
+
53
+ ```bibtex
54
+ @inproceedings{qu2023towards,
55
+ title={Towards Robust Tampered Text Detection in Document Image: New Dataset and New Solution},
56
+ author={Qu, Chenfan and Liu, Chongyu and Liu, Yuliang and Chen, Xinhong and Peng, Dezhi and Guo, Fengjun and Jin, Lianwen},
57
+ booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
58
+ pages={5937--5946},
59
+ year={2023}
60
+ }
61
+ ```
62
+
63
+ ## 📖 Paper
64
+
65
+ [Towards Robust Tampered Text Detection in Document Image: New Dataset and New Solution](https://openaccess.thecvf.com/content/CVPR2023/papers/Qu_Towards_Robust_Tampered_Text_Detection_in_Document_Image_New_Dataset_CVPR_2023_paper.pdf) (CVPR 2023)
66
+
67
+ ## ⚠️ Limitations
68
+
69
+ - **JPEG Dependency**: Requires JPEG format for DCT analysis
70
+ - **Quality Sensitivity**: Detection accuracy varies with compression quality
71
+ - **False Positives**: May occur on low-quality scans or heavily compressed images
72
+ - **Preprocessing**: Images must contain text/document content
73
+
74
+ ## 🛠️ Technical Details
75
+
76
+ ### Model Weights
77
+
78
+ - **Main Model**: `dtd_doctamper.pth` (257MB)
79
+ - **VPH Backbone**: `vph_imagenet.pt` (4.8MB)
80
+ - **Swin Transformer**: `swin_imagenet.pt` (187MB)
81
+ - **Total Size**: ~449MB
82
+
83
+ ### Performance
84
+
85
+ - **Input Size**: Variable (auto-resized)
86
+ - **Inference Time**: ~2-5 seconds on CPU
87
+ - **GPU Acceleration**: Supported (CUDA)
88
+
89
+ ## 📦 Local Installation
90
+
91
+ ```bash
92
+ # Clone repository
93
+ git clone https://huggingface.co/spaces/YOUR_USERNAME/dtd-doctamper-detection
94
+ cd dtd-doctamper-detection
95
+
96
+ # Install dependencies
97
+ pip install -r requirements.txt
98
+
99
+ # Run application
100
+ python app.py
101
+ ```
102
+
103
+ ## 📄 License
104
+
105
+ MIT License - See LICENSE file for details
106
+
107
+ ## 🤝 Acknowledgments
108
+
109
+ - Original DTD model by Qu et al. (CVPR 2023)
110
+ - DocTamper dataset
111
+ - Hugging Face Spaces for hosting
app.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DTD DocTamper - Gradio Application
3
+ Document Tampering Detection using DTD model
4
+ """
5
+ import gradio as gr
6
+ import numpy as np
7
+ from inference import DTDPredictor
8
+
9
+ # Initialize predictor
10
+ print("Loading DTD model...")
11
+ predictor = DTDPredictor(
12
+ checkpoint_path='checkpoints/dtd_doctamper.pth',
13
+ device='auto'
14
+ )
15
+ print("Model loaded!")
16
+
17
+ def predict_tampering(image, quality=90):
18
+ """
19
+ Predict document tampering
20
+
21
+ Args:
22
+ image: Input image (PIL Image or numpy array)
23
+ quality: JPEG compression quality for DCT analysis
24
+
25
+ Returns:
26
+ Tuple of (original, mask, heatmap)
27
+ """
28
+ # Save uploaded image temporarily
29
+ import tempfile
30
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as tmp:
31
+ if hasattr(image, 'save'):
32
+ image.save(tmp, 'JPEG', quality=95)
33
+ else:
34
+ from PIL import Image
35
+ Image.fromarray(image).save(tmp, 'JPEG', quality=95)
36
+ tmp_path = tmp.name
37
+
38
+ try:
39
+ # Run prediction
40
+ result = predictor.predict(tmp_path, quality=quality)
41
+
42
+ return (
43
+ result['original'],
44
+ result['mask'],
45
+ result['heatmap']
46
+ )
47
+ finally:
48
+ import os
49
+ os.unlink(tmp_path)
50
+
51
+ # Create Gradio interface
52
+ with gr.Blocks(title="DTD Document Tampering Detection") as demo:
53
+ gr.Markdown("""
54
+ # 🔍 DTD: Document Tampering Detection
55
+
56
+ Upload a document image to detect forged or tampered regions using the DTD (Document Tampering Detector) model.
57
+
58
+ **How it works:**
59
+ - The model analyzes JPEG compression artifacts (DCT coefficients)
60
+ - Red regions indicate potential tampering
61
+ - Works best on JPEG images of documents
62
+
63
+ **Paper:** [Towards Robust Tampered Text Detection in Document Image](https://openaccess.thecvf.com/content/CVPR2023/papers/Qu_Towards_Robust_Tampered_Text_Detection_in_Document_Image_New_Dataset_CVPR_2023_paper.pdf)
64
+ """)
65
+
66
+ with gr.Row():
67
+ with gr.Column():
68
+ input_image = gr.Image(
69
+ label="Upload Document Image",
70
+ type="pil"
71
+ )
72
+
73
+ quality_slider = gr.Slider(
74
+ minimum=75,
75
+ maximum=95,
76
+ value=90,
77
+ step=5,
78
+ label="JPEG Quality for DCT Analysis",
79
+ info="Higher quality = more sensitive detection"
80
+ )
81
+
82
+ submit_btn = gr.Button("Detect Tampering", variant="primary")
83
+
84
+ with gr.Column():
85
+ with gr.Tab("Heatmap Overlay"):
86
+ output_heatmap = gr.Image(label="Tampering Heatmap")
87
+
88
+ with gr.Tab("Binary Mask"):
89
+ output_mask = gr.Image(label="Tampering Mask")
90
+
91
+ with gr.Tab("Original"):
92
+ output_original = gr.Image(label="Original Image")
93
+
94
+ # Examples
95
+ gr.Examples(
96
+ examples=[
97
+ ["examples/carte.jpeg", 90],
98
+ ["examples/TamperedPaystub.jpg", 90],
99
+ ["examples/Paystub.jpg", 90],
100
+ ],
101
+ inputs=[input_image, quality_slider],
102
+ outputs=[output_original, output_mask, output_heatmap],
103
+ fn=predict_tampering,
104
+ cache_examples=False,
105
+ )
106
+
107
+ # Event handlers
108
+ submit_btn.click(
109
+ fn=predict_tampering,
110
+ inputs=[input_image, quality_slider],
111
+ outputs=[output_original, output_mask, output_heatmap]
112
+ )
113
+
114
+ gr.Markdown("""
115
+ ---
116
+ ### ℹ️ About
117
+
118
+ **DTD (Document Tampering Detector)** is a deep learning model designed to detect forged text in document images.
119
+
120
+ **Features:**
121
+ - Analyzes JPEG compression artifacts using DCT (Discrete Cosine Transform)
122
+ - Detects copy-paste, splicing, and text manipulation
123
+ - Works on scanned documents, photos of documents, and digital documents
124
+
125
+ **Citation:**
126
+ ```bibtex
127
+ @inproceedings{qu2023towards,
128
+ title={Towards Robust Tampered Text Detection in Document Image: New Dataset and New Solution},
129
+ author={Qu, Chenfan and Liu, Chongyu and Liu, Yuliang and Chen, Xinhong and Peng, Dezhi and Guo, Fengjun and Jin, Lianwen},
130
+ booktitle={CVPR},
131
+ year={2023}
132
+ }
133
+ ```
134
+
135
+ **Model Architecture:**
136
+ - Backbone: VPH (Vision Pyramid Hybrid) + Swin Transformer
137
+ - Decoder: Multi-scale Iterative Decoder (MID)
138
+ - Input: RGB image + DCT coefficients + Quantization tables
139
+ - Output: Binary segmentation mask
140
+
141
+ **Limitations:**
142
+ - Requires JPEG images for DCT analysis
143
+ - May produce false positives on low-quality scans
144
+ - Performance varies with JPEG compression quality
145
+ """)
146
+
147
+ if __name__ == "__main__":
148
+ demo.launch(
149
+ server_name="0.0.0.0",
150
+ server_port=7860,
151
+ share=False
152
+ )
checkpoints/dtd_doctamper.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:81291d19e0c92fd56a8e76f422114a9c6bc6f67f4ac03a0facc18a045894a8c1
3
+ size 269695109
checkpoints/qt_table.pk ADDED
Binary file (7.74 kB). View file
 
checkpoints/swin_imagenet.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1436a2d793dbbb74c9578a68e52d0e7deaa3f305560a34d287a8e4edc866b245
3
+ size 196402845
checkpoints/vph_imagenet.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f4e99cbcbd5f17a0004278a57e3ff199a0de7189345d7e6924e104710d602898
3
+ size 5000275
examples/Paystub.jpg ADDED

Git LFS Details

  • SHA256: e29a9c35b4d22e54500486a4d3ab8e9501dd5ebe2797ae51e2e2e6933dea60f8
  • Pointer size: 131 Bytes
  • Size of remote file: 293 kB
examples/TamperedPaystub.jpg ADDED

Git LFS Details

  • SHA256: 6b5dca0dcc2d057cfd06f78137fa5bf36c0b14cb58d9d62c1482833b0d14e152
  • Pointer size: 131 Bytes
  • Size of remote file: 287 kB
examples/TamperedPaystubv1.jpg ADDED

Git LFS Details

  • SHA256: 8295f2441596ec7caddbbfa4df94bd8705ad367244485c812f3bb1885d4b3386
  • Pointer size: 131 Bytes
  • Size of remote file: 686 kB
examples/carte.jpeg ADDED
inference.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DTD DocTamper Inference Module
3
+ Simplified for Gradio deployment on Hugging Face
4
+ """
5
+ import os
6
+ import sys
7
+ import cv2
8
+ import torch
9
+ import jpegio
10
+ import numpy as np
11
+ import pickle
12
+ import tempfile
13
+ from PIL import Image
14
+ from models import fix_imports # Apply compatibility fixes
15
+ from models import patch_gelu
16
+ from models import patch_droppath
17
+ from models.dtd import seg_dtd
18
+ import torchvision.transforms as transforms
19
+
20
+ class DTDPredictor:
21
+ def __init__(self, checkpoint_path='checkpoints/dtd_doctamper.pth', device='cpu'):
22
+ """
23
+ Initialize DTD model for inference
24
+
25
+ Args:
26
+ checkpoint_path: Path to model checkpoint
27
+ device: Device to use ('cpu', 'cuda', or 'auto')
28
+ """
29
+ # Auto-detect device
30
+ if device == 'auto':
31
+ if torch.cuda.is_available():
32
+ self.device = 'cuda'
33
+ else:
34
+ self.device = 'cpu'
35
+ else:
36
+ self.device = device
37
+
38
+ print(f'Using device: {self.device}')
39
+
40
+ # Load QT table
41
+ with open('checkpoints/qt_table.pk', 'rb') as fpk:
42
+ pks = pickle.load(fpk)
43
+ self.pks = {}
44
+ for k, v in pks.items():
45
+ self.pks[k] = torch.LongTensor(v)
46
+
47
+ # Image transforms
48
+ self.transform = transforms.Compose([
49
+ transforms.ToTensor(),
50
+ transforms.Normalize(mean=(0.485, 0.455, 0.406),
51
+ std=(0.229, 0.224, 0.225))
52
+ ])
53
+
54
+ # Load model
55
+ self.model = seg_dtd('', 2, device=self.device)
56
+ if self.device == 'cuda':
57
+ self.model = self.model.cuda()
58
+ self.model = self.model.to(self.device)
59
+
60
+ # Load checkpoint
61
+ ckpt = torch.load(checkpoint_path, map_location='cpu')
62
+ state_dict = ckpt['state_dict']
63
+
64
+ # Remove 'module.' prefix if present
65
+ new_state_dict = {}
66
+ for k, v in state_dict.items():
67
+ if k.startswith('module.'):
68
+ new_state_dict[k[7:]] = v
69
+ else:
70
+ new_state_dict[k] = v
71
+
72
+ self.model.load_state_dict(new_state_dict)
73
+ self.model.eval()
74
+
75
+ print('Model loaded successfully!')
76
+
77
+ def extract_dct(self, image_path, quality=90):
78
+ """
79
+ Extract DCT coefficients from JPEG image
80
+
81
+ Args:
82
+ image_path: Path to JPEG image
83
+ quality: JPEG quality for re-compression
84
+
85
+ Returns:
86
+ DCT coefficients and quantization table
87
+ """
88
+ # Load image
89
+ im = Image.open(image_path).convert('RGB')
90
+
91
+ # Re-compress to JPEG with specified quality
92
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as tmp:
93
+ im_gray = im.convert("L")
94
+ im_gray.save(tmp, "JPEG", quality=quality)
95
+ tmp_path = tmp.name
96
+
97
+ try:
98
+ # Read JPEG with jpegio
99
+ jpg = jpegio.read(tmp_path)
100
+ dct = jpg.coef_arrays[0].copy()
101
+
102
+ # Get quantization table
103
+ qt = jpg.quant_tables[0]
104
+ qt_flat = qt.flatten()[:64] # First 64 values
105
+
106
+ return dct, qt_flat
107
+ finally:
108
+ # Clean up temp file
109
+ os.unlink(tmp_path)
110
+
111
+ @torch.no_grad()
112
+ def predict(self, image_path, quality=90):
113
+ """
114
+ Predict tampering mask for input image
115
+
116
+ Args:
117
+ image_path: Path to input JPEG image
118
+ quality: JPEG quality for DCT extraction
119
+
120
+ Returns:
121
+ Dictionary containing:
122
+ - original: Original image (numpy array)
123
+ - mask: Binary tampering mask (numpy array)
124
+ - heatmap: Colorized heatmap overlay
125
+ """
126
+ # Load image
127
+ im = Image.open(image_path).convert('RGB')
128
+ im_np = np.array(im)
129
+
130
+ # Extract DCT coefficients
131
+ dct, qt = self.extract_dct(image_path, quality)
132
+
133
+ # Prepare inputs
134
+ # Image
135
+ im_tensor = self.transform(im).unsqueeze(0).to(self.device)
136
+
137
+ # DCT coefficients (clip to [0, 20])
138
+ dct_tensor = torch.from_numpy(np.clip(np.abs(dct), 0, 20)).unsqueeze(0).unsqueeze(0).float().to(self.device)
139
+
140
+ # Quantization table
141
+ qt_indices = []
142
+ for val in qt:
143
+ # Find closest match in quantization table
144
+ if val in self.pks:
145
+ qt_indices.append(val)
146
+ else:
147
+ # Find closest
148
+ closest = min(self.pks.keys(), key=lambda x: abs(x - val))
149
+ qt_indices.append(closest)
150
+
151
+ qt_tensor = torch.LongTensor(qt_indices[:64]).unsqueeze(0).to(self.device)
152
+
153
+ # Forward pass
154
+ output = self.model(im_tensor, dct_tensor, qt_tensor)
155
+
156
+ # Get prediction mask
157
+ pred_mask = output.argmax(1).squeeze().cpu().numpy()
158
+
159
+ # Create heatmap overlay
160
+ heatmap = self.create_heatmap(im_np, pred_mask)
161
+
162
+ return {
163
+ 'original': im_np,
164
+ 'mask': (pred_mask * 255).astype(np.uint8),
165
+ 'heatmap': heatmap
166
+ }
167
+
168
+ def create_heatmap(self, image, mask):
169
+ """
170
+ Create colorized heatmap overlay
171
+
172
+ Args:
173
+ image: Original image (numpy array)
174
+ mask: Binary mask (numpy array)
175
+
176
+ Returns:
177
+ Heatmap overlay (numpy array)
178
+ """
179
+ # Create colored mask
180
+ colored_mask = np.zeros_like(image)
181
+ colored_mask[mask == 1] = [255, 0, 0] # Red for tampered regions
182
+
183
+ # Blend with original image
184
+ alpha = 0.5
185
+ heatmap = cv2.addWeighted(image, 1 - alpha, colored_mask, alpha, 0)
186
+
187
+ return heatmap
models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # DTD Models Module
models/dtd.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import fix_imports # Apply import fixes
3
+ import cv2
4
+ import lmdb
5
+ import torch
6
+ import jpegio
7
+ import numpy as np
8
+ import torch.nn as nn
9
+ import gc
10
+ import math
11
+ import time
12
+ import copy
13
+ import logging
14
+ import torch.optim as optim
15
+ import torch.distributed as dist
16
+ import random
17
+ import pickle
18
+ import six
19
+ from glob import glob
20
+ from PIL import Image
21
+ from tqdm import tqdm
22
+ from torch.autograd import Variable
23
+ from torch.cuda.amp import autocast
24
+ import segmentation_models_pytorch as smp
25
+ from torch.utils.data import Dataset, DataLoader
26
+ from torch.cuda.amp import autocast, GradScaler#need pytorch>1.6
27
+ from losses import DiceLoss,FocalLoss,SoftCrossEntropyLoss,LovaszLoss
28
+ from fph import FPH
29
+ import albumentations as A
30
+ from swins import *
31
+ from albumentations.pytorch import ToTensorV2
32
+ import torchvision
33
+ import torch.nn.functional as F
34
+ try:
35
+ from timm.models.layers import trunc_normal_, DropPath
36
+ except ImportError:
37
+ from timm.layers import trunc_normal_, DropPath
38
+ from functools import partial
39
+ from segmentation_models_pytorch.base import modules as md
40
+ from typing import Optional, Union, List
41
+ from segmentation_models_pytorch.base import SegmentationModel
42
+
43
+ # Custom GELU for compatibility
44
+ class GELU(nn.Module):
45
+ def forward(self, x):
46
+ return F.gelu(x)
47
+
48
+ class LayerNorm(nn.Module):
49
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
50
+ super().__init__()
51
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
52
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
53
+ self.eps = eps
54
+ self.data_format = data_format
55
+ if self.data_format not in ["channels_last", "channels_first"]:
56
+ raise NotImplementedError
57
+ self.normalized_shape = (normalized_shape, )
58
+
59
+ def forward(self, x):
60
+ if self.data_format == "channels_last":
61
+ return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
62
+ elif self.data_format == "channels_first":
63
+ u = x.mean(1, keepdim=True)
64
+ s = (x - u).pow(2).mean(1, keepdim=True)
65
+ x = (x - u) / torch.sqrt(s + self.eps)
66
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
67
+ return x
68
+
69
+ class SCSEModule(nn.Module):
70
+ def __init__(self, in_channels, reduction=16):
71
+ super().__init__()
72
+ self.cSE = nn.Sequential(
73
+ nn.AdaptiveAvgPool2d(1),
74
+ nn.Conv2d(in_channels, in_channels // reduction, 1),
75
+ nn.ReLU(inplace=True),
76
+ nn.Conv2d(in_channels // reduction, in_channels, 1),
77
+ nn.Sigmoid(),
78
+ )
79
+ self.sSE = nn.Sequential(nn.Conv2d(in_channels, 1, 1), nn.Sigmoid())
80
+
81
+ def forward(self, x):
82
+ return x * self.cSE(x) + x * self.sSE(x)
83
+
84
+ class ConvBlock(nn.Module):
85
+ def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
86
+ super().__init__()
87
+ self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim)
88
+ self.norm = LayerNorm(dim, eps=1e-6)
89
+ self.pwconv1 = nn.Linear(dim, 4 * dim)
90
+ self.act = GELU()
91
+ self.pwconv2 = nn.Linear(4 * dim, dim)
92
+ self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) if layer_scale_init_value > 0 else None
93
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
94
+
95
+ def forward(self, x):
96
+ ipt = x
97
+ x = self.dwconv(x)
98
+ x = x.permute(0, 2, 3, 1)
99
+ x = self.norm(x)
100
+ x = self.pwconv1(x)
101
+ x = self.act(x)
102
+ x = self.pwconv2(x)
103
+ if self.gamma is not None:
104
+ x = self.gamma * x
105
+ x = x.permute(0, 3, 1, 2)
106
+ x = ipt + self.drop_path(x)
107
+ return x
108
+
109
+ class AddCoords(nn.Module):
110
+ def __init__(self, with_r=True):
111
+ super().__init__()
112
+ self.with_r = with_r
113
+ def forward(self, input_tensor):
114
+ batch_size, _, x_dim, y_dim = input_tensor.size()
115
+ xx_c, yy_c = torch.meshgrid(torch.arange(x_dim,dtype=input_tensor.dtype), torch.arange(y_dim,dtype=input_tensor.dtype))
116
+ xx_c = xx_c.to(input_tensor.device) / (x_dim - 1) * 2 - 1
117
+ yy_c = yy_c.to(input_tensor.device) / (y_dim - 1) * 2 - 1
118
+ xx_c = xx_c.expand(batch_size,1,x_dim,y_dim)
119
+ yy_c = yy_c.expand(batch_size,1,x_dim,y_dim)
120
+ ret = torch.cat((input_tensor,xx_c,yy_c), dim=1)
121
+ if self.with_r:
122
+ rr = torch.sqrt(torch.pow(xx_c - 0.5, 2) + torch.pow(yy_c - 0.5, 2))
123
+ ret = torch.cat([ret, rr], dim=1)
124
+ return ret
125
+
126
+ class VPH(nn.Module):
127
+ def __init__(self, dims=[96, 192], drop_path_rate=0.4, layer_scale_init_value=1e-6):
128
+ super().__init__()
129
+ dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
130
+ self.downsample_layers = nn.ModuleList([nn.Sequential(nn.Conv2d(6, dims[0], kernel_size=4, stride=4), LayerNorm(dims[0], eps=1e-6, data_format="channels_first")), nn.Sequential(LayerNorm(dims[1], eps=1e-6, data_format="channels_first"),nn.Conv2d(dims[1], dims[2], kernel_size=2, stride=2))])
131
+ self.stages = nn.ModuleList([nn.Sequential(*[ConvBlock(dim=dims[0], drop_path=dp_rates[j],layer_scale_init_value=layer_scale_init_value) for j in range(3)]), nn.Sequential(*[ConvBlock(dim=dims[1], drop_path=dp_rates[3 + j],layer_scale_init_value=layer_scale_init_value) for j in range(3)])])
132
+ self.apply(self._init_weights)
133
+
134
+ def initnorm(self):
135
+ norm_layer = partial(LayerNorm, eps=1e-6, data_format="channels_first")
136
+ for i_layer in range(4):
137
+ layer = norm_layer(self.dims[i_layer])
138
+ layer_name = f'norm{i_layer}'
139
+ self.add_module(layer_name, layer)
140
+
141
+ def _init_weights(self, m):
142
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
143
+ trunc_normal_(m.weight, std=.02)
144
+ nn.init.constant_(m.bias, 0)
145
+
146
+ def init_weights(self, pretrained=None):
147
+ def _init_weights(m):
148
+ if isinstance(m, nn.Linear):
149
+ trunc_normal_(m.weight, std=.02)
150
+ if isinstance(m, nn.Linear) and m.bias is not None:
151
+ nn.init.constant_(m.bias, 0)
152
+ elif isinstance(m, nn.LayerNorm):
153
+ nn.init.constant_(m.bias, 0)
154
+ nn.init.constant_(m.weight, 1.0)
155
+ self.apply(_init_weights)
156
+
157
+ def forward(self, x):
158
+ outs = []
159
+ x = self.stages[0](self.downsample_layers[0](x))
160
+ outs = [self.norm0(x)]
161
+ x = self.stages[1](self.downsample_layers[1](x))
162
+ outs.append(self.norm1(x))
163
+ return outs
164
+
165
+ class SegmentationHead(nn.Sequential):
166
+ def __init__(self, in_channels, out_channels, kernel_size=3, activation=None, upsampling=1):
167
+ upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity()
168
+ conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2)
169
+ activation = md.Activation(activation)
170
+ super().__init__(conv2d, upsampling, activation)
171
+
172
+ class DecoderBlock(nn.Module):
173
+ def __init__(self,cin,cadd,cout,):
174
+ super().__init__()
175
+ self.cin = (cin + cadd)
176
+ self.cout = cout
177
+ self.conv1 = nn.Sequential(
178
+ nn.Conv2d(self.cin, self.cout, kernel_size=3, padding=1, bias=False),
179
+ nn.BatchNorm2d(self.cout),
180
+ nn.ReLU(inplace=True)
181
+ )
182
+ self.conv2 = nn.Sequential(
183
+ nn.Conv2d(self.cout, self.cout, kernel_size=3, padding=1, bias=False),
184
+ nn.BatchNorm2d(self.cout),
185
+ nn.ReLU(inplace=True)
186
+ )
187
+
188
+ def forward(self, x1, x2=None):
189
+ x1 = F.interpolate(x1, scale_factor=2.0, mode="nearest")
190
+ if x2 is not None:
191
+ x1 = torch.cat([x1, x2], dim=1)
192
+ x1 = self.conv1(x1[:,:self.cin])
193
+ x1 = self.conv2(x1)
194
+ return x1
195
+
196
+ class ConvBNReLU(nn.Module):
197
+ def __init__(self,in_c,out_c,ks,stride=1,norm=True,res=False):
198
+ super(ConvBNReLU, self).__init__()
199
+ if norm:
200
+ self.conv = nn.Sequential(nn.Conv2d(in_c, out_c, kernel_size=ks, padding = ks//2, stride=stride,bias=False),nn.BatchNorm2d(out_c),nn.ReLU(True))
201
+ else:
202
+ self.conv = nn.Conv2d(in_c, out_c, kernel_size=ks, padding = ks//2, stride=stride,bias=False)
203
+ self.res = res
204
+ def forward(self,x):
205
+ if self.res:
206
+ return (x + self.conv(x))
207
+ else:
208
+ return self.conv(x)
209
+
210
+ class FUSE1(nn.Module):
211
+ def __init__(self,in_channels_list=(96,192,384,768)):
212
+ super(FUSE1, self).__init__()
213
+ self.c31 = ConvBNReLU(in_channels_list[2],in_channels_list[2],1)
214
+ self.c32 = ConvBNReLU(in_channels_list[3],in_channels_list[2],1)
215
+ self.c33 = ConvBNReLU(in_channels_list[2],in_channels_list[2],3)
216
+
217
+ self.c21 = ConvBNReLU(in_channels_list[1],in_channels_list[1],1)
218
+ self.c22 = ConvBNReLU(in_channels_list[2],in_channels_list[1],1)
219
+ self.c23 = ConvBNReLU(in_channels_list[1],in_channels_list[1],3)
220
+
221
+ self.c11 = ConvBNReLU(in_channels_list[0],in_channels_list[0],1)
222
+ self.c12 = ConvBNReLU(in_channels_list[1],in_channels_list[0],1)
223
+ self.c13 = ConvBNReLU(in_channels_list[0],in_channels_list[0],3)
224
+
225
+ def forward(self,x):
226
+ x,x1,x2,x3 = x
227
+ h,w = x2.shape[-2:]
228
+ x2 = self.c33(F.interpolate(self.c32(x3),size=(h,w))+self.c31(x2))
229
+ h,w = x1.shape[-2:]
230
+ x1 = self.c23(F.interpolate(self.c22(x2),size=(h,w))+self.c21(x1))
231
+ h,w = x.shape[-2:]
232
+ x = self.c13(F.interpolate(self.c12(x1),size=(h,w))+self.c11(x))
233
+ return x,x1,x2,x3
234
+
235
+ class FUSE2(nn.Module):
236
+ def __init__(self,in_channels_list=(96,192,384)):
237
+ super(FUSE2, self).__init__()
238
+
239
+ self.c21 = ConvBNReLU(in_channels_list[1],in_channels_list[1],1)
240
+ self.c22 = ConvBNReLU(in_channels_list[2],in_channels_list[1],1)
241
+ self.c23 = ConvBNReLU(in_channels_list[1],in_channels_list[1],3)
242
+
243
+ self.c11 = ConvBNReLU(in_channels_list[0],in_channels_list[0],1)
244
+ self.c12 = ConvBNReLU(in_channels_list[1],in_channels_list[0],1)
245
+ self.c13 = ConvBNReLU(in_channels_list[0],in_channels_list[0],3)
246
+
247
+ def forward(self,x):
248
+ x,x1,x2 = x
249
+ h,w = x1.shape[-2:]
250
+ x1 = self.c23(F.interpolate(self.c22(x2),size=(h,w),mode='bilinear',align_corners=True)+self.c21(x1))
251
+ h,w = x.shape[-2:]
252
+ x = self.c13(F.interpolate(self.c12(x1),size=(h,w),mode='bilinear',align_corners=True)+self.c11(x))
253
+ return x,x1,x2
254
+
255
+ class FUSE3(nn.Module):
256
+ def __init__(self,in_channels_list=(96,192)):
257
+ super(FUSE3, self).__init__()
258
+
259
+ self.c11 = ConvBNReLU(in_channels_list[0],in_channels_list[0],1)
260
+ self.c12 = ConvBNReLU(in_channels_list[1],in_channels_list[0],1)
261
+ self.c13 = ConvBNReLU(in_channels_list[0],in_channels_list[0],3)
262
+
263
+ def forward(self,x):
264
+ x,x1 = x
265
+ h,w = x.shape[-2:]
266
+ x = self.c13(F.interpolate(self.c12(x1),size=(h,w),mode='bilinear',align_corners=True)+self.c11(x))
267
+ return x,x1
268
+
269
+ class MID(nn.Module):
270
+ def __init__(self, encoder_channels, decoder_channels):
271
+ super().__init__()
272
+ encoder_channels = encoder_channels[1:][::-1]
273
+ self.in_channels = [encoder_channels[0]] + list(decoder_channels[:-1])
274
+ self.add_channels = list(encoder_channels[1:]) + [96]
275
+ self.out_channels = decoder_channels
276
+ self.fuse1 = FUSE1()
277
+ self.fuse2 = FUSE2()
278
+ self.fuse3 = FUSE3()
279
+ decoder_convs = {}
280
+ for layer_idx in range(len(self.in_channels) - 1):
281
+ for depth_idx in range(layer_idx + 1):
282
+ if depth_idx == 0:
283
+ in_ch = self.in_channels[layer_idx]
284
+ skip_ch = self.add_channels[layer_idx] * (layer_idx + 1)
285
+ out_ch = self.out_channels[layer_idx]
286
+ else:
287
+ out_ch = self.add_channels[layer_idx]
288
+ skip_ch = self.add_channels[layer_idx] * (layer_idx + 1 - depth_idx)
289
+ in_ch = self.add_channels[layer_idx - 1]
290
+ decoder_convs[f"x_{depth_idx}_{layer_idx}"] = DecoderBlock(in_ch, skip_ch, out_ch)
291
+ decoder_convs[f"x_{0}_{len(self.in_channels)-1}"] = DecoderBlock(self.in_channels[-1], 0, self.out_channels[-1])
292
+ self.decoder_convs = nn.ModuleDict(decoder_convs)
293
+
294
+ def forward(self, *features):
295
+ decoder_features = {}
296
+ features = self.fuse1(features)[::-1]
297
+ decoder_features["x_0_0"] = self.decoder_convs["x_0_0"](features[0],features[1])
298
+ decoder_features["x_1_1"] = self.decoder_convs["x_1_1"](features[1],features[2])
299
+ decoder_features["x_2_2"] = self.decoder_convs["x_2_2"](features[2],features[3])
300
+ decoder_features["x_2_2"], decoder_features["x_1_1"], decoder_features["x_0_0"] = self.fuse2((decoder_features["x_2_2"], decoder_features["x_1_1"], decoder_features["x_0_0"]))
301
+ decoder_features["x_0_1"] = self.decoder_convs["x_0_1"](decoder_features["x_0_0"], torch.cat((decoder_features["x_1_1"], features[2]),1))
302
+ decoder_features["x_1_2"] = self.decoder_convs["x_1_2"](decoder_features["x_1_1"], torch.cat((decoder_features["x_2_2"], features[3]),1))
303
+ decoder_features["x_1_2"], decoder_features["x_0_1"] = self.fuse3((decoder_features["x_1_2"], decoder_features["x_0_1"]))
304
+ decoder_features["x_0_2"] = self.decoder_convs["x_0_2"](decoder_features["x_0_1"], torch.cat((decoder_features["x_1_2"], decoder_features["x_2_2"], features[3]),1))
305
+ return self.decoder_convs["x_0_3"](torch.cat((decoder_features["x_0_2"], decoder_features["x_1_2"], decoder_features["x_2_2"]),1))
306
+
307
+
308
+ class DTD(SegmentationModel):
309
+ def __init__(self, encoder_name = "resnet18", decoder_channels = (384, 192, 96, 64), classes = 1, device='cpu'):
310
+ super().__init__()
311
+ # Load models with proper device mapping
312
+ import os
313
+ model_dir = os.path.dirname(os.path.abspath(__file__))
314
+ vph_path = os.path.join(model_dir, '..', 'pths', 'vph_imagenet.pt')
315
+ swin_path = os.path.join(model_dir, '..', 'pths', 'swin_imagenet.pt')
316
+
317
+ if device == 'mps':
318
+ self.vph = torch.load(vph_path, map_location=torch.device('cpu'))
319
+ self.swin = torch.load(swin_path, map_location=torch.device('cpu'))
320
+ else:
321
+ self.vph = torch.load(vph_path, map_location=device)
322
+ self.swin = torch.load(swin_path, map_location=device)
323
+ self.fph = FPH()
324
+ self.decoder = MID(encoder_channels=(96, 192, 384, 768), decoder_channels=decoder_channels)
325
+ self.segmentation_head = SegmentationHead(in_channels=decoder_channels[-1], out_channels=classes, upsampling=2.0)
326
+ self.addcoords = AddCoords()
327
+ self.FU = nn.Sequential(SCSEModule(448),nn.Conv2d(448,192,3,1,1),nn.BatchNorm2d(192),nn.ReLU(True))
328
+ self.classification_head = None
329
+ self.initialize()
330
+
331
+ def forward(self,x,dct,qt):
332
+ features = self.vph(self.addcoords(x))
333
+ features[1] = self.FU(torch.cat((features[1],self.fph(dct,qt)),1))
334
+ rst = self.swin[0](features[1].flatten(2).transpose(1,2).contiguous())
335
+ N,L,C = rst.shape
336
+ H = W = int(L**(1/2))
337
+ features.append(self.vph.norm2(rst.transpose(1,2).contiguous().view(N,C,H,W)))
338
+ features.append(self.vph.norm3(self.swin[2](self.swin[1](rst)).transpose(1,2).contiguous().view(N,C*2,H//2,W//2)))
339
+ decoder_output = self.decoder(*features)
340
+ return self.segmentation_head(decoder_output)
341
+
342
+ class seg_dtd(nn.Module):
343
+ def __init__(self, model_name='resnet18', n_class=1, device='cpu'):
344
+ super().__init__()
345
+ self.model = DTD(encoder_name=model_name, classes=n_class, device=device)
346
+ self.device = device
347
+
348
+ def forward(self, x, dct, qt):
349
+ # Use autocast only for CUDA, not for MPS
350
+ if self.device == 'cuda':
351
+ with autocast():
352
+ x = self.model(x, dct, qt)
353
+ else:
354
+ x = self.model(x, dct, qt)
355
+ return x
356
+
models/fix_imports.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Simple import compatibility fix for timm
3
+ """
4
+ import sys
5
+ import torch.nn as nn
6
+ try:
7
+ import timm.layers as new_layers
8
+
9
+ # Create fake modules for backward compatibility
10
+ sys.modules['timm.models.layers.drop'] = new_layers.drop
11
+ sys.modules['timm.models.layers'] = new_layers
12
+
13
+ # Also ensure the imports work
14
+ from timm.layers import DropPath, trunc_normal_
15
+
16
+ # Patch DropPath to add missing attribute
17
+ def patched_droppath_init(self, drop_prob=0., scale_by_keep=True):
18
+ super(DropPath, self).__init__()
19
+ self.drop_prob = drop_prob
20
+ self.scale_by_keep = scale_by_keep
21
+
22
+ # Save original
23
+ _original_droppath_init = DropPath.__init__
24
+
25
+ # Apply patch
26
+ DropPath.__init__ = patched_droppath_init
27
+
28
+ except ImportError:
29
+ pass
30
+
31
+ print("Import compatibility fixes applied")
models/fph.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from efficientnet_pytorch.utils import *
2
+ import os
3
+ import logging
4
+ import functools
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch._utils
9
+ import torch.nn.functional as F
10
+ from functools import partial
11
+ try:
12
+ from timm.models.layers import trunc_normal_, DropPath
13
+ except ImportError:
14
+ from timm.layers import trunc_normal_, DropPath
15
+ import collections
16
+
17
+ BlockArgs = collections.namedtuple('BlockArgs', ['num_repeat', 'kernel_size', 'stride', 'expand_ratio','input_filters', 'output_filters', 'se_ratio', 'id_skip'])
18
+ GlobalParams = collections.namedtuple('GlobalParams', ['width_coefficient', 'depth_coefficient', 'image_size', 'dropout_rate','num_classes', 'batch_norm_momentum', 'batch_norm_epsilon','drop_connect_rate', 'depth_divisor', 'min_depth', 'include_top'])
19
+ global_params = GlobalParams(width_coefficient=1.8, depth_coefficient=2.6, image_size=528, dropout_rate=0.0, num_classes=1000, batch_norm_momentum=0.99, batch_norm_epsilon=0.001, drop_connect_rate=0.0, depth_divisor=8, min_depth=None, include_top=True)
20
+
21
+ def get_width_and_height_from_size(x):
22
+ if isinstance(x, int):
23
+ return x, x
24
+ if isinstance(x, list) or isinstance(x, tuple):
25
+ return x
26
+ else:
27
+ raise TypeError()
28
+
29
+ def calculate_output_image_size(input_image_size, stride):
30
+ if input_image_size is None:
31
+ return None
32
+ image_height, image_width = get_width_and_height_from_size(input_image_size)
33
+ stride = stride if isinstance(stride, int) else stride[0]
34
+ image_height = int(math.ceil(image_height / stride))
35
+ image_width = int(math.ceil(image_width / stride))
36
+ return [image_height, image_width]
37
+
38
+ class MBConvBlock(nn.Module):
39
+ def __init__(self, block_args, global_params, image_size=25):
40
+ super().__init__()
41
+ self._block_args = block_args
42
+ self._bn_mom = 1 - global_params.batch_norm_momentum # pytorch's difference from tensorflow
43
+ self._bn_eps = global_params.batch_norm_epsilon
44
+ self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1)
45
+ self.id_skip = block_args.id_skip # whether to use skip connection and drop connect
46
+ inp = self._block_args.input_filters # number of input channels
47
+ oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels
48
+ if self._block_args.expand_ratio != 1:
49
+ Conv2d = get_same_padding_conv2d(image_size=image_size)
50
+ self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False)
51
+ self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
52
+ k = self._block_args.kernel_size
53
+ s = self._block_args.stride
54
+ Conv2d = get_same_padding_conv2d(image_size=image_size)
55
+ self._depthwise_conv = Conv2d(
56
+ in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise
57
+ kernel_size=k, stride=s, bias=False)
58
+ self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
59
+ image_size = calculate_output_image_size(image_size, s)
60
+ if self.has_se:
61
+ Conv2d = get_same_padding_conv2d(image_size=(1, 1))
62
+ num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio))
63
+ self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
64
+ self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)
65
+ final_oup = self._block_args.output_filters
66
+ Conv2d = get_same_padding_conv2d(image_size=image_size)
67
+ self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)
68
+ self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps)
69
+ self._swish = MemoryEfficientSwish()
70
+
71
+ def forward(self, inputs, drop_connect_rate=None):
72
+ x = inputs
73
+ if self._block_args.expand_ratio != 1:
74
+ x = self._expand_conv(inputs)
75
+ x = self._bn0(x)
76
+ x = self._swish(x)
77
+ x = self._depthwise_conv(x)
78
+ x = self._bn1(x)
79
+ x = self._swish(x)
80
+ if self.has_se:
81
+ x_squeezed = F.adaptive_avg_pool2d(x, 1)
82
+ x_squeezed = self._se_reduce(x_squeezed)
83
+ x_squeezed = self._swish(x_squeezed)
84
+ x_squeezed = self._se_expand(x_squeezed)
85
+ x = torch.sigmoid(x_squeezed) * x
86
+ x = self._project_conv(x)
87
+ x = self._bn2(x)
88
+ input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters
89
+ if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters:
90
+ if drop_connect_rate:
91
+ x = drop_connect(x, p=drop_connect_rate, training=self.training)
92
+ x = x + inputs # skip connection
93
+ return x
94
+
95
+ def set_swish(self, memory_efficient=True):
96
+ self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
97
+
98
+ class AddCoords(nn.Module):
99
+ def __init__(self, with_r=True):
100
+ super().__init__()
101
+ self.with_r = with_r
102
+ def forward(self, input_tensor):
103
+ batch_size, _, x_dim, y_dim = input_tensor.size()
104
+ xx_c, yy_c = torch.meshgrid(torch.arange(x_dim,dtype=input_tensor.dtype), torch.arange(y_dim,dtype=input_tensor.dtype))
105
+ xx_c = xx_c.to(input_tensor.device) / (x_dim - 1) * 2 - 1
106
+ yy_c = yy_c.to(input_tensor.device) / (y_dim - 1) * 2 - 1
107
+ xx_c = xx_c.expand(batch_size,1,x_dim,y_dim)
108
+ yy_c = yy_c.expand(batch_size,1,x_dim,y_dim)
109
+ ret = torch.cat((input_tensor,xx_c,yy_c), dim=1)
110
+ if self.with_r:
111
+ rr = torch.sqrt(torch.pow(xx_c - 0.5, 2) + torch.pow(yy_c - 0.5, 2))
112
+ ret = torch.cat([ret, rr], dim=1)
113
+ return ret
114
+
115
+ class FPH(nn.Module):
116
+
117
+ def __init__(self):
118
+ super(FPH, self).__init__()
119
+ self.obembed = nn.Embedding(21,21).from_pretrained(torch.eye(21))
120
+ self.qtembed = nn.Embedding(64,16)
121
+ self.conv1 = nn.Sequential(nn.Conv2d(in_channels=21,out_channels=64,kernel_size=3,stride=1,dilation=8,padding=8),nn.BatchNorm2d(64, momentum=0.01),nn.ReLU(inplace=True))
122
+ self.conv2 = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=16, kernel_size=1, stride=1, padding=0, bias=False),nn.BatchNorm2d(16, momentum=0.01),nn.ReLU(inplace=True))
123
+ self.addcoords = AddCoords()
124
+ repeats = (1,1,1)
125
+ in_channles = (256,256,256)
126
+ out_channles = (256,256,512)
127
+ self.conv0 = nn.Sequential(nn.Conv2d(in_channels=35, out_channels=256, kernel_size=8, stride=8, padding=0, bias=False),nn.BatchNorm2d(256, momentum=0.01),nn.ReLU(inplace=True),MBConvBlock(BlockArgs(num_repeat=repeats[0], kernel_size=3, stride=[1], expand_ratio=6, input_filters=in_channles[0], output_filters=in_channles[1], se_ratio=0.25, id_skip=True), global_params),MBConvBlock(BlockArgs(num_repeat=repeats[0], kernel_size=3, stride=[1], expand_ratio=6, input_filters=in_channles[1], output_filters=in_channles[1], se_ratio=0.25, id_skip=True), global_params),MBConvBlock(BlockArgs(num_repeat=repeats[0], kernel_size=3, stride=[1], expand_ratio=6, input_filters=in_channles[1], output_filters=in_channles[1], se_ratio=0.25, id_skip=True), global_params),)
128
+
129
+ def forward(self, x, qtable):
130
+ x = self.conv2(self.conv1(self.obembed(x).permute(0,3,1,2).contiguous()))
131
+ B, C, H, W = x.shape
132
+ return self.conv0(self.addcoords(torch.cat(((x.reshape(B,C,H//8,8,W//8,8).permute(0,1,3,5,2,4)*self.qtembed(qtable.unsqueeze(-1).unsqueeze(-1).long()).transpose(1,6).squeeze(6).contiguous()).permute(0,1,4,2,5,3).reshape(B,C,H,W),x), dim=1)))
models/patch_droppath.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Patch DropPath for compatibility
3
+ """
4
+ try:
5
+ from timm.layers import DropPath
6
+
7
+ # Patch existing instances to add scale_by_keep
8
+ def patched_droppath_getattr(self, name):
9
+ if name == 'scale_by_keep':
10
+ return True
11
+ return object.__getattribute__(self, name)
12
+
13
+ DropPath.__getattr__ = patched_droppath_getattr
14
+
15
+ print("DropPath patched for compatibility")
16
+ except ImportError:
17
+ pass
models/patch_gelu.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Monkey patch GELU to fix compatibility issues
3
+ """
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ # Patch the forward method of existing GELU instances
9
+ def patched_gelu_forward(self, input):
10
+ return F.gelu(input)
11
+
12
+ # Save original
13
+ _original_gelu_forward = nn.GELU.forward
14
+
15
+ # Apply patch
16
+ nn.GELU.forward = patched_gelu_forward
17
+
18
+ # Also create a new GELU class
19
+ class PatchedGELU(nn.Module):
20
+ def __init__(self, approximate='none'):
21
+ super().__init__()
22
+
23
+ def forward(self, input):
24
+ return F.gelu(input)
25
+
26
+ def __getattr__(self, name):
27
+ if name == 'approximate':
28
+ return 'none'
29
+ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
30
+
31
+ # Replace the class too
32
+ nn.GELU = PatchedGELU
33
+
34
+ print("GELU patched for compatibility")
models/swins.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torch.utils.checkpoint as checkpoint
6
+ try:
7
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
8
+ except ImportError:
9
+ from timm.layers import DropPath, to_2tuple, trunc_normal_
10
+ import numpy as np
11
+
12
+
13
+ class Mlp(nn.Module):
14
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
15
+ super().__init__()
16
+ out_features = out_features or in_features
17
+ hidden_features = hidden_features or in_features
18
+ self.fc1 = nn.Linear(in_features, hidden_features)
19
+ self.act = nn.GELU()# act_layer()
20
+ self.fc2 = nn.Linear(hidden_features, out_features)
21
+ self.drop = nn.Dropout(drop)
22
+
23
+ def forward(self, x):
24
+ x = self.fc1(x)
25
+ x = F.gelu(x)
26
+ x = self.drop(x)
27
+ x = self.fc2(x)
28
+ x = self.drop(x)
29
+ return x
30
+
31
+
32
+ def window_partition(x, window_size):
33
+ B, H, W, C = x.shape
34
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
35
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
36
+ return windows
37
+
38
+
39
+ def window_reverse(windows, window_size, H, W):
40
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
41
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
42
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
43
+ return x
44
+
45
+
46
+ class WindowAttention(nn.Module):
47
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
48
+ pretrained_window_size=[0, 0]):
49
+
50
+ super().__init__()
51
+ self.dim = dim
52
+ self.window_size = window_size # Wh, Ww
53
+ self.pretrained_window_size = pretrained_window_size
54
+ self.num_heads = num_heads
55
+
56
+ self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)
57
+
58
+ # mlp to generate continuous relative position bias
59
+ self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),
60
+ nn.ReLU(inplace=True),
61
+ nn.Linear(512, num_heads, bias=False))
62
+
63
+ # get relative_coords_table
64
+ relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
65
+ relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
66
+ relative_coords_table = torch.stack(
67
+ torch.meshgrid([relative_coords_h,
68
+ relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
69
+ if pretrained_window_size[0] > 0:
70
+ relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
71
+ relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
72
+ else:
73
+ relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
74
+ relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
75
+ relative_coords_table *= 8 # normalize to -8, 8
76
+ relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
77
+ torch.abs(relative_coords_table) + 1.0) / np.log2(8)
78
+
79
+ self.register_buffer("relative_coords_table", relative_coords_table)
80
+
81
+ # get pair-wise relative position index for each token inside the window
82
+ coords_h = torch.arange(self.window_size[0])
83
+ coords_w = torch.arange(self.window_size[1])
84
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
85
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
86
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
87
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
88
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
89
+ relative_coords[:, :, 1] += self.window_size[1] - 1
90
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
91
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
92
+ self.register_buffer("relative_position_index", relative_position_index)
93
+
94
+ self.qkv = nn.Linear(dim, dim * 3, bias=False)
95
+ if qkv_bias:
96
+ self.q_bias = nn.Parameter(torch.zeros(dim))
97
+ self.v_bias = nn.Parameter(torch.zeros(dim))
98
+ else:
99
+ self.q_bias = None
100
+ self.v_bias = None
101
+ self.attn_drop = nn.Dropout(attn_drop)
102
+ self.proj = nn.Linear(dim, dim)
103
+ self.proj_drop = nn.Dropout(proj_drop)
104
+ self.softmax = nn.Softmax(dim=-1)
105
+
106
+ def forward(self, x, mask=None):
107
+ B_, N, C = x.shape
108
+ qkv_bias = None
109
+ if self.q_bias is not None:
110
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
111
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
112
+ qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
113
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
114
+
115
+ # cosine attention
116
+ attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
117
+ logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01,device=attn.device))).exp()
118
+ attn = attn * logit_scale
119
+
120
+ relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
121
+ relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
122
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
123
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
124
+ relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
125
+ attn = attn + relative_position_bias.unsqueeze(0)
126
+
127
+ if mask is not None:
128
+ nW = mask.shape[0]
129
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
130
+ attn = attn.view(-1, self.num_heads, N, N)
131
+ attn = self.softmax(attn)
132
+ else:
133
+ attn = self.softmax(attn)
134
+
135
+ attn = self.attn_drop(attn)
136
+
137
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
138
+ x = self.proj(x)
139
+ x = self.proj_drop(x)
140
+ return x
141
+
142
+ def extra_repr(self) -> str:
143
+ return f'dim={self.dim}, window_size={self.window_size}, ' \
144
+ f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}'
145
+
146
+ class SwinTransformerBlock(nn.Module):
147
+ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
148
+ mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
149
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0):
150
+ super().__init__()
151
+ self.dim = dim
152
+ self.input_resolution = input_resolution
153
+ self.num_heads = num_heads
154
+ self.window_size = window_size
155
+ self.shift_size = shift_size
156
+ self.mlp_ratio = mlp_ratio
157
+ if min(self.input_resolution) <= self.window_size:
158
+ # if window size is larger than input resolution, we don't partition windows
159
+ self.shift_size = 0
160
+ self.window_size = min(self.input_resolution)
161
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
162
+
163
+ self.norm1 = norm_layer(dim)
164
+ self.attn = WindowAttention(
165
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
166
+ qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
167
+ pretrained_window_size=to_2tuple(pretrained_window_size))
168
+
169
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
170
+ self.norm2 = norm_layer(dim)
171
+ mlp_hidden_dim = int(dim * mlp_ratio)
172
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
173
+
174
+ if self.shift_size > 0:
175
+ # calculate attention mask for SW-MSA
176
+ H, W = self.input_resolution
177
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
178
+ h_slices = (slice(0, -self.window_size),
179
+ slice(-self.window_size, -self.shift_size),
180
+ slice(-self.shift_size, None))
181
+ w_slices = (slice(0, -self.window_size),
182
+ slice(-self.window_size, -self.shift_size),
183
+ slice(-self.shift_size, None))
184
+ cnt = 0
185
+ for h in h_slices:
186
+ for w in w_slices:
187
+ img_mask[:, h, w, :] = cnt
188
+ cnt += 1
189
+
190
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
191
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
192
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
193
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
194
+ else:
195
+ attn_mask = None
196
+
197
+ self.register_buffer("attn_mask", attn_mask)
198
+
199
+ def forward(self, x):
200
+ # H, W = self.input_resolution
201
+ B, L, C = x.shape
202
+ H = W = int(L**(1/2))
203
+ assert L == H * W, "input feature has wrong size"
204
+
205
+ shortcut = x
206
+ x = x.view(B, H, W, C)
207
+
208
+ # cyclic shift
209
+ if self.shift_size > 0:
210
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
211
+ else:
212
+ shifted_x = x
213
+
214
+ # partition windows
215
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
216
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
217
+
218
+ # W-MSA/SW-MSA
219
+ attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
220
+
221
+ # merge windows
222
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
223
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
224
+
225
+ # reverse cyclic shift
226
+ if self.shift_size > 0:
227
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
228
+ else:
229
+ x = shifted_x
230
+ x = x.view(B, H * W, C)
231
+ x = shortcut + self.norm1(x)##self.drop_path(self.norm1(x))
232
+
233
+ # FFN
234
+ x = x + self.norm2(self.mlp(x))##self.drop_path(self.norm2(self.mlp(x)))
235
+
236
+ return x
237
+
238
+ def extra_repr(self) -> str:
239
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
240
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
241
+
242
+
243
+ class PatchMerging(nn.Module):
244
+
245
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
246
+ super().__init__()
247
+ self.input_resolution = input_resolution
248
+ self.dim = dim
249
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
250
+ self.norm = norm_layer(2 * dim)
251
+
252
+ def forward(self, x):
253
+ """
254
+ x: B, H*W, C
255
+ """
256
+ # H, W = self.input_resolution
257
+ B, L, C = x.shape
258
+ H = W = int(L**(1/2))
259
+ assert L == H * W, "input feature has wrong size"
260
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
261
+
262
+ x = x.view(B, H, W, C)
263
+
264
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
265
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
266
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
267
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
268
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
269
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
270
+
271
+ x = self.reduction(x)
272
+ x = self.norm(x)
273
+
274
+ return x
275
+
276
+ def extra_repr(self) -> str:
277
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
278
+
279
+
280
+ class BasicLayer(nn.Module):
281
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
282
+ mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
283
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
284
+ pretrained_window_size=0):
285
+
286
+ super().__init__()
287
+ self.dim = dim
288
+ self.input_resolution = input_resolution
289
+ self.depth = depth
290
+ self.use_checkpoint = use_checkpoint
291
+
292
+ # build blocks
293
+ self.blocks = nn.ModuleList([
294
+ SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
295
+ num_heads=num_heads, window_size=window_size,
296
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
297
+ mlp_ratio=mlp_ratio,
298
+ qkv_bias=qkv_bias,
299
+ drop=drop, attn_drop=attn_drop,
300
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
301
+ norm_layer=norm_layer,
302
+ pretrained_window_size=pretrained_window_size)
303
+ for i in range(depth)])
304
+
305
+ # patch merging layer
306
+ if downsample is not None:
307
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
308
+ else:
309
+ self.downsample = None
310
+
311
+ def forward(self, x):
312
+ for blk in self.blocks:
313
+ if self.use_checkpoint:
314
+ x = checkpoint.checkpoint(blk, x)
315
+ else:
316
+ x = blk(x)
317
+ if self.downsample is not None:
318
+ x = self.downsample(x)
319
+ return x
320
+
321
+ def extra_repr(self) -> str:
322
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
323
+
324
+ def _init_respostnorm(self):
325
+ for blk in self.blocks:
326
+ nn.init.constant_(blk.norm1.bias, 0)
327
+ nn.init.constant_(blk.norm1.weight, 0)
328
+ nn.init.constant_(blk.norm2.bias, 0)
329
+ nn.init.constant_(blk.norm2.weight, 0)
330
+
331
+
332
+ class PatchEmbed(nn.Module):
333
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
334
+ super().__init__()
335
+ img_size = to_2tuple(img_size)
336
+ patch_size = to_2tuple(patch_size)
337
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
338
+ self.img_size = img_size
339
+ self.patch_size = patch_size
340
+ self.patches_resolution = patches_resolution
341
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
342
+
343
+ self.in_chans = in_chans
344
+ self.embed_dim = embed_dim
345
+
346
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
347
+ if norm_layer is not None:
348
+ self.norm = norm_layer(embed_dim)
349
+ else:
350
+ self.norm = None
351
+
352
+ def forward(self, x):
353
+ B, C, H, W = x.shape
354
+ # FIXME look at relaxing size constraints
355
+ assert H == self.img_size[0] and W == self.img_size[1], \
356
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
357
+ x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
358
+ if self.norm is not None:
359
+ x = self.norm(x)
360
+ return x
361
+
362
+
363
+ class SwinTransformerV2(nn.Module):
364
+ def __init__(self, img_size=256, patch_size=4, in_chans=3, num_classes=1000,
365
+ embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32],
366
+ window_size=8, mlp_ratio=4., qkv_bias=True,
367
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.0,
368
+ norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
369
+ use_checkpoint=False, pretrained_window_sizes=[8, 8, 8, 6], **kwargs):
370
+ super().__init__()
371
+
372
+ self.num_classes = num_classes
373
+ self.num_layers = len(depths)
374
+ self.embed_dim = embed_dim
375
+ self.ape = ape
376
+ self.patch_norm = patch_norm
377
+ self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
378
+ self.mlp_ratio = mlp_ratio
379
+
380
+ # split image into non-overlapping patches
381
+ self.patch_embed = PatchEmbed(
382
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
383
+ norm_layer=norm_layer if self.patch_norm else None)
384
+ num_patches = self.patch_embed.num_patches
385
+ patches_resolution = self.patch_embed.patches_resolution
386
+ self.patches_resolution = patches_resolution
387
+
388
+ # absolute position embedding
389
+ if self.ape:
390
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
391
+ trunc_normal_(self.absolute_pos_embed, std=.02)
392
+
393
+ self.pos_drop = nn.Dropout(p=drop_rate)
394
+
395
+ # stochastic depth
396
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
397
+
398
+ # build layers
399
+ self.layers = nn.ModuleList()
400
+ for i_layer in range(self.num_layers):
401
+ layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
402
+ input_resolution=(patches_resolution[0] // (2 ** i_layer),
403
+ patches_resolution[1] // (2 ** i_layer)),
404
+ depth=depths[i_layer],
405
+ num_heads=num_heads[i_layer],
406
+ window_size=window_size,
407
+ mlp_ratio=self.mlp_ratio,
408
+ qkv_bias=qkv_bias,
409
+ drop=drop_rate, attn_drop=attn_drop_rate,
410
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
411
+ norm_layer=norm_layer,
412
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
413
+ use_checkpoint=use_checkpoint,
414
+ pretrained_window_size=pretrained_window_sizes[i_layer])
415
+ self.layers.append(layer)
416
+
417
+ self.norm = norm_layer(self.num_features)
418
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
419
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
420
+
421
+ self.apply(self._init_weights)
422
+ for bly in self.layers:
423
+ bly._init_respostnorm()
424
+
425
+ def _init_weights(self, m):
426
+ if isinstance(m, nn.Linear):
427
+ trunc_normal_(m.weight, std=.02)
428
+ if isinstance(m, nn.Linear) and m.bias is not None:
429
+ nn.init.constant_(m.bias, 0)
430
+ elif isinstance(m, nn.LayerNorm):
431
+ nn.init.constant_(m.bias, 0)
432
+ nn.init.constant_(m.weight, 1.0)
433
+
434
+ @torch.jit.ignore
435
+ def no_weight_decay(self):
436
+ return {'absolute_pos_embed'}
437
+
438
+ @torch.jit.ignore
439
+ def no_weight_decay_keywords(self):
440
+ return {"cpb_mlp", "logit_scale", 'relative_position_bias_table'}
441
+
442
+ def forward(self, x):
443
+ x = self.patch_embed(x)
444
+ if self.ape:
445
+ x = x + self.absolute_pos_embed
446
+ x = self.pos_drop(x)
447
+
448
+ for li,layer in enumerate(self.layers):
449
+ print(li,'0',x.shape)
450
+ x = layer(x)
451
+ print(li,'1',x.shape)
452
+
453
+ x = self.norm(x)
454
+ return x
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio==4.44.0
2
+ torch==2.0.1
3
+ torchvision==0.15.2
4
+ opencv-python-headless==4.8.1.78
5
+ numpy==1.24.3
6
+ pillow==10.0.0
7
+ jpegio==0.2.3
8
+ segmentation-models-pytorch==0.3.3
9
+ timm==0.9.12
10
+ efficientnet-pytorch==0.7.1
11
+ albumentations==1.3.1