lukhsaankumar commited on
Commit
df4a21a
·
1 Parent(s): 14a1b30

Deploy DeepFake Detector API - 2026-03-07 09:12:00

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .env.example +24 -0
  2. Dockerfile +43 -0
  3. Dockerfile.huggingface +43 -0
  4. README.md +178 -8
  5. README_HF.md +182 -0
  6. app/__init__.py +1 -0
  7. app/__pycache__/__init__.cpython-312.pyc +0 -0
  8. app/__pycache__/main.cpython-312.pyc +0 -0
  9. app/api/__init__.py +1 -0
  10. app/api/__pycache__/__init__.cpython-312.pyc +0 -0
  11. app/api/__pycache__/routes_health.cpython-312.pyc +0 -0
  12. app/api/__pycache__/routes_models.cpython-312.pyc +0 -0
  13. app/api/__pycache__/routes_predict.cpython-312.pyc +0 -0
  14. app/api/routes_health.py +62 -0
  15. app/api/routes_models.py +51 -0
  16. app/api/routes_predict.py +286 -0
  17. app/core/__init__.py +1 -0
  18. app/core/__pycache__/__init__.cpython-312.pyc +0 -0
  19. app/core/__pycache__/config.cpython-312.pyc +0 -0
  20. app/core/__pycache__/errors.cpython-312.pyc +0 -0
  21. app/core/__pycache__/logging.cpython-312.pyc +0 -0
  22. app/core/config.py +64 -0
  23. app/core/errors.py +53 -0
  24. app/core/logging.py +61 -0
  25. app/main.py +128 -0
  26. app/models/__init__.py +1 -0
  27. app/models/__pycache__/__init__.cpython-312.pyc +0 -0
  28. app/models/wrappers/__init__.py +1 -0
  29. app/models/wrappers/__pycache__/__init__.cpython-312.pyc +0 -0
  30. app/models/wrappers/__pycache__/base_wrapper.cpython-312.pyc +0 -0
  31. app/models/wrappers/__pycache__/cnn_transfer_wrapper.cpython-312.pyc +0 -0
  32. app/models/wrappers/__pycache__/deit_distilled_wrapper.cpython-312.pyc +0 -0
  33. app/models/wrappers/__pycache__/dummy_majority_fusion_wrapper.cpython-312.pyc +0 -0
  34. app/models/wrappers/__pycache__/dummy_random_wrapper.cpython-312.pyc +0 -0
  35. app/models/wrappers/__pycache__/gradfield_cnn_wrapper.cpython-312.pyc +0 -0
  36. app/models/wrappers/__pycache__/logreg_fusion_wrapper.cpython-312.pyc +0 -0
  37. app/models/wrappers/__pycache__/vit_base_wrapper.cpython-312.pyc +0 -0
  38. app/models/wrappers/base_wrapper.py +150 -0
  39. app/models/wrappers/cnn_transfer_wrapper.py +226 -0
  40. app/models/wrappers/deit_distilled_wrapper.py +312 -0
  41. app/models/wrappers/dummy_majority_fusion_wrapper.py +171 -0
  42. app/models/wrappers/dummy_random_wrapper.py +168 -0
  43. app/models/wrappers/gradfield_cnn_wrapper.py +401 -0
  44. app/models/wrappers/logreg_fusion_wrapper.py +161 -0
  45. app/models/wrappers/vit_base_wrapper.py +331 -0
  46. app/schemas/__init__.py +1 -0
  47. app/schemas/__pycache__/__init__.cpython-312.pyc +0 -0
  48. app/schemas/__pycache__/models.cpython-312.pyc +0 -0
  49. app/schemas/__pycache__/predict.cpython-312.pyc +0 -0
  50. app/schemas/models.py +53 -0
.env.example ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DeepFake Detector Backend - Environment Variables
2
+ # Copy this file to .env and update with your values
3
+
4
+ # Hugging Face Configuration
5
+ # Available fusion models:
6
+ # - DeepFakeDetector/fusion-logreg-final (Logistic Regression - default)
7
+ # - DeepFakeDetector/fusion-meta-final (Meta-classifier)
8
+ HF_FUSION_REPO_ID=DeepFakeDetector/fusion-logreg-final
9
+ HF_CACHE_DIR=.hf_cache
10
+ # HF_TOKEN=your_huggingface_token_here # Optional: for private repos
11
+
12
+ # Google Gemini API (Optional - for LLM explanations)
13
+ # GOOGLE_API_KEY=your_google_api_key_here
14
+
15
+ # Server Configuration
16
+ HOST=0.0.0.0
17
+ PORT=8000
18
+
19
+ # CORS Configuration (comma-separated list of allowed origins)
20
+ CORS_ORIGINS=http://localhost:8082,https://www.deepfake-detector.app,https://deepfake-detector.app
21
+
22
+ # Debugging
23
+ ENABLE_DEBUG=false
24
+ LOG_LEVEL=INFO
Dockerfile ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DeepFake Detector API - Hugging Face Spaces Docker Image
2
+ # Optimized for HF Spaces deployment with GPU support
3
+
4
+ FROM python:3.11-slim
5
+
6
+ # Set working directory
7
+ WORKDIR /app
8
+
9
+ # Set environment variables
10
+ ENV PYTHONDONTWRITEBYTECODE=1 \
11
+ PYTHONUNBUFFERED=1 \
12
+ PIP_NO_CACHE_DIR=1 \
13
+ PIP_DISABLE_PIP_VERSION_CHECK=1 \
14
+ PORT=7860
15
+
16
+ # Install system dependencies
17
+ RUN apt-get update && apt-get install -y --no-install-recommends \
18
+ curl \
19
+ git \
20
+ && rm -rf /var/lib/apt/lists/*
21
+
22
+ # Create non-root user (HF Spaces requirement)
23
+ RUN useradd -m -u 1000 user
24
+ USER user
25
+
26
+ # Set PATH for user-installed packages
27
+ ENV PATH="/home/user/.local/bin:$PATH"
28
+
29
+ # Copy requirements and install dependencies as user
30
+ COPY --chown=user:user requirements.txt .
31
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
32
+
33
+ # Copy application code
34
+ COPY --chown=user:user . /app
35
+
36
+ # Create cache directory for Hugging Face models
37
+ RUN mkdir -p /app/.hf_cache
38
+
39
+ # Expose HF Spaces port
40
+ EXPOSE 7860
41
+
42
+ # Run the application (start.sh already defaults to port 7860)
43
+ CMD ["./start.sh"]
Dockerfile.huggingface ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DeepFake Detector API - Hugging Face Spaces Docker Image
2
+ # Optimized for HF Spaces deployment with GPU support
3
+
4
+ FROM python:3.11-slim
5
+
6
+ # Set working directory
7
+ WORKDIR /app
8
+
9
+ # Set environment variables
10
+ ENV PYTHONDONTWRITEBYTECODE=1 \
11
+ PYTHONUNBUFFERED=1 \
12
+ PIP_NO_CACHE_DIR=1 \
13
+ PIP_DISABLE_PIP_VERSION_CHECK=1 \
14
+ PORT=7860
15
+
16
+ # Install system dependencies
17
+ RUN apt-get update && apt-get install -y --no-install-recommends \
18
+ curl \
19
+ git \
20
+ && rm -rf /var/lib/apt/lists/*
21
+
22
+ # Create non-root user (HF Spaces requirement)
23
+ RUN useradd -m -u 1000 user
24
+ USER user
25
+
26
+ # Set PATH for user-installed packages
27
+ ENV PATH="/home/user/.local/bin:$PATH"
28
+
29
+ # Copy requirements and install dependencies as user
30
+ COPY --chown=user:user requirements.txt .
31
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
32
+
33
+ # Copy application code
34
+ COPY --chown=user:user . /app
35
+
36
+ # Create cache directory for Hugging Face models
37
+ RUN mkdir -p /app/.hf_cache
38
+
39
+ # Expose HF Spaces port
40
+ EXPOSE 7860
41
+
42
+ # Run the application (start.sh already defaults to port 7860)
43
+ CMD ["./start.sh"]
README.md CHANGED
@@ -1,12 +1,182 @@
1
  ---
2
- title: DeepFakeDetectorBackend
3
- emoji: 👁
4
- colorFrom: gray
5
- colorTo: yellow
6
  sdk: docker
7
- pinned: false
8
- license: mit
9
- short_description: FastAPI Backend for MacAI Society DeepFake Detector
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: DeepFake Detector API
3
+ emoji: 🎭
4
+ colorFrom: blue
5
+ colorTo: purple
6
  sdk: docker
7
+ app_port: 7860
 
 
8
  ---
9
 
10
+ # 🎭 DeepFake Detector API
11
+
12
+ FastAPI backend for detecting AI-generated (deepfake) images using an ensemble of state-of-the-art deep learning models.
13
+
14
+ ## 🤖 Models
15
+
16
+ This API uses a fusion ensemble of 5 deep learning models:
17
+
18
+ - **CNN Transfer** (EfficientNet-B0) - Transfer learning from ImageNet
19
+ - **ViT Base** (Vision Transformer) - Attention-based architecture
20
+ - **DeiT Distilled** (Data-efficient Image Transformer) - Distilled ViT variant
21
+ - **Gradient Field CNN** - Custom architecture analyzing gradient patterns
22
+ - **FFT CNN** - Frequency domain analysis using Fast Fourier Transform
23
+
24
+ All models are combined using a **Logistic Regression stacking ensemble** for optimal accuracy.
25
+
26
+ ## 🔗 API Endpoints
27
+
28
+ | Endpoint | Method | Description |
29
+ |----------|--------|-------------|
30
+ | `/health` | GET | Health check - returns API status |
31
+ | `/ready` | GET | Model readiness check - confirms models are loaded |
32
+ | `/models` | GET | List all loaded models with metadata |
33
+ | `/predict` | POST | Predict if an image is real or AI-generated |
34
+ | `/docs` | GET | Interactive Swagger API documentation |
35
+ | `/redoc` | GET | Alternative API documentation |
36
+
37
+ ## 🚀 Usage Example
38
+
39
+ ### Using cURL
40
+
41
+ ```bash
42
+ # Check if API is ready
43
+ curl https://lukhsaankumar-deepfakedetectorbackend.hf.space/ready
44
+
45
+ # Make a prediction
46
+ curl -X POST "https://lukhsaankumar-deepfakedetectorbackend.hf.space/predict" \
47
+ -F "file=@image.jpg" \
48
+ -F "explain=true"
49
+ ```
50
+
51
+ ### Using Python
52
+
53
+ ```python
54
+ import requests
55
+
56
+ # Upload an image for prediction
57
+ url = "https://lukhsaankumar-deepfakedetectorbackend.hf.space/predict"
58
+ files = {"file": open("image.jpg", "rb")}
59
+ data = {"explain": True}
60
+
61
+ response = requests.post(url, files=files, data=data)
62
+ result = response.json()
63
+
64
+ print(f"Prediction: {result['prediction']}")
65
+ print(f"Confidence: {result['confidence']:.2%}")
66
+ print(f"Explanation: {result['explanation']}")
67
+ ```
68
+
69
+ ## 🎯 Response Format
70
+
71
+ ```json
72
+ {
73
+ "prediction": "fake",
74
+ "confidence": 0.8734,
75
+ "probabilities": {
76
+ "real": 0.1266,
77
+ "fake": 0.8734
78
+ },
79
+ "model_predictions": {
80
+ "cnn_transfer": {"prediction": "fake", "confidence": 0.89},
81
+ "vit_base": {"prediction": "fake", "confidence": 0.92},
82
+ "deit": {"prediction": "fake", "confidence": 0.85},
83
+ "gradient_field": {"prediction": "real", "confidence": 0.55},
84
+ "fft_cnn": {"prediction": "fake", "confidence": 0.78}
85
+ },
86
+ "fusion_confidence": 0.8734,
87
+ "explanation": "AI-powered analysis of the prediction...",
88
+ "processing_time_ms": 342
89
+ }
90
+ ```
91
+
92
+ ## 🔧 Configuration
93
+
94
+ ### Required Secrets
95
+
96
+ Set these in your Space Settings → Repository secrets:
97
+
98
+ | Secret | Description | Required |
99
+ |--------|-------------|----------|
100
+ | `GOOGLE_API_KEY` | Google Gemini API key for AI explanations | Yes |
101
+ | `HF_TOKEN` | Hugging Face token (auto-set by Spaces) | No |
102
+
103
+ ### Optional Environment Variables
104
+
105
+ | Variable | Default | Description |
106
+ |----------|---------|-------------|
107
+ | `HF_FUSION_REPO_ID` | `DeepFakeDetector/fusion-logreg-final` | Hugging Face model repository |
108
+ | `CORS_ORIGINS` | Multiple defaults | Comma-separated allowed CORS origins |
109
+ | `GEMINI_MODEL` | `gemini-2.5-flash` | Gemini model for explanations |
110
+
111
+ ## 🏗️ Architecture
112
+
113
+ ```
114
+ ┌─────────────┐
115
+ │ Client │
116
+ └──────┬──────┘
117
+
118
+
119
+ ┌─────────────────────────────────┐
120
+ │ FastAPI Backend │
121
+ │ ┌──────────────────────────┐ │
122
+ │ │ Model Registry │ │
123
+ │ │ ┌────────────────────┐ │ │
124
+ │ │ │ CNN Transfer │ │ │
125
+ │ │ │ ViT Base │ │ │
126
+ │ │ │ DeiT Distilled │ │ │
127
+ │ │ │ Gradient Field │ │ │
128
+ │ │ │ FFT CNN │ │ │
129
+ │ │ └────────────────────┘ │ │
130
+ │ │ ┌────────────────────┐ │ │
131
+ │ │ │ Fusion Ensemble │ │ │
132
+ │ │ │ (LogReg Stacking) │ │ │
133
+ │ │ └────────────────────┘ │ │
134
+ │ └──────────────────────────┘ │
135
+ │ ┌──────────────────────────┐ │
136
+ │ │ Gemini Explainer │ │
137
+ │ └──────────────��───────────┘ │
138
+ └─────────────────────────────────┘
139
+ ```
140
+
141
+ ## 📊 Performance
142
+
143
+ - **Accuracy**: ~87% on test set (OpenFake dataset)
144
+ - **Inference Time**: ~200-500ms per image (with GPU)
145
+ - **Model Size**: ~500MB total
146
+ - **Supported Formats**: JPG, PNG, WEBP
147
+
148
+ ## 🐛 Troubleshooting
149
+
150
+ ### Models not loading?
151
+ - Check the Logs tab for specific errors
152
+ - Verify `HF_FUSION_REPO_ID` points to a valid repository
153
+ - Ensure the repository is public or `HF_TOKEN` is set
154
+
155
+ ### Explanations not working?
156
+ - Verify `GOOGLE_API_KEY` is set in Space Settings
157
+ - Check if you have Gemini API quota remaining
158
+ - Review logs for API errors
159
+
160
+ ### CORS errors?
161
+ - Add your frontend domain to `CORS_ORIGINS` in Space Settings
162
+ - Format: `https://yourdomain.com,https://www.yourdomain.com`
163
+
164
+ ## 📚 Documentation
165
+
166
+ - **Interactive Docs**: Visit `/docs` for Swagger UI
167
+ - **ReDoc**: Visit `/redoc` for alternative documentation
168
+ - **Source Code**: [GitHub Repository](https://github.com/lukhsaankumar/DeepFakeDetector)
169
+
170
+ ## 📝 License
171
+
172
+ This project is part of the MacAI Society research initiative.
173
+
174
+ ## 🙏 Acknowledgments
175
+
176
+ - Models trained on OpenFake, ImageNet, and custom datasets
177
+ - Powered by PyTorch, Hugging Face, and FastAPI
178
+ - AI explanations by Google Gemini
179
+
180
+ ---
181
+
182
+ **Built with ❤️ by MacAI Society**
README_HF.md ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: DeepFake Detector API
3
+ emoji: 🎭
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: docker
7
+ app_port: 7860
8
+ ---
9
+
10
+ # 🎭 DeepFake Detector API
11
+
12
+ FastAPI backend for detecting AI-generated (deepfake) images using an ensemble of state-of-the-art deep learning models.
13
+
14
+ ## 🤖 Models
15
+
16
+ This API uses a fusion ensemble of 5 deep learning models:
17
+
18
+ - **CNN Transfer** (EfficientNet-B0) - Transfer learning from ImageNet
19
+ - **ViT Base** (Vision Transformer) - Attention-based architecture
20
+ - **DeiT Distilled** (Data-efficient Image Transformer) - Distilled ViT variant
21
+ - **Gradient Field CNN** - Custom architecture analyzing gradient patterns
22
+ - **FFT CNN** - Frequency domain analysis using Fast Fourier Transform
23
+
24
+ All models are combined using a **Logistic Regression stacking ensemble** for optimal accuracy.
25
+
26
+ ## 🔗 API Endpoints
27
+
28
+ | Endpoint | Method | Description |
29
+ |----------|--------|-------------|
30
+ | `/health` | GET | Health check - returns API status |
31
+ | `/ready` | GET | Model readiness check - confirms models are loaded |
32
+ | `/models` | GET | List all loaded models with metadata |
33
+ | `/predict` | POST | Predict if an image is real or AI-generated |
34
+ | `/docs` | GET | Interactive Swagger API documentation |
35
+ | `/redoc` | GET | Alternative API documentation |
36
+
37
+ ## 🚀 Usage Example
38
+
39
+ ### Using cURL
40
+
41
+ ```bash
42
+ # Check if API is ready
43
+ curl https://lukhsaankumar-deepfakedetectorbackend.hf.space/ready
44
+
45
+ # Make a prediction
46
+ curl -X POST "https://lukhsaankumar-deepfakedetectorbackend.hf.space/predict" \
47
+ -F "file=@image.jpg" \
48
+ -F "explain=true"
49
+ ```
50
+
51
+ ### Using Python
52
+
53
+ ```python
54
+ import requests
55
+
56
+ # Upload an image for prediction
57
+ url = "https://lukhsaankumar-deepfakedetectorbackend.hf.space/predict"
58
+ files = {"file": open("image.jpg", "rb")}
59
+ data = {"explain": True}
60
+
61
+ response = requests.post(url, files=files, data=data)
62
+ result = response.json()
63
+
64
+ print(f"Prediction: {result['prediction']}")
65
+ print(f"Confidence: {result['confidence']:.2%}")
66
+ print(f"Explanation: {result['explanation']}")
67
+ ```
68
+
69
+ ## 🎯 Response Format
70
+
71
+ ```json
72
+ {
73
+ "prediction": "fake",
74
+ "confidence": 0.8734,
75
+ "probabilities": {
76
+ "real": 0.1266,
77
+ "fake": 0.8734
78
+ },
79
+ "model_predictions": {
80
+ "cnn_transfer": {"prediction": "fake", "confidence": 0.89},
81
+ "vit_base": {"prediction": "fake", "confidence": 0.92},
82
+ "deit": {"prediction": "fake", "confidence": 0.85},
83
+ "gradient_field": {"prediction": "real", "confidence": 0.55},
84
+ "fft_cnn": {"prediction": "fake", "confidence": 0.78}
85
+ },
86
+ "fusion_confidence": 0.8734,
87
+ "explanation": "AI-powered analysis of the prediction...",
88
+ "processing_time_ms": 342
89
+ }
90
+ ```
91
+
92
+ ## 🔧 Configuration
93
+
94
+ ### Required Secrets
95
+
96
+ Set these in your Space Settings → Repository secrets:
97
+
98
+ | Secret | Description | Required |
99
+ |--------|-------------|----------|
100
+ | `GOOGLE_API_KEY` | Google Gemini API key for AI explanations | Yes |
101
+ | `HF_TOKEN` | Hugging Face token (auto-set by Spaces) | No |
102
+
103
+ ### Optional Environment Variables
104
+
105
+ | Variable | Default | Description |
106
+ |----------|---------|-------------|
107
+ | `HF_FUSION_REPO_ID` | `DeepFakeDetector/fusion-logreg-final` | Hugging Face model repository |
108
+ | `CORS_ORIGINS` | Multiple defaults | Comma-separated allowed CORS origins |
109
+ | `GEMINI_MODEL` | `gemini-2.5-flash` | Gemini model for explanations |
110
+
111
+ ## 🏗️ Architecture
112
+
113
+ ```
114
+ ┌─────────────┐
115
+ │ Client │
116
+ └──────┬──────┘
117
+
118
+
119
+ ┌─────────────────────────────────┐
120
+ │ FastAPI Backend │
121
+ │ ┌──────────────────────────┐ │
122
+ │ │ Model Registry │ │
123
+ │ │ ┌────────────────────┐ │ │
124
+ │ │ │ CNN Transfer │ │ │
125
+ │ │ │ ViT Base │ │ │
126
+ │ │ │ DeiT Distilled │ │ │
127
+ │ │ │ Gradient Field │ │ │
128
+ │ │ │ FFT CNN │ │ │
129
+ │ │ └────────────────────┘ │ │
130
+ │ │ ┌────────────────────┐ │ │
131
+ │ │ │ Fusion Ensemble │ │ │
132
+ │ │ │ (LogReg Stacking) │ │ │
133
+ │ │ └────────────────────┘ │ │
134
+ │ └──────────────────────────┘ │
135
+ │ ┌──────────────────────────┐ │
136
+ │ │ Gemini Explainer │ │
137
+ │ └──────────────────────────┘ │
138
+ └─────────────────────────────────┘
139
+ ```
140
+
141
+ ## 📊 Performance
142
+
143
+ - **Accuracy**: ~87% on test set (OpenFake dataset)
144
+ - **Inference Time**: ~200-500ms per image (with GPU)
145
+ - **Model Size**: ~500MB total
146
+ - **Supported Formats**: JPG, PNG, WEBP
147
+
148
+ ## 🐛 Troubleshooting
149
+
150
+ ### Models not loading?
151
+ - Check the Logs tab for specific errors
152
+ - Verify `HF_FUSION_REPO_ID` points to a valid repository
153
+ - Ensure the repository is public or `HF_TOKEN` is set
154
+
155
+ ### Explanations not working?
156
+ - Verify `GOOGLE_API_KEY` is set in Space Settings
157
+ - Check if you have Gemini API quota remaining
158
+ - Review logs for API errors
159
+
160
+ ### CORS errors?
161
+ - Add your frontend domain to `CORS_ORIGINS` in Space Settings
162
+ - Format: `https://yourdomain.com,https://www.yourdomain.com`
163
+
164
+ ## 📚 Documentation
165
+
166
+ - **Interactive Docs**: Visit `/docs` for Swagger UI
167
+ - **ReDoc**: Visit `/redoc` for alternative documentation
168
+ - **Source Code**: [GitHub Repository](https://github.com/lukhsaankumar/DeepFakeDetector)
169
+
170
+ ## 📝 License
171
+
172
+ This project is part of the MacAI Society research initiative.
173
+
174
+ ## 🙏 Acknowledgments
175
+
176
+ - Models trained on OpenFake, ImageNet, and custom datasets
177
+ - Powered by PyTorch, Hugging Face, and FastAPI
178
+ - AI explanations by Google Gemini
179
+
180
+ ---
181
+
182
+ **Built with ❤️ by MacAI Society**
app/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # DeepFake Detector Backend Application
app/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (188 Bytes). View file
 
app/__pycache__/main.cpython-312.pyc ADDED
Binary file (4.92 kB). View file
 
app/api/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # API module
app/api/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (192 Bytes). View file
 
app/api/__pycache__/routes_health.cpython-312.pyc ADDED
Binary file (2.08 kB). View file
 
app/api/__pycache__/routes_models.cpython-312.pyc ADDED
Binary file (1.74 kB). View file
 
app/api/__pycache__/routes_predict.cpython-312.pyc ADDED
Binary file (10.7 kB). View file
 
app/api/routes_health.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Health check routes.
3
+ """
4
+
5
+ from fastapi import APIRouter
6
+
7
+ from app.core.logging import get_logger
8
+ from app.schemas.models import HealthResponse, ReadyResponse
9
+ from app.services.model_registry import get_model_registry
10
+
11
+ logger = get_logger(__name__)
12
+ router = APIRouter(tags=["health"])
13
+
14
+
15
+ @router.get(
16
+ "/health",
17
+ response_model=HealthResponse,
18
+ summary="Health check",
19
+ description="Simple health check to verify the API is running"
20
+ )
21
+ async def health_check() -> HealthResponse:
22
+ """
23
+ Health check endpoint.
24
+
25
+ Returns OK if the API server is running.
26
+ """
27
+ return HealthResponse(status="ok")
28
+
29
+
30
+ @router.get(
31
+ "/ready",
32
+ response_model=ReadyResponse,
33
+ summary="Readiness check",
34
+ description="Check if models are loaded and the API is ready to serve predictions"
35
+ )
36
+ async def readiness_check() -> ReadyResponse:
37
+ """
38
+ Readiness check endpoint.
39
+
40
+ Verifies that models are loaded and ready for inference.
41
+ Returns detailed information about loaded models.
42
+ """
43
+ registry = get_model_registry()
44
+
45
+ if not registry.is_loaded:
46
+ return ReadyResponse(
47
+ status="not_ready",
48
+ models_loaded=False,
49
+ fusion_repo=None,
50
+ submodels=[]
51
+ )
52
+
53
+ return ReadyResponse(
54
+ status="ready",
55
+ models_loaded=True,
56
+ fusion_repo=registry.get_fusion_repo_id(),
57
+ submodels=[
58
+ model["repo_id"]
59
+ for model in registry.list_models()
60
+ if model["model_type"] == "submodel"
61
+ ]
62
+ )
app/api/routes_models.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model listing routes.
3
+ """
4
+
5
+ from fastapi import APIRouter
6
+
7
+ from app.core.logging import get_logger
8
+ from app.schemas.models import ModelsListResponse, ModelInfo
9
+ from app.services.model_registry import get_model_registry
10
+
11
+ logger = get_logger(__name__)
12
+ router = APIRouter(tags=["models"])
13
+
14
+
15
+ @router.get(
16
+ "/models",
17
+ response_model=ModelsListResponse,
18
+ summary="List loaded models",
19
+ description="Get information about all loaded models including fusion and submodels"
20
+ )
21
+ async def list_models() -> ModelsListResponse:
22
+ """
23
+ List all loaded models.
24
+
25
+ Returns information about the fusion model and all submodels,
26
+ including their Hugging Face repository IDs and configurations.
27
+ """
28
+ registry = get_model_registry()
29
+ models = registry.list_models()
30
+
31
+ fusion_info = None
32
+ submodels_info = []
33
+
34
+ for model in models:
35
+ model_info = ModelInfo(
36
+ repo_id=model["repo_id"],
37
+ name=model["name"],
38
+ model_type=model["model_type"],
39
+ config=model.get("config")
40
+ )
41
+
42
+ if model["model_type"] == "fusion":
43
+ fusion_info = model_info
44
+ else:
45
+ submodels_info.append(model_info)
46
+
47
+ return ModelsListResponse(
48
+ fusion=fusion_info,
49
+ submodels=submodels_info,
50
+ total_count=len(models)
51
+ )
app/api/routes_predict.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Prediction routes.
3
+ """
4
+
5
+ import base64
6
+ from typing import Optional
7
+
8
+ from fastapi import APIRouter, File, Form, HTTPException, Query, UploadFile
9
+
10
+ from app.core.errors import (
11
+ DeepFakeDetectorError,
12
+ ImageProcessingError,
13
+ InferenceError,
14
+ FusionError,
15
+ ModelNotFoundError,
16
+ ModelNotLoadedError
17
+ )
18
+ from app.core.logging import get_logger
19
+ from app.schemas.predict import (
20
+ PredictResponse,
21
+ PredictionResult,
22
+ TimingInfo,
23
+ ErrorResponse,
24
+ FusionMeta,
25
+ ModelDisplayInfo,
26
+ ExplainModelResponse,
27
+ SingleModelInsight
28
+ )
29
+ from app.services.inference_service import get_inference_service
30
+ from app.services.fusion_service import get_fusion_service
31
+ from app.services.preprocess_service import get_preprocess_service
32
+ from app.services.model_registry import get_model_registry
33
+ from app.services.llm_service import get_llm_service, get_model_display_info, MODEL_DISPLAY_INFO
34
+ from app.utils.timing import Timer
35
+
36
+ logger = get_logger(__name__)
37
+ router = APIRouter(tags=["predict"])
38
+
39
+
40
+ @router.post(
41
+ "/predict",
42
+ response_model=PredictResponse,
43
+ summary="Predict if image is real or fake",
44
+ description="Upload an image to get a deepfake detection prediction",
45
+ responses={
46
+ 400: {"model": ErrorResponse, "description": "Invalid image or request"},
47
+ 404: {"model": ErrorResponse, "description": "Model not found"},
48
+ 500: {"model": ErrorResponse, "description": "Inference error"}
49
+ }
50
+ )
51
+ async def predict(
52
+ image: UploadFile = File(..., description="Image file to analyze"),
53
+ use_fusion: bool = Query(
54
+ True,
55
+ description="Use fusion model (majority vote) across all submodels"
56
+ ),
57
+ model: Optional[str] = Query(
58
+ None,
59
+ description="Specific submodel to use (name or repo_id). Only used when use_fusion=false"
60
+ ),
61
+ return_submodels: Optional[bool] = Query(
62
+ None,
63
+ description="Include individual submodel predictions in response. Defaults to true when use_fusion=true"
64
+ ),
65
+ explain: bool = Query(
66
+ True,
67
+ description="Generate explainability heatmaps (Grad-CAM for CNNs, attention rollout for transformers)"
68
+ )
69
+ ) -> PredictResponse:
70
+ """
71
+ Predict if an uploaded image is real or fake.
72
+
73
+ When use_fusion=true (default):
74
+ - Runs all submodels on the image
75
+ - Combines predictions using majority vote fusion
76
+ - Returns the fused result plus optionally individual submodel results
77
+
78
+ When use_fusion=false:
79
+ - Runs only the specified submodel (or the first available if not specified)
80
+ - Returns just that model's prediction
81
+
82
+ Response includes timing information for each step.
83
+ """
84
+ timer = Timer()
85
+ timer.start_total()
86
+
87
+ # Determine if we should return submodel results
88
+ should_return_submodels = return_submodels if return_submodels is not None else use_fusion
89
+
90
+ try:
91
+ # Read image bytes
92
+ with timer.measure("download"):
93
+ image_bytes = await image.read()
94
+
95
+ # Validate and preprocess
96
+ with timer.measure("preprocess"):
97
+ preprocess_service = get_preprocess_service()
98
+ preprocess_service.validate_image(image_bytes)
99
+
100
+ inference_service = get_inference_service()
101
+ fusion_service = get_fusion_service()
102
+ registry = get_model_registry()
103
+
104
+ if use_fusion:
105
+ # Run all submodels
106
+ with timer.measure("inference"):
107
+ submodel_outputs = inference_service.predict_all_submodels(
108
+ image_bytes=image_bytes,
109
+ explain=explain
110
+ )
111
+
112
+ # Run fusion
113
+ with timer.measure("fusion"):
114
+ final_result = fusion_service.fuse(submodel_outputs=submodel_outputs)
115
+
116
+ timer.stop_total()
117
+
118
+ # Extract fusion meta (contribution percentages)
119
+ fusion_meta_dict = final_result.get("meta", {})
120
+ contribution_percentages = fusion_meta_dict.get("contribution_percentages", {})
121
+
122
+ # Build fusion meta object
123
+ fusion_meta = FusionMeta(
124
+ submodel_weights=fusion_meta_dict.get("submodel_weights", {}),
125
+ weighted_contributions=fusion_meta_dict.get("weighted_contributions", {}),
126
+ contribution_percentages=contribution_percentages
127
+ ) if fusion_meta_dict else None
128
+
129
+ # Build model display info for frontend
130
+ model_display_info = {
131
+ name: ModelDisplayInfo(**get_model_display_info(name))
132
+ for name in submodel_outputs.keys()
133
+ }
134
+
135
+ # Build response
136
+ return PredictResponse(
137
+ final=PredictionResult(
138
+ pred=final_result["pred"],
139
+ pred_int=final_result["pred_int"],
140
+ prob_fake=final_result["prob_fake"]
141
+ ),
142
+ fusion_used=True,
143
+ submodels={
144
+ name: PredictionResult(
145
+ pred=output["pred"],
146
+ pred_int=output["pred_int"],
147
+ prob_fake=output["prob_fake"],
148
+ heatmap_base64=output.get("heatmap_base64"),
149
+ explainability_type=output.get("explainability_type"),
150
+ focus_summary=output.get("focus_summary"),
151
+ contribution_percentage=contribution_percentages.get(name)
152
+ )
153
+ for name, output in submodel_outputs.items()
154
+ } if should_return_submodels else None,
155
+ fusion_meta=fusion_meta,
156
+ model_display_info=model_display_info if should_return_submodels else None,
157
+ timing_ms=TimingInfo(**timer.get_timings())
158
+ )
159
+
160
+ else:
161
+ # Single model prediction
162
+ model_key = model or registry.get_submodel_names()[0]
163
+
164
+ with timer.measure("inference"):
165
+ result = inference_service.predict_single(
166
+ model_key=model_key,
167
+ image_bytes=image_bytes,
168
+ explain=explain
169
+ )
170
+
171
+ timer.stop_total()
172
+
173
+ return PredictResponse(
174
+ final=PredictionResult(
175
+ pred=result["pred"],
176
+ pred_int=result["pred_int"],
177
+ prob_fake=result["prob_fake"],
178
+ heatmap_base64=result.get("heatmap_base64"),
179
+ explainability_type=result.get("explainability_type"),
180
+ focus_summary=result.get("focus_summary")
181
+ ),
182
+ fusion_used=False,
183
+ submodels=None,
184
+ timing_ms=TimingInfo(**timer.get_timings())
185
+ )
186
+
187
+ except ImageProcessingError as e:
188
+ logger.warning(f"Image processing error: {e.message}")
189
+ raise HTTPException(
190
+ status_code=400,
191
+ detail={"error": "ImageProcessingError", "message": e.message, "details": e.details}
192
+ )
193
+
194
+ except ModelNotFoundError as e:
195
+ logger.warning(f"Model not found: {e.message}")
196
+ raise HTTPException(
197
+ status_code=404,
198
+ detail={"error": "ModelNotFoundError", "message": e.message, "details": e.details}
199
+ )
200
+
201
+ except ModelNotLoadedError as e:
202
+ logger.error(f"Models not loaded: {e.message}")
203
+ raise HTTPException(
204
+ status_code=503,
205
+ detail={"error": "ModelNotLoadedError", "message": e.message, "details": e.details}
206
+ )
207
+
208
+ except (InferenceError, FusionError) as e:
209
+ logger.error(f"Inference/Fusion error: {e.message}")
210
+ raise HTTPException(
211
+ status_code=500,
212
+ detail={"error": type(e).__name__, "message": e.message, "details": e.details}
213
+ )
214
+
215
+ except Exception as e:
216
+ logger.exception(f"Unexpected error in predict endpoint: {e}")
217
+ raise HTTPException(
218
+ status_code=500,
219
+ detail={"error": "InternalError", "message": str(e)}
220
+ )
221
+
222
+
223
+ @router.post("/explain-model", response_model=ExplainModelResponse)
224
+ async def explain_model(
225
+ image: UploadFile = File(...),
226
+ model_name: str = Form(...),
227
+ prob_fake: float = Form(...),
228
+ contribution_percentage: float = Form(None),
229
+ heatmap_base64: str = Form(None),
230
+ focus_summary: str = Form(None)
231
+ ):
232
+ """
233
+ Generate an on-demand LLM explanation for a single model's prediction.
234
+ This endpoint is token-efficient - only called when user requests insights.
235
+ """
236
+ try:
237
+ # Read and validate image
238
+ image_bytes = await image.read()
239
+ if len(image_bytes) == 0:
240
+ raise HTTPException(status_code=400, detail="Empty image file")
241
+
242
+ # Encode image to base64 for LLM
243
+ original_b64 = base64.b64encode(image_bytes).decode('utf-8')
244
+
245
+ # Get LLM service
246
+ llm_service = get_llm_service()
247
+ if not llm_service.enabled:
248
+ raise HTTPException(
249
+ status_code=503,
250
+ detail="LLM service is not enabled. Set GEMINI_API_KEY environment variable."
251
+ )
252
+
253
+ # Generate explanation
254
+ result = llm_service.generate_single_model_explanation(
255
+ model_name=model_name,
256
+ original_image_b64=original_b64,
257
+ prob_fake=prob_fake,
258
+ heatmap_b64=heatmap_base64,
259
+ contribution_percentage=contribution_percentage,
260
+ focus_summary=focus_summary
261
+ )
262
+
263
+ if result is None:
264
+ raise HTTPException(
265
+ status_code=500,
266
+ detail="Failed to generate explanation from LLM"
267
+ )
268
+
269
+ return ExplainModelResponse(
270
+ model_name=model_name,
271
+ insight=SingleModelInsight(
272
+ key_finding=result["key_finding"],
273
+ what_model_saw=result["what_model_saw"],
274
+ important_regions=result["important_regions"],
275
+ confidence_qualifier=result["confidence_qualifier"]
276
+ )
277
+ )
278
+
279
+ except HTTPException:
280
+ raise
281
+ except Exception as e:
282
+ logger.exception(f"Error generating model explanation: {e}")
283
+ raise HTTPException(
284
+ status_code=500,
285
+ detail={"error": "ExplanationError", "message": str(e)}
286
+ )
app/core/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Core module
app/core/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (193 Bytes). View file
 
app/core/__pycache__/config.cpython-312.pyc ADDED
Binary file (2.22 kB). View file
 
app/core/__pycache__/errors.cpython-312.pyc ADDED
Binary file (2.54 kB). View file
 
app/core/__pycache__/logging.cpython-312.pyc ADDED
Binary file (2.52 kB). View file
 
app/core/config.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Application configuration with environment variable support.
3
+ """
4
+
5
+ import os
6
+ from functools import lru_cache
7
+ from pydantic_settings import BaseSettings
8
+ from typing import Optional
9
+
10
+
11
+ class Settings(BaseSettings):
12
+ """Application settings loaded from environment variables."""
13
+
14
+ # Hugging Face configuration
15
+ # Available fusion models:
16
+ # - DeepFakeDetector/fusion-logreg-final (Logistic Regression - default)
17
+ # - DeepFakeDetector/fusion-meta-final (Meta-classifier)
18
+ HF_FUSION_REPO_ID: str = "DeepFakeDetector/fusion-logreg-final"
19
+ HF_CACHE_DIR: str = ".hf_cache"
20
+ HF_TOKEN: Optional[str] = None
21
+
22
+ # Google Gemini API configuration
23
+ GOOGLE_API_KEY: Optional[str] = None
24
+ GEMINI_MODEL: str = "gemini-2.5-flash"
25
+
26
+ @property
27
+ def llm_enabled(self) -> bool:
28
+ """Check if LLM explanations are available."""
29
+ return self.GOOGLE_API_KEY is not None and len(self.GOOGLE_API_KEY) > 0
30
+
31
+ # Application configuration
32
+ ENABLE_DEBUG: bool = False
33
+ LOG_LEVEL: str = "INFO"
34
+
35
+ # Server configuration
36
+ HOST: str = "0.0.0.0"
37
+ PORT: int = 8000
38
+
39
+ # CORS configuration
40
+ CORS_ORIGINS: str = "http://localhost:5173,http://localhost:3000,https://www.deepfake-detector.app,https://deepfake-detector.app"
41
+
42
+ @property
43
+ def cors_origins_list(self) -> list[str]:
44
+ """Parse CORS origins from comma-separated string."""
45
+ return [origin.strip() for origin in self.CORS_ORIGINS.split(",") if origin.strip()]
46
+
47
+ # API configuration
48
+ API_V1_PREFIX: str = "/api/v1"
49
+ PROJECT_NAME: str = "DeepFake Detector API"
50
+ VERSION: str = "0.1.0"
51
+
52
+ class Config:
53
+ env_file = ".env"
54
+ env_file_encoding = "utf-8"
55
+ case_sensitive = True
56
+
57
+
58
+ @lru_cache()
59
+ def get_settings() -> Settings:
60
+ """Get cached settings instance."""
61
+ return Settings()
62
+
63
+
64
+ settings = get_settings()
app/core/errors.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom exceptions and error handling for the application.
3
+ """
4
+
5
+ from typing import Any, Dict, Optional
6
+
7
+
8
+ class DeepFakeDetectorError(Exception):
9
+ """Base exception for DeepFake Detector application."""
10
+
11
+ def __init__(
12
+ self,
13
+ message: str,
14
+ details: Optional[Dict[str, Any]] = None
15
+ ):
16
+ self.message = message
17
+ self.details = details or {}
18
+ super().__init__(self.message)
19
+
20
+
21
+ class ModelNotLoadedError(DeepFakeDetectorError):
22
+ """Raised when attempting to use a model that hasn't been loaded."""
23
+ pass
24
+
25
+
26
+ class ModelNotFoundError(DeepFakeDetectorError):
27
+ """Raised when a requested model is not found in the registry."""
28
+ pass
29
+
30
+
31
+ class HuggingFaceDownloadError(DeepFakeDetectorError):
32
+ """Raised when downloading from Hugging Face fails."""
33
+ pass
34
+
35
+
36
+ class ImageProcessingError(DeepFakeDetectorError):
37
+ """Raised when image processing/decoding fails."""
38
+ pass
39
+
40
+
41
+ class InferenceError(DeepFakeDetectorError):
42
+ """Raised when model inference fails."""
43
+ pass
44
+
45
+
46
+ class FusionError(DeepFakeDetectorError):
47
+ """Raised when fusion prediction fails."""
48
+ pass
49
+
50
+
51
+ class ConfigurationError(DeepFakeDetectorError):
52
+ """Raised when configuration is invalid or missing."""
53
+ pass
app/core/logging.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Logging configuration for the application.
3
+ """
4
+
5
+ import logging
6
+ import sys
7
+ from typing import Optional
8
+
9
+ from app.core.config import settings
10
+
11
+
12
+ def setup_logging(level: Optional[str] = None) -> logging.Logger:
13
+ """
14
+ Set up application logging.
15
+
16
+ Args:
17
+ level: Log level string (DEBUG, INFO, WARNING, ERROR, CRITICAL)
18
+
19
+ Returns:
20
+ Configured logger instance
21
+ """
22
+ log_level = level or settings.LOG_LEVEL
23
+
24
+ # Create formatter
25
+ formatter = logging.Formatter(
26
+ fmt="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s",
27
+ datefmt="%Y-%m-%d %H:%M:%S"
28
+ )
29
+
30
+ # Configure root logger
31
+ root_logger = logging.getLogger()
32
+ root_logger.setLevel(getattr(logging, log_level.upper(), logging.INFO))
33
+
34
+ # Remove existing handlers
35
+ for handler in root_logger.handlers[:]:
36
+ root_logger.removeHandler(handler)
37
+
38
+ # Add stdout handler
39
+ stdout_handler = logging.StreamHandler(sys.stdout)
40
+ stdout_handler.setFormatter(formatter)
41
+ root_logger.addHandler(stdout_handler)
42
+
43
+ # Set third-party loggers to WARNING to reduce noise
44
+ logging.getLogger("uvicorn").setLevel(logging.WARNING)
45
+ logging.getLogger("httpx").setLevel(logging.WARNING)
46
+ logging.getLogger("huggingface_hub").setLevel(logging.WARNING)
47
+
48
+ return root_logger
49
+
50
+
51
+ def get_logger(name: str) -> logging.Logger:
52
+ """
53
+ Get a named logger instance.
54
+
55
+ Args:
56
+ name: Logger name (typically __name__)
57
+
58
+ Returns:
59
+ Logger instance
60
+ """
61
+ return logging.getLogger(name)
app/main.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FastAPI application entry point.
3
+
4
+ DeepFake Detector API - Milestone 1: Hugging Face hosted dummy models.
5
+ """
6
+
7
+ from contextlib import asynccontextmanager
8
+ from typing import AsyncGenerator
9
+
10
+ from fastapi import FastAPI, Request
11
+ from fastapi.middleware.cors import CORSMiddleware
12
+ from fastapi.responses import JSONResponse
13
+
14
+ from app.api import routes_health, routes_models, routes_predict
15
+ from app.core.config import settings
16
+ from app.core.errors import DeepFakeDetectorError
17
+ from app.core.logging import setup_logging, get_logger
18
+ from app.services.model_registry import get_model_registry
19
+
20
+ # Set up logging
21
+ setup_logging()
22
+ logger = get_logger(__name__)
23
+
24
+
25
+ @asynccontextmanager
26
+ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
27
+ """
28
+ Application lifespan manager.
29
+
30
+ Handles startup and shutdown events:
31
+ - Startup: Load models from Hugging Face
32
+ - Shutdown: Cleanup resources
33
+ """
34
+ # Startup
35
+ logger.info("Starting DeepFake Detector API...")
36
+ logger.info(f"Configuration: HF_FUSION_REPO_ID={settings.HF_FUSION_REPO_ID}")
37
+ logger.info(f"Configuration: HF_CACHE_DIR={settings.HF_CACHE_DIR}")
38
+
39
+ # Load models from Hugging Face
40
+ try:
41
+ registry = get_model_registry()
42
+ await registry.load_from_fusion_repo(settings.HF_FUSION_REPO_ID)
43
+ logger.info("Models loaded successfully!")
44
+ except Exception as e:
45
+ logger.error(f"Failed to load models on startup: {e}")
46
+ logger.warning("API will start but /ready will report not_ready until models are loaded")
47
+
48
+ yield # Application runs here
49
+
50
+ # Shutdown
51
+ logger.info("Shutting down DeepFake Detector API...")
52
+
53
+
54
+ # Create FastAPI application
55
+ app = FastAPI(
56
+ title=settings.PROJECT_NAME,
57
+ version=settings.VERSION,
58
+ description="""
59
+ DeepFake Detector API - Analyze images to detect AI-generated content.
60
+
61
+ ## Features
62
+
63
+ - **Fusion prediction**: Combines multiple model predictions using majority vote
64
+ - **Individual model prediction**: Run specific submodels directly
65
+ - **Timing information**: Detailed performance metrics for each request
66
+
67
+ ## Milestone 1
68
+
69
+ This is the initial milestone using dummy random models hosted on Hugging Face
70
+ for testing the API infrastructure.
71
+ """,
72
+ lifespan=lifespan,
73
+ debug=settings.ENABLE_DEBUG
74
+ )
75
+
76
+ # Add CORS middleware
77
+ app.add_middleware(
78
+ CORSMiddleware,
79
+ allow_origins=settings.cors_origins_list,
80
+ allow_credentials=True,
81
+ allow_methods=["*"],
82
+ allow_headers=["*"],
83
+ )
84
+
85
+ logger.info(f"CORS enabled for origins: {settings.cors_origins_list}")
86
+
87
+
88
+ # Global exception handler for custom errors
89
+ @app.exception_handler(DeepFakeDetectorError)
90
+ async def deepfake_error_handler(request: Request, exc: DeepFakeDetectorError):
91
+ """Handle custom DeepFakeDetector exceptions."""
92
+ return JSONResponse(
93
+ status_code=500,
94
+ content={
95
+ "error": type(exc).__name__,
96
+ "message": exc.message,
97
+ "details": exc.details
98
+ }
99
+ )
100
+
101
+
102
+ # Include routers
103
+ app.include_router(routes_health.router)
104
+ app.include_router(routes_models.router)
105
+ app.include_router(routes_predict.router)
106
+
107
+
108
+ # Root endpoint
109
+ @app.get("/", tags=["root"])
110
+ async def root():
111
+ """Root endpoint with API information."""
112
+ return {
113
+ "name": settings.PROJECT_NAME,
114
+ "version": settings.VERSION,
115
+ "docs": "/docs",
116
+ "health": "/health",
117
+ "ready": "/ready"
118
+ }
119
+
120
+
121
+ if __name__ == "__main__":
122
+ import uvicorn
123
+ uvicorn.run(
124
+ "app.main:app",
125
+ host=settings.HOST,
126
+ port=settings.PORT,
127
+ reload=settings.ENABLE_DEBUG
128
+ )
app/models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Models module
app/models/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (195 Bytes). View file
 
app/models/wrappers/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Model wrappers module
app/models/wrappers/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (204 Bytes). View file
 
app/models/wrappers/__pycache__/base_wrapper.cpython-312.pyc ADDED
Binary file (5.75 kB). View file
 
app/models/wrappers/__pycache__/cnn_transfer_wrapper.cpython-312.pyc ADDED
Binary file (9.91 kB). View file
 
app/models/wrappers/__pycache__/deit_distilled_wrapper.cpython-312.pyc ADDED
Binary file (13.7 kB). View file
 
app/models/wrappers/__pycache__/dummy_majority_fusion_wrapper.cpython-312.pyc ADDED
Binary file (7.04 kB). View file
 
app/models/wrappers/__pycache__/dummy_random_wrapper.cpython-312.pyc ADDED
Binary file (6.52 kB). View file
 
app/models/wrappers/__pycache__/gradfield_cnn_wrapper.cpython-312.pyc ADDED
Binary file (18.1 kB). View file
 
app/models/wrappers/__pycache__/logreg_fusion_wrapper.cpython-312.pyc ADDED
Binary file (6.73 kB). View file
 
app/models/wrappers/__pycache__/vit_base_wrapper.cpython-312.pyc ADDED
Binary file (14.6 kB). View file
 
app/models/wrappers/base_wrapper.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Base wrapper class for model wrappers.
3
+ """
4
+
5
+ from abc import ABC, abstractmethod
6
+ from typing import Any, Callable, Dict, Optional
7
+
8
+ from PIL import Image
9
+
10
+
11
+ class BaseModelWrapper(ABC):
12
+ """
13
+ Abstract base class for model wrappers.
14
+
15
+ All model wrappers should inherit from this class and implement
16
+ the abstract methods.
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ repo_id: str,
22
+ config: Dict[str, Any],
23
+ local_path: str
24
+ ):
25
+ """
26
+ Initialize the wrapper.
27
+
28
+ Args:
29
+ repo_id: Hugging Face repository ID
30
+ config: Configuration from config.json
31
+ local_path: Local path where the model files are stored
32
+ """
33
+ self.repo_id = repo_id
34
+ self.config = config
35
+ self.local_path = local_path
36
+ self._predict_fn: Optional[Callable] = None
37
+
38
+ @property
39
+ def name(self) -> str:
40
+ """
41
+ Get the short name of the model.
42
+
43
+ Prefers 'name' from config if available, otherwise derives from repo_id.
44
+ Strips '-final' suffix to ensure consistency with fusion configs.
45
+ """
46
+ # Try to get name from config first
47
+ config_name = self.config.get("name")
48
+ if config_name:
49
+ # Strip -final suffix if present
50
+ return config_name.replace("-final", "")
51
+
52
+ # Fall back to repo_id last part, strip -final suffix
53
+ repo_name = self.repo_id.split("/")[-1]
54
+ return repo_name.replace("-final", "")
55
+
56
+ @abstractmethod
57
+ def load(self) -> None:
58
+ """
59
+ Load the model and prepare for inference.
60
+
61
+ This method should import the predict function from the downloaded
62
+ repository and store it for later use.
63
+ """
64
+ pass
65
+
66
+ @abstractmethod
67
+ def predict(self, *args, **kwargs) -> Dict[str, Any]:
68
+ """
69
+ Run prediction.
70
+
71
+ Returns:
72
+ Dictionary with standardized prediction fields:
73
+ - pred_int: 0 (real) or 1 (fake)
74
+ - pred: "real" or "fake"
75
+ - prob_fake: float probability
76
+ - meta: dict with any additional metadata
77
+ """
78
+ pass
79
+
80
+ def is_loaded(self) -> bool:
81
+ """Check if the model is loaded and ready for inference."""
82
+ return self._predict_fn is not None
83
+
84
+ def get_info(self) -> Dict[str, Any]:
85
+ """
86
+ Get model information.
87
+
88
+ Returns:
89
+ Dictionary with model info
90
+ """
91
+ return {
92
+ "repo_id": self.repo_id,
93
+ "name": self.name,
94
+ "config": self.config,
95
+ "local_path": self.local_path,
96
+ "is_loaded": self.is_loaded()
97
+ }
98
+
99
+
100
+ class BaseSubmodelWrapper(BaseModelWrapper):
101
+ """Base wrapper for submodels that process images."""
102
+
103
+ @abstractmethod
104
+ def predict(
105
+ self,
106
+ image: Optional[Image.Image] = None,
107
+ image_bytes: Optional[bytes] = None,
108
+ explain: bool = False,
109
+ **kwargs
110
+ ) -> Dict[str, Any]:
111
+ """
112
+ Run prediction on an image.
113
+
114
+ Args:
115
+ image: PIL Image object
116
+ image_bytes: Raw image bytes (alternative to image)
117
+ explain: If True, include explainability heatmap in output
118
+ **kwargs: Additional arguments
119
+
120
+ Returns:
121
+ Standardized prediction dictionary with:
122
+ - pred_int: 0 (real) or 1 (fake)
123
+ - pred: "real" or "fake"
124
+ - prob_fake: float probability
125
+ - heatmap_base64: Optional[str] (when explain=True)
126
+ - explainability_type: Optional[str] (when explain=True)
127
+ """
128
+ pass
129
+
130
+
131
+ class BaseFusionWrapper(BaseModelWrapper):
132
+ """Base wrapper for fusion models that combine submodel outputs."""
133
+
134
+ @abstractmethod
135
+ def predict(
136
+ self,
137
+ submodel_outputs: Dict[str, Dict[str, Any]],
138
+ **kwargs
139
+ ) -> Dict[str, Any]:
140
+ """
141
+ Run fusion prediction on submodel outputs.
142
+
143
+ Args:
144
+ submodel_outputs: Dictionary mapping submodel name to its output
145
+ **kwargs: Additional arguments
146
+
147
+ Returns:
148
+ Standardized prediction dictionary
149
+ """
150
+ pass
app/models/wrappers/cnn_transfer_wrapper.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Wrapper for CNN Transfer (EfficientNet-B0) submodel.
3
+ """
4
+
5
+ import json
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from pathlib import Path
11
+ from typing import Any, Dict, Optional, Tuple
12
+ from PIL import Image
13
+ from torchvision import transforms
14
+ from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
15
+
16
+ from app.core.errors import InferenceError, ConfigurationError
17
+ from app.core.logging import get_logger
18
+ from app.models.wrappers.base_wrapper import BaseSubmodelWrapper
19
+ from app.services.explainability import GradCAM, heatmap_to_base64, compute_focus_summary
20
+
21
+ logger = get_logger(__name__)
22
+
23
+
24
+ class CNNTransferWrapper(BaseSubmodelWrapper):
25
+ """
26
+ Wrapper for CNN Transfer model using EfficientNet-B0 backbone.
27
+
28
+ Model expects 224x224 RGB images with ImageNet normalization.
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ repo_id: str,
34
+ config: Dict[str, Any],
35
+ local_path: str
36
+ ):
37
+ super().__init__(repo_id, config, local_path)
38
+ self._model: Optional[nn.Module] = None
39
+ self._transform: Optional[transforms.Compose] = None
40
+ self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
41
+ self._threshold = config.get("threshold", 0.5)
42
+ logger.info(f"Initialized CNNTransferWrapper for {repo_id}")
43
+
44
+ def load(self) -> None:
45
+ """Load the EfficientNet-B0 model with trained weights."""
46
+ weights_path = Path(self.local_path) / "model.pth"
47
+ preprocess_path = Path(self.local_path) / "preprocess.json"
48
+
49
+ if not weights_path.exists():
50
+ raise ConfigurationError(
51
+ message=f"model.pth not found in {self.local_path}",
52
+ details={"repo_id": self.repo_id, "expected_path": str(weights_path)}
53
+ )
54
+
55
+ try:
56
+ # Load preprocessing config
57
+ preprocess_config = {}
58
+ if preprocess_path.exists():
59
+ with open(preprocess_path, "r") as f:
60
+ preprocess_config = json.load(f)
61
+
62
+ # Build transform pipeline
63
+ input_size = preprocess_config.get("input_size", [224, 224])
64
+ if isinstance(input_size, int):
65
+ input_size = [input_size, input_size]
66
+
67
+ normalize_config = preprocess_config.get("normalize", {})
68
+ mean = normalize_config.get("mean", [0.485, 0.456, 0.406])
69
+ std = normalize_config.get("std", [0.229, 0.224, 0.225])
70
+
71
+ self._transform = transforms.Compose([
72
+ transforms.Resize(input_size),
73
+ transforms.ToTensor(),
74
+ transforms.Normalize(mean=mean, std=std)
75
+ ])
76
+
77
+ # Create model architecture
78
+ num_classes = self.config.get("num_classes", 2)
79
+ self._model = efficientnet_b0(weights=None)
80
+
81
+ # Replace classifier for binary classification
82
+ in_features = self._model.classifier[1].in_features
83
+ self._model.classifier = nn.Sequential(
84
+ nn.Dropout(p=0.2, inplace=True),
85
+ nn.Linear(in_features, num_classes)
86
+ )
87
+
88
+ # Load trained weights
89
+ state_dict = torch.load(weights_path, map_location=self._device, weights_only=True)
90
+ self._model.load_state_dict(state_dict)
91
+ self._model.to(self._device)
92
+ self._model.eval()
93
+
94
+ # Mark as loaded
95
+ self._predict_fn = self._run_inference
96
+ logger.info(f"Loaded CNN Transfer model from {self.repo_id}")
97
+
98
+ except ConfigurationError:
99
+ raise
100
+ except Exception as e:
101
+ logger.error(f"Failed to load CNN Transfer model: {e}")
102
+ raise ConfigurationError(
103
+ message=f"Failed to load model: {e}",
104
+ details={"repo_id": self.repo_id, "error": str(e)}
105
+ )
106
+
107
+ def _run_inference(
108
+ self,
109
+ image_tensor: torch.Tensor,
110
+ explain: bool = False
111
+ ) -> Dict[str, Any]:
112
+ """Run model inference on preprocessed tensor."""
113
+ heatmap = None
114
+
115
+ if explain:
116
+ # Use GradCAM for explainability (requires gradients)
117
+ target_layer = self._model.features[-1] # Last MBConv block
118
+ gradcam = GradCAM(self._model, target_layer)
119
+ try:
120
+ # GradCAM needs gradients, so don't use no_grad
121
+ logits = self._model(image_tensor)
122
+ probs = F.softmax(logits, dim=1)
123
+ prob_fake = probs[0, 1].item()
124
+ pred_int = 1 if prob_fake >= self._threshold else 0
125
+
126
+ # Compute heatmap for predicted class
127
+ heatmap = gradcam(
128
+ image_tensor.clone(),
129
+ target_class=pred_int,
130
+ output_size=(224, 224)
131
+ )
132
+ finally:
133
+ gradcam.remove_hooks()
134
+ else:
135
+ with torch.no_grad():
136
+ logits = self._model(image_tensor)
137
+ probs = F.softmax(logits, dim=1)
138
+ prob_fake = probs[0, 1].item()
139
+ pred_int = 1 if prob_fake >= self._threshold else 0
140
+
141
+ result = {
142
+ "logits": logits[0].detach().cpu().numpy().tolist(),
143
+ "prob_fake": prob_fake,
144
+ "pred_int": pred_int
145
+ }
146
+
147
+ if heatmap is not None:
148
+ result["heatmap"] = heatmap
149
+
150
+ return result
151
+
152
+ def predict(
153
+ self,
154
+ image: Optional[Image.Image] = None,
155
+ image_bytes: Optional[bytes] = None,
156
+ explain: bool = False,
157
+ **kwargs
158
+ ) -> Dict[str, Any]:
159
+ """
160
+ Run prediction on an image.
161
+
162
+ Args:
163
+ image: PIL Image object
164
+ image_bytes: Raw image bytes (will be converted to PIL Image)
165
+ explain: If True, compute GradCAM heatmap
166
+
167
+ Returns:
168
+ Standardized prediction dictionary with optional heatmap
169
+ """
170
+ if self._model is None or self._transform is None:
171
+ raise InferenceError(
172
+ message="Model not loaded",
173
+ details={"repo_id": self.repo_id}
174
+ )
175
+
176
+ try:
177
+ # Convert bytes to PIL Image if needed
178
+ if image is None and image_bytes is not None:
179
+ import io
180
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
181
+ elif image is not None:
182
+ image = image.convert("RGB")
183
+ else:
184
+ raise InferenceError(
185
+ message="No image provided",
186
+ details={"repo_id": self.repo_id}
187
+ )
188
+
189
+ # Preprocess
190
+ image_tensor = self._transform(image).unsqueeze(0).to(self._device)
191
+
192
+ # Run inference
193
+ result = self._run_inference(image_tensor, explain=explain)
194
+
195
+ # Standardize output
196
+ labels = self.config.get("labels", {"0": "real", "1": "fake"})
197
+ pred_int = result["pred_int"]
198
+
199
+ output = {
200
+ "pred_int": pred_int,
201
+ "pred": labels.get(str(pred_int), "unknown"),
202
+ "prob_fake": result["prob_fake"],
203
+ "meta": {
204
+ "model": self.name,
205
+ "threshold": self._threshold,
206
+ "logits": result["logits"]
207
+ }
208
+ }
209
+
210
+ # Add heatmap if requested
211
+ if explain and "heatmap" in result:
212
+ heatmap = result["heatmap"]
213
+ output["heatmap_base64"] = heatmap_to_base64(heatmap)
214
+ output["explainability_type"] = "grad_cam"
215
+ output["focus_summary"] = compute_focus_summary(heatmap)
216
+
217
+ return output
218
+
219
+ except InferenceError:
220
+ raise
221
+ except Exception as e:
222
+ logger.error(f"Prediction failed for {self.repo_id}: {e}")
223
+ raise InferenceError(
224
+ message=f"Prediction failed: {e}",
225
+ details={"repo_id": self.repo_id, "error": str(e)}
226
+ )
app/models/wrappers/deit_distilled_wrapper.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Wrapper for DeiT Distilled submodel.
3
+ """
4
+
5
+ import json
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from pathlib import Path
11
+ from typing import Any, Dict, List, Optional, Tuple
12
+ from PIL import Image
13
+ from torchvision import transforms
14
+
15
+ try:
16
+ import timm
17
+ TIMM_AVAILABLE = True
18
+ except ImportError:
19
+ TIMM_AVAILABLE = False
20
+
21
+ from app.core.errors import InferenceError, ConfigurationError
22
+ from app.core.logging import get_logger
23
+ from app.models.wrappers.base_wrapper import BaseSubmodelWrapper
24
+ from app.services.explainability import attention_rollout, heatmap_to_base64, compute_focus_summary
25
+
26
+ logger = get_logger(__name__)
27
+
28
+
29
+ def create_custom_mlp_head(in_features: int = 768, num_classes: int = 2) -> nn.Sequential:
30
+ """
31
+ Create custom MLP head for DeiT model matching training configuration.
32
+
33
+ Returns nn.Sequential to match saved state dict keys (0, 1, 4 indices).
34
+ """
35
+ return nn.Sequential(
36
+ nn.LayerNorm(in_features), # 0
37
+ nn.Linear(in_features, 512), # 1
38
+ nn.GELU(), # 2 (no params)
39
+ nn.Dropout(p=0.2), # 3 (no params)
40
+ nn.Linear(512, num_classes) # 4
41
+ )
42
+
43
+
44
+ class DeiTDistilledWrapper(BaseSubmodelWrapper):
45
+ """
46
+ Wrapper for DeiT Distilled model.
47
+
48
+ Model expects 224x224 RGB images with ImageNet normalization.
49
+ Uses a custom MLP head for classification.
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ repo_id: str,
55
+ config: Dict[str, Any],
56
+ local_path: str
57
+ ):
58
+ super().__init__(repo_id, config, local_path)
59
+ self._model: Optional[nn.Module] = None
60
+ self._transform: Optional[transforms.Compose] = None
61
+ self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
62
+ self._threshold = config.get("threshold", 0.5)
63
+ logger.info(f"Initialized DeiTDistilledWrapper for {repo_id}")
64
+
65
+ def load(self) -> None:
66
+ """Load the DeiT model with custom head and trained weights."""
67
+ if not TIMM_AVAILABLE:
68
+ raise ConfigurationError(
69
+ message="timm package not installed. Run: pip install timm",
70
+ details={"repo_id": self.repo_id}
71
+ )
72
+
73
+ weights_path = Path(self.local_path) / "deit_distilled_final.pt"
74
+ preprocess_path = Path(self.local_path) / "preprocess.json"
75
+
76
+ if not weights_path.exists():
77
+ raise ConfigurationError(
78
+ message=f"deit_distilled_final.pt not found in {self.local_path}",
79
+ details={"repo_id": self.repo_id, "expected_path": str(weights_path)}
80
+ )
81
+
82
+ try:
83
+ # Load preprocessing config
84
+ preprocess_config = {}
85
+ if preprocess_path.exists():
86
+ with open(preprocess_path, "r") as f:
87
+ preprocess_config = json.load(f)
88
+
89
+ # Build transform pipeline
90
+ input_size = preprocess_config.get("input_size", 224)
91
+ if isinstance(input_size, list):
92
+ input_size = input_size[0]
93
+
94
+ normalize_config = preprocess_config.get("normalize", {})
95
+ mean = normalize_config.get("mean", [0.485, 0.456, 0.406])
96
+ std = normalize_config.get("std", [0.229, 0.224, 0.225])
97
+
98
+ # Use bicubic interpolation as specified
99
+ interpolation = preprocess_config.get("interpolation", "bicubic")
100
+ interp_mode = transforms.InterpolationMode.BICUBIC if interpolation == "bicubic" else transforms.InterpolationMode.BILINEAR
101
+
102
+ self._transform = transforms.Compose([
103
+ transforms.Resize((input_size, input_size), interpolation=interp_mode),
104
+ transforms.ToTensor(),
105
+ transforms.Normalize(mean=mean, std=std)
106
+ ])
107
+
108
+ # Create model architecture
109
+ model_name = self.config.get("model_name", "deit_base_distilled_patch16_224")
110
+ num_classes = self.config.get("num_classes", 2)
111
+
112
+ # Create base model without pretrained weights
113
+ self._model = timm.create_model(model_name, pretrained=False, num_classes=0)
114
+
115
+ # Replace heads with custom MLP heads (Sequential assigned directly)
116
+ # Note: state dict has separate keys for head and head_dist, so don't share
117
+ hidden_dim = 768 # DeiT base hidden dimension
118
+ self._model.head = create_custom_mlp_head(hidden_dim, num_classes)
119
+ self._model.head_dist = create_custom_mlp_head(hidden_dim, num_classes)
120
+
121
+ # Load trained weights
122
+ state_dict = torch.load(weights_path, map_location=self._device, weights_only=True)
123
+ self._model.load_state_dict(state_dict)
124
+ self._model.to(self._device)
125
+ self._model.eval()
126
+
127
+ # Mark as loaded
128
+ self._predict_fn = self._run_inference
129
+ logger.info(f"Loaded DeiT Distilled model from {self.repo_id}")
130
+
131
+ except ConfigurationError:
132
+ raise
133
+ except Exception as e:
134
+ logger.error(f"Failed to load DeiT Distilled model: {e}")
135
+ raise ConfigurationError(
136
+ message=f"Failed to load model: {e}",
137
+ details={"repo_id": self.repo_id, "error": str(e)}
138
+ )
139
+
140
+ def _run_inference(
141
+ self,
142
+ image_tensor: torch.Tensor,
143
+ explain: bool = False
144
+ ) -> Dict[str, Any]:
145
+ """Run model inference on preprocessed tensor."""
146
+ heatmap = None
147
+
148
+ if explain:
149
+ # Collect attention weights from all blocks
150
+ attentions: List[torch.Tensor] = []
151
+ handles = []
152
+
153
+ # Hook into attention modules to capture weights
154
+ # DeiT blocks structure: blocks[i].attn
155
+ def create_attn_hook():
156
+ stored_attn = []
157
+
158
+ def hook(module, inputs, outputs):
159
+ # Get q, k from the module's forward computation
160
+ # inputs[0] is x of shape [B, N, C]
161
+ x = inputs[0]
162
+ B, N, C = x.shape
163
+
164
+ # Access the attention module's parameters
165
+ qkv = module.qkv(x) # [B, N, 3*dim]
166
+ qkv = qkv.reshape(B, N, 3, module.num_heads, C // module.num_heads)
167
+ qkv = qkv.permute(2, 0, 3, 1, 4) # [3, B, heads, N, dim_head]
168
+ q, k, v = qkv[0], qkv[1], qkv[2]
169
+
170
+ # Compute attention weights
171
+ scale = (C // module.num_heads) ** -0.5
172
+ attn = (q @ k.transpose(-2, -1)) * scale
173
+ attn = attn.softmax(dim=-1) # [B, heads, N, N]
174
+
175
+ # Average over heads
176
+ attn_avg = attn.mean(dim=1) # [B, N, N]
177
+ stored_attn.append(attn_avg.detach())
178
+
179
+ return hook, stored_attn
180
+
181
+ all_stored_attns = []
182
+ for block in self._model.blocks:
183
+ hook_fn, stored = create_attn_hook()
184
+ all_stored_attns.append(stored)
185
+ handle = block.attn.register_forward_hook(hook_fn)
186
+ handles.append(handle)
187
+
188
+ try:
189
+ with torch.no_grad():
190
+ logits = self._model(image_tensor)
191
+ probs = F.softmax(logits, dim=1)
192
+ prob_fake = probs[0, 1].item()
193
+ pred_int = 1 if prob_fake >= self._threshold else 0
194
+
195
+ # Get attention from hooks
196
+ attention_list = [stored[0] for stored in all_stored_attns if len(stored) > 0]
197
+
198
+ if attention_list:
199
+ # Stack: [num_layers, B, N, N]
200
+ attention_stack = torch.stack(attention_list, dim=0)
201
+ # Compute rollout - returns (grid_size, grid_size) heatmap
202
+ attention_map = attention_rollout(
203
+ attention_stack[:, 0], # [num_layers, N, N]
204
+ head_fusion="mean", # Already averaged
205
+ discard_ratio=0.0,
206
+ num_prefix_tokens=2 # DeiT has CLS + distillation token
207
+ ) # Returns (14, 14) for DeiT-Base
208
+
209
+ # Resize to image size
210
+ from PIL import Image as PILImage
211
+ heatmap_img = PILImage.fromarray(
212
+ (attention_map * 255).astype(np.uint8)
213
+ ).resize((224, 224), PILImage.BILINEAR)
214
+ heatmap = np.array(heatmap_img).astype(np.float32) / 255.0
215
+
216
+ finally:
217
+ for handle in handles:
218
+ handle.remove()
219
+ else:
220
+ with torch.no_grad():
221
+ # In eval mode, DeiT returns single tensor
222
+ logits = self._model(image_tensor)
223
+ probs = F.softmax(logits, dim=1)
224
+ prob_fake = probs[0, 1].item()
225
+ pred_int = 1 if prob_fake >= self._threshold else 0
226
+
227
+ result = {
228
+ "logits": logits[0].cpu().numpy().tolist(),
229
+ "prob_fake": prob_fake,
230
+ "pred_int": pred_int
231
+ }
232
+
233
+ if heatmap is not None:
234
+ result["heatmap"] = heatmap
235
+
236
+ return result
237
+
238
+ def predict(
239
+ self,
240
+ image: Optional[Image.Image] = None,
241
+ image_bytes: Optional[bytes] = None,
242
+ explain: bool = False,
243
+ **kwargs
244
+ ) -> Dict[str, Any]:
245
+ """
246
+ Run prediction on an image.
247
+
248
+ Args:
249
+ image: PIL Image object
250
+ image_bytes: Raw image bytes (will be converted to PIL Image)
251
+ explain: If True, compute attention rollout heatmap
252
+
253
+ Returns:
254
+ Standardized prediction dictionary with optional heatmap
255
+ """
256
+ if self._model is None or self._transform is None:
257
+ raise InferenceError(
258
+ message="Model not loaded",
259
+ details={"repo_id": self.repo_id}
260
+ )
261
+
262
+ try:
263
+ # Convert bytes to PIL Image if needed
264
+ if image is None and image_bytes is not None:
265
+ import io
266
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
267
+ elif image is not None:
268
+ image = image.convert("RGB")
269
+ else:
270
+ raise InferenceError(
271
+ message="No image provided",
272
+ details={"repo_id": self.repo_id}
273
+ )
274
+
275
+ # Preprocess
276
+ image_tensor = self._transform(image).unsqueeze(0).to(self._device)
277
+
278
+ # Run inference
279
+ result = self._run_inference(image_tensor, explain=explain)
280
+
281
+ # Standardize output
282
+ class_mapping = self.config.get("class_mapping", {"0": "real", "1": "fake"})
283
+ pred_int = result["pred_int"]
284
+
285
+ output = {
286
+ "pred_int": pred_int,
287
+ "pred": class_mapping.get(str(pred_int), "unknown"),
288
+ "prob_fake": result["prob_fake"],
289
+ "meta": {
290
+ "model": self.name,
291
+ "threshold": self._threshold,
292
+ "logits": result["logits"]
293
+ }
294
+ }
295
+
296
+ # Add heatmap if requested
297
+ if explain and "heatmap" in result:
298
+ heatmap = result["heatmap"]
299
+ output["heatmap_base64"] = heatmap_to_base64(heatmap)
300
+ output["explainability_type"] = "attention_rollout"
301
+ output["focus_summary"] = compute_focus_summary(heatmap)
302
+
303
+ return output
304
+
305
+ except InferenceError:
306
+ raise
307
+ except Exception as e:
308
+ logger.error(f"Prediction failed for {self.repo_id}: {e}")
309
+ raise InferenceError(
310
+ message=f"Prediction failed: {e}",
311
+ details={"repo_id": self.repo_id, "error": str(e)}
312
+ )
app/models/wrappers/dummy_majority_fusion_wrapper.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Wrapper for dummy majority vote fusion model.
3
+ """
4
+
5
+ import importlib.util
6
+ import sys
7
+ from pathlib import Path
8
+ from typing import Any, Dict, List
9
+
10
+ from app.core.errors import FusionError, ConfigurationError
11
+ from app.core.logging import get_logger
12
+ from app.models.wrappers.base_wrapper import BaseFusionWrapper
13
+
14
+ logger = get_logger(__name__)
15
+
16
+
17
+ class DummyMajorityFusionWrapper(BaseFusionWrapper):
18
+ """
19
+ Wrapper for dummy majority vote fusion models.
20
+
21
+ These models are hosted on Hugging Face and contain a fusion.py
22
+ with a predict() function that performs majority voting on submodel outputs.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ repo_id: str,
28
+ config: Dict[str, Any],
29
+ local_path: str
30
+ ):
31
+ """
32
+ Initialize the wrapper.
33
+
34
+ Args:
35
+ repo_id: Hugging Face repository ID (e.g., "DeepFakeDetector/fusion-majority-test")
36
+ config: Configuration from config.json
37
+ local_path: Local path where the model files are stored
38
+ """
39
+ super().__init__(repo_id, config, local_path)
40
+ self._submodel_repos: List[str] = config.get("submodels", [])
41
+ logger.info(f"Initialized DummyMajorityFusionWrapper for {repo_id}")
42
+ logger.info(f"Submodels: {self._submodel_repos}")
43
+
44
+ @property
45
+ def submodel_repos(self) -> List[str]:
46
+ """Get list of submodel repository IDs."""
47
+ return self._submodel_repos
48
+
49
+ def load(self) -> None:
50
+ """
51
+ Load the fusion predict function from the downloaded repository.
52
+
53
+ Dynamically imports predict.py and extracts the predict function.
54
+ """
55
+ fusion_path = Path(self.local_path) / "predict.py"
56
+
57
+ if not fusion_path.exists():
58
+ raise ConfigurationError(
59
+ message=f"predict.py not found in {self.local_path}",
60
+ details={"repo_id": self.repo_id, "expected_path": str(fusion_path)}
61
+ )
62
+
63
+ try:
64
+ # Create a unique module name to avoid conflicts
65
+ module_name = f"hf_model_{self.name.replace('-', '_')}_fusion"
66
+
67
+ # Load the module dynamically
68
+ spec = importlib.util.spec_from_file_location(module_name, fusion_path)
69
+ if spec is None or spec.loader is None:
70
+ raise ConfigurationError(
71
+ message=f"Could not load spec for {fusion_path}",
72
+ details={"repo_id": self.repo_id}
73
+ )
74
+
75
+ module = importlib.util.module_from_spec(spec)
76
+ sys.modules[module_name] = module
77
+ spec.loader.exec_module(module)
78
+
79
+ # Get the predict function
80
+ if not hasattr(module, "predict"):
81
+ raise ConfigurationError(
82
+ message=f"predict.py does not have a 'predict' function",
83
+ details={"repo_id": self.repo_id}
84
+ )
85
+
86
+ self._predict_fn = module.predict
87
+ logger.info(f"Loaded fusion predict function from {self.repo_id}")
88
+
89
+ except ConfigurationError:
90
+ raise
91
+ except Exception as e:
92
+ logger.error(f"Failed to load fusion function from {self.repo_id}: {e}")
93
+ raise ConfigurationError(
94
+ message=f"Failed to load fusion model: {e}",
95
+ details={"repo_id": self.repo_id, "error": str(e)}
96
+ )
97
+
98
+ def predict(
99
+ self,
100
+ submodel_outputs: Dict[str, Dict[str, Any]],
101
+ **kwargs
102
+ ) -> Dict[str, Any]:
103
+ """
104
+ Run fusion prediction on submodel outputs.
105
+
106
+ Args:
107
+ submodel_outputs: Dictionary mapping submodel name to its prediction output
108
+ **kwargs: Additional arguments passed to the fusion function
109
+
110
+ Returns:
111
+ Standardized prediction dictionary with:
112
+ - pred_int: 0 or 1
113
+ - pred: "real" or "fake"
114
+ - prob_fake: float (average of pred_ints)
115
+ - meta: dict
116
+ """
117
+ if self._predict_fn is None:
118
+ raise FusionError(
119
+ message="Fusion model not loaded",
120
+ details={"repo_id": self.repo_id}
121
+ )
122
+
123
+ try:
124
+ # Call the actual fusion predict function from the HF repo
125
+ result = self._predict_fn(submodel_outputs=submodel_outputs, **kwargs)
126
+
127
+ # Validate and standardize the output
128
+ standardized = self._standardize_output(result)
129
+ return standardized
130
+
131
+ except FusionError:
132
+ raise
133
+ except Exception as e:
134
+ logger.error(f"Fusion prediction failed for {self.repo_id}: {e}")
135
+ raise FusionError(
136
+ message=f"Fusion prediction failed: {e}",
137
+ details={"repo_id": self.repo_id, "error": str(e)}
138
+ )
139
+
140
+ def _standardize_output(self, result: Dict[str, Any]) -> Dict[str, Any]:
141
+ """
142
+ Standardize the fusion output to ensure consistent format.
143
+
144
+ Args:
145
+ result: Raw fusion output
146
+
147
+ Returns:
148
+ Standardized dictionary
149
+ """
150
+ pred_int = result.get("pred_int", 0)
151
+
152
+ # Ensure pred_int is 0 or 1
153
+ if pred_int not in (0, 1):
154
+ pred_int = 1 if pred_int > 0.5 else 0
155
+
156
+ # Generate pred label if not present
157
+ pred = result.get("pred")
158
+ if pred is None:
159
+ pred = "fake" if pred_int == 1 else "real"
160
+
161
+ # Generate prob_fake if not present
162
+ prob_fake = result.get("prob_fake")
163
+ if prob_fake is None:
164
+ prob_fake = float(pred_int)
165
+
166
+ return {
167
+ "pred_int": pred_int,
168
+ "pred": pred,
169
+ "prob_fake": float(prob_fake),
170
+ "meta": result.get("meta", {})
171
+ }
app/models/wrappers/dummy_random_wrapper.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Wrapper for dummy random submodels.
3
+ """
4
+
5
+ import importlib.util
6
+ import sys
7
+ from pathlib import Path
8
+ from typing import Any, Dict, Optional
9
+
10
+ from PIL import Image
11
+
12
+ from app.core.errors import InferenceError, ConfigurationError
13
+ from app.core.logging import get_logger
14
+ from app.models.wrappers.base_wrapper import BaseSubmodelWrapper
15
+
16
+ logger = get_logger(__name__)
17
+
18
+
19
+ class DummyRandomWrapper(BaseSubmodelWrapper):
20
+ """
21
+ Wrapper for dummy random prediction models.
22
+
23
+ These models are hosted on Hugging Face and contain a predict.py
24
+ with a predict() function that returns random predictions.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ repo_id: str,
30
+ config: Dict[str, Any],
31
+ local_path: str
32
+ ):
33
+ """
34
+ Initialize the wrapper.
35
+
36
+ Args:
37
+ repo_id: Hugging Face repository ID (e.g., "DeepFakeDetector/test-random-a")
38
+ config: Configuration from config.json
39
+ local_path: Local path where the model files are stored
40
+ """
41
+ super().__init__(repo_id, config, local_path)
42
+ logger.info(f"Initialized DummyRandomWrapper for {repo_id}")
43
+
44
+ def load(self) -> None:
45
+ """
46
+ Load the predict function from the downloaded repository.
47
+
48
+ Dynamically imports predict.py and extracts the predict function.
49
+ """
50
+ predict_path = Path(self.local_path) / "predict.py"
51
+
52
+ if not predict_path.exists():
53
+ raise ConfigurationError(
54
+ message=f"predict.py not found in {self.local_path}",
55
+ details={"repo_id": self.repo_id, "expected_path": str(predict_path)}
56
+ )
57
+
58
+ try:
59
+ # Create a unique module name to avoid conflicts
60
+ module_name = f"hf_model_{self.name.replace('-', '_')}_predict"
61
+
62
+ # Load the module dynamically
63
+ spec = importlib.util.spec_from_file_location(module_name, predict_path)
64
+ if spec is None or spec.loader is None:
65
+ raise ConfigurationError(
66
+ message=f"Could not load spec for {predict_path}",
67
+ details={"repo_id": self.repo_id}
68
+ )
69
+
70
+ module = importlib.util.module_from_spec(spec)
71
+ sys.modules[module_name] = module
72
+ spec.loader.exec_module(module)
73
+
74
+ # Get the predict function
75
+ if not hasattr(module, "predict"):
76
+ raise ConfigurationError(
77
+ message=f"predict.py does not have a 'predict' function",
78
+ details={"repo_id": self.repo_id}
79
+ )
80
+
81
+ self._predict_fn = module.predict
82
+ logger.info(f"Loaded predict function from {self.repo_id}")
83
+
84
+ except ConfigurationError:
85
+ raise
86
+ except Exception as e:
87
+ logger.error(f"Failed to load predict function from {self.repo_id}: {e}")
88
+ raise ConfigurationError(
89
+ message=f"Failed to load model: {e}",
90
+ details={"repo_id": self.repo_id, "error": str(e)}
91
+ )
92
+
93
+ def predict(
94
+ self,
95
+ image: Optional[Image.Image] = None,
96
+ image_bytes: Optional[bytes] = None,
97
+ **kwargs
98
+ ) -> Dict[str, Any]:
99
+ """
100
+ Run prediction on an image.
101
+
102
+ Args:
103
+ image: PIL Image object (optional for dummy model)
104
+ image_bytes: Raw image bytes (optional for dummy model)
105
+ **kwargs: Additional arguments passed to the model
106
+
107
+ Returns:
108
+ Standardized prediction dictionary with:
109
+ - pred_int: 0 or 1
110
+ - pred: "real" or "fake"
111
+ - prob_fake: float
112
+ - meta: dict
113
+ """
114
+ if self._predict_fn is None:
115
+ raise InferenceError(
116
+ message="Model not loaded",
117
+ details={"repo_id": self.repo_id}
118
+ )
119
+
120
+ try:
121
+ # Call the actual predict function from the HF repo
122
+ result = self._predict_fn(image_bytes=image_bytes, **kwargs)
123
+
124
+ # Validate and standardize the output
125
+ standardized = self._standardize_output(result)
126
+ return standardized
127
+
128
+ except InferenceError:
129
+ raise
130
+ except Exception as e:
131
+ logger.error(f"Prediction failed for {self.repo_id}: {e}")
132
+ raise InferenceError(
133
+ message=f"Prediction failed: {e}",
134
+ details={"repo_id": self.repo_id, "error": str(e)}
135
+ )
136
+
137
+ def _standardize_output(self, result: Dict[str, Any]) -> Dict[str, Any]:
138
+ """
139
+ Standardize the model output to ensure consistent format.
140
+
141
+ Args:
142
+ result: Raw model output
143
+
144
+ Returns:
145
+ Standardized dictionary
146
+ """
147
+ pred_int = result.get("pred_int", 0)
148
+
149
+ # Ensure pred_int is 0 or 1
150
+ if pred_int not in (0, 1):
151
+ pred_int = 1 if pred_int > 0.5 else 0
152
+
153
+ # Generate pred label if not present
154
+ pred = result.get("pred")
155
+ if pred is None:
156
+ pred = "fake" if pred_int == 1 else "real"
157
+
158
+ # Generate prob_fake if not present
159
+ prob_fake = result.get("prob_fake")
160
+ if prob_fake is None:
161
+ prob_fake = float(pred_int)
162
+
163
+ return {
164
+ "pred_int": pred_int,
165
+ "pred": pred,
166
+ "prob_fake": float(prob_fake),
167
+ "meta": result.get("meta", {})
168
+ }
app/models/wrappers/gradfield_cnn_wrapper.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Wrapper for Gradient Field CNN submodel.
3
+ """
4
+
5
+ import json
6
+ import math
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from pathlib import Path
12
+ from typing import Any, Dict, Optional, Tuple
13
+ from PIL import Image
14
+ from torchvision import transforms
15
+
16
+ from app.core.errors import InferenceError, ConfigurationError
17
+ from app.core.logging import get_logger
18
+ from app.models.wrappers.base_wrapper import BaseSubmodelWrapper
19
+ from app.services.explainability import heatmap_to_base64, compute_focus_summary
20
+
21
+ logger = get_logger(__name__)
22
+
23
+
24
+ class CompactGradientNet(nn.Module):
25
+ """
26
+ CNN for gradient field classification with discriminative features.
27
+
28
+ Input: Luminance image (1-channel)
29
+ Internal: Computes 6-channel gradient field [luminance, Gx, Gy, magnitude, angle, coherence]
30
+ Output: Logits and embeddings
31
+ """
32
+
33
+ def __init__(self, depth=4, base_filters=32, dropout=0.3, embedding_dim=128):
34
+ super().__init__()
35
+
36
+ # Sobel kernels
37
+ sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]],
38
+ dtype=torch.float32).view(1, 1, 3, 3)
39
+ sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]],
40
+ dtype=torch.float32).view(1, 1, 3, 3)
41
+ self.register_buffer('sobel_x', sobel_x)
42
+ self.register_buffer('sobel_y', sobel_y)
43
+
44
+ # Gaussian kernel for structure tensor smoothing
45
+ gaussian = torch.tensor([[1, 4, 6, 4, 1], [4, 16, 24, 16, 4],
46
+ [6, 24, 36, 24, 6], [4, 16, 24, 16, 4],
47
+ [1, 4, 6, 4, 1]], dtype=torch.float32) / 256.0
48
+ self.register_buffer('gaussian', gaussian.view(1, 1, 5, 5))
49
+
50
+ # Input normalization and channel mixing
51
+ self.input_norm = nn.BatchNorm2d(6)
52
+ self.channel_mix = nn.Sequential(
53
+ nn.Conv2d(6, 6, kernel_size=1),
54
+ nn.ReLU()
55
+ )
56
+
57
+ # CNN layers
58
+ layers = []
59
+ in_ch = 6
60
+ for i in range(depth):
61
+ out_ch = base_filters * (2**i)
62
+ layers.extend([
63
+ nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
64
+ nn.BatchNorm2d(out_ch),
65
+ nn.ReLU(),
66
+ nn.MaxPool2d(2)
67
+ ])
68
+ if dropout > 0:
69
+ layers.append(nn.Dropout2d(dropout))
70
+ in_ch = out_ch
71
+
72
+ self.cnn = nn.Sequential(*layers)
73
+ self.global_pool = nn.AdaptiveAvgPool2d(1)
74
+ self.embedding = nn.Linear(out_ch, embedding_dim)
75
+ self.classifier = nn.Linear(embedding_dim, 1)
76
+
77
+ def compute_gradient_field(self, luminance):
78
+ """Compute 6-channel gradient field on GPU (includes luminance)."""
79
+ G_x = F.conv2d(luminance, self.sobel_x, padding=1)
80
+ G_y = F.conv2d(luminance, self.sobel_y, padding=1)
81
+
82
+ magnitude = torch.sqrt(G_x**2 + G_y**2 + 1e-8)
83
+ angle = torch.atan2(G_y, G_x) / math.pi
84
+
85
+ # Structure tensor for coherence
86
+ Gxx, Gxy, Gyy = G_x * G_x, G_x * G_y, G_y * G_y
87
+ Sxx = F.conv2d(Gxx, self.gaussian, padding=2)
88
+ Sxy = F.conv2d(Gxy, self.gaussian, padding=2)
89
+ Syy = F.conv2d(Gyy, self.gaussian, padding=2)
90
+
91
+ trace = Sxx + Syy
92
+ det_term = torch.sqrt((Sxx - Syy)**2 + 4 * Sxy**2 + 1e-8)
93
+ lambda1, lambda2 = 0.5 * (trace + det_term), 0.5 * (trace - det_term)
94
+ coherence = ((lambda1 - lambda2) / (lambda1 + lambda2 + 1e-8))**2
95
+
96
+ magnitude_scaled = torch.log1p(magnitude * 10)
97
+
98
+ return torch.cat([luminance, G_x, G_y, magnitude_scaled, angle, coherence], dim=1)
99
+
100
+ def forward(self, luminance):
101
+ x = self.compute_gradient_field(luminance)
102
+ x = self.input_norm(x)
103
+ x = self.channel_mix(x)
104
+ x = self.cnn(x)
105
+ x = self.global_pool(x).flatten(1)
106
+ emb = self.embedding(x)
107
+ logit = self.classifier(emb)
108
+ return logit.squeeze(1), emb
109
+
110
+
111
+ class GradfieldCNNWrapper(BaseSubmodelWrapper):
112
+ """
113
+ Wrapper for Gradient Field CNN model.
114
+
115
+ Model expects 256x256 luminance images.
116
+ Internally computes Sobel gradients and other discriminative features.
117
+ """
118
+
119
+ # BT.709 luminance coefficients
120
+ R_COEFF = 0.2126
121
+ G_COEFF = 0.7152
122
+ B_COEFF = 0.0722
123
+
124
+ def __init__(
125
+ self,
126
+ repo_id: str,
127
+ config: Dict[str, Any],
128
+ local_path: str
129
+ ):
130
+ super().__init__(repo_id, config, local_path)
131
+ self._model: Optional[nn.Module] = None
132
+ self._resize: Optional[transforms.Resize] = None
133
+ self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
134
+ self._threshold = config.get("threshold", 0.5)
135
+ logger.info(f"Initialized GradfieldCNNWrapper for {repo_id}")
136
+
137
+ def load(self) -> None:
138
+ """Load the Gradient Field CNN model with trained weights."""
139
+ # Try different weight file names
140
+ weights_path = None
141
+ for fname in ["gradient_field_cnn_v3_finetuned.pth", "gradient_field_cnn_v2.pth", "weights.pt", "model.pth"]:
142
+ candidate = Path(self.local_path) / fname
143
+ if candidate.exists():
144
+ weights_path = candidate
145
+ break
146
+
147
+ preprocess_path = Path(self.local_path) / "preprocess.json"
148
+
149
+ if weights_path is None:
150
+ raise ConfigurationError(
151
+ message=f"No weights file found in {self.local_path}",
152
+ details={"repo_id": self.repo_id}
153
+ )
154
+
155
+ try:
156
+ # Load preprocessing config
157
+ preprocess_config = {}
158
+ if preprocess_path.exists():
159
+ with open(preprocess_path, "r") as f:
160
+ preprocess_config = json.load(f)
161
+
162
+ # Get input size (default 256 for gradient field)
163
+ input_size = preprocess_config.get("input_size", 256)
164
+ if isinstance(input_size, list):
165
+ input_size = input_size[0]
166
+
167
+ self._resize = transforms.Resize((input_size, input_size))
168
+
169
+ # Get model parameters from config
170
+ model_params = self.config.get("model_parameters", {})
171
+ depth = model_params.get("depth", 4)
172
+ base_filters = model_params.get("base_filters", 32)
173
+ dropout = model_params.get("dropout", 0.3)
174
+ embedding_dim = model_params.get("embedding_dim", 128)
175
+
176
+ # Create model
177
+ self._model = CompactGradientNet(
178
+ depth=depth,
179
+ base_filters=base_filters,
180
+ dropout=dropout,
181
+ embedding_dim=embedding_dim
182
+ )
183
+
184
+ # Load trained weights
185
+ # Note: weights_only=False needed because checkpoint contains numpy types
186
+ state_dict = torch.load(weights_path, map_location=self._device, weights_only=False)
187
+
188
+ # Handle different checkpoint formats
189
+ if isinstance(state_dict, dict):
190
+ if "model_state_dict" in state_dict:
191
+ state_dict = state_dict["model_state_dict"]
192
+ elif "state_dict" in state_dict:
193
+ state_dict = state_dict["state_dict"]
194
+ elif "model" in state_dict:
195
+ state_dict = state_dict["model"]
196
+
197
+ self._model.load_state_dict(state_dict)
198
+ self._model.to(self._device)
199
+ self._model.eval()
200
+
201
+ # Mark as loaded
202
+ self._predict_fn = self._run_inference
203
+ logger.info(f"Loaded Gradient Field CNN model from {self.repo_id}")
204
+
205
+ except ConfigurationError:
206
+ raise
207
+ except Exception as e:
208
+ logger.error(f"Failed to load Gradient Field CNN model: {e}")
209
+ raise ConfigurationError(
210
+ message=f"Failed to load model: {e}",
211
+ details={"repo_id": self.repo_id, "error": str(e)}
212
+ )
213
+
214
+ def _rgb_to_luminance(self, img_tensor: torch.Tensor) -> torch.Tensor:
215
+ """
216
+ Convert RGB tensor to luminance using BT.709 coefficients.
217
+
218
+ Args:
219
+ img_tensor: RGB tensor of shape (3, H, W) with values in [0, 1]
220
+
221
+ Returns:
222
+ Luminance tensor of shape (1, H, W)
223
+ """
224
+ luminance = (
225
+ self.R_COEFF * img_tensor[0] +
226
+ self.G_COEFF * img_tensor[1] +
227
+ self.B_COEFF * img_tensor[2]
228
+ )
229
+ return luminance.unsqueeze(0)
230
+
231
+ def _run_inference(
232
+ self,
233
+ luminance_tensor: torch.Tensor,
234
+ explain: bool = False
235
+ ) -> Dict[str, Any]:
236
+ """Run model inference on preprocessed luminance tensor."""
237
+ heatmap = None
238
+
239
+ if explain:
240
+ # Custom GradCAM implementation for single-logit binary model
241
+ # Using absolute CAM values to capture both positive and negative contributions
242
+ # Target the last Conv2d layer (cnn[-5])
243
+ target_layer = self._model.cnn[-5]
244
+
245
+ activations = None
246
+ gradients = None
247
+
248
+ def forward_hook(module, input, output):
249
+ nonlocal activations
250
+ activations = output.detach()
251
+
252
+ def backward_hook(module, grad_input, grad_output):
253
+ nonlocal gradients
254
+ gradients = grad_output[0].detach()
255
+
256
+ h_fwd = target_layer.register_forward_hook(forward_hook)
257
+ h_bwd = target_layer.register_full_backward_hook(backward_hook)
258
+
259
+ try:
260
+ # Forward pass with gradients
261
+ input_tensor = luminance_tensor.clone().requires_grad_(True)
262
+ logits, embedding = self._model(input_tensor)
263
+ prob_fake = torch.sigmoid(logits).item()
264
+ pred_int = 1 if prob_fake >= self._threshold else 0
265
+
266
+ # Backward pass
267
+ self._model.zero_grad()
268
+ logits.backward()
269
+
270
+ if gradients is not None and activations is not None:
271
+ # Compute Grad-CAM weights (global average pooled gradients)
272
+ weights = gradients.mean(dim=(2, 3), keepdim=True) # [1, C, 1, 1]
273
+
274
+ # Weighted combination of activation maps
275
+ cam = (weights * activations).sum(dim=1, keepdim=True) # [1, 1, H, W]
276
+
277
+ # Use absolute values instead of ReLU to capture all contributions
278
+ # This is important for models where negative gradients carry meaning
279
+ cam = torch.abs(cam)
280
+
281
+ # Normalize to [0, 1]
282
+ cam = cam - cam.min()
283
+ cam_max = cam.max()
284
+ if cam_max > 0:
285
+ cam = cam / cam_max
286
+
287
+ # Resize to output size (256x256)
288
+ cam = F.interpolate(
289
+ cam,
290
+ size=(256, 256),
291
+ mode='bilinear',
292
+ align_corners=False
293
+ )
294
+
295
+ heatmap = cam.squeeze().cpu().numpy()
296
+ else:
297
+ logger.warning("GradCAM: gradients or activations not captured")
298
+ heatmap = np.zeros((256, 256), dtype=np.float32)
299
+
300
+ finally:
301
+ h_fwd.remove()
302
+ h_bwd.remove()
303
+ else:
304
+ with torch.no_grad():
305
+ logits, embedding = self._model(luminance_tensor)
306
+ prob_fake = torch.sigmoid(logits).item()
307
+ pred_int = 1 if prob_fake >= self._threshold else 0
308
+
309
+ result = {
310
+ "logits": logits.detach().cpu().numpy().tolist() if hasattr(logits, 'detach') else logits.cpu().numpy().tolist(),
311
+ "prob_fake": prob_fake,
312
+ "pred_int": pred_int,
313
+ "embedding": embedding.detach().cpu().numpy().tolist() if explain else embedding.cpu().numpy().tolist()
314
+ }
315
+
316
+ if heatmap is not None:
317
+ result["heatmap"] = heatmap
318
+
319
+ return result
320
+
321
+ def predict(
322
+ self,
323
+ image: Optional[Image.Image] = None,
324
+ image_bytes: Optional[bytes] = None,
325
+ explain: bool = False,
326
+ **kwargs
327
+ ) -> Dict[str, Any]:
328
+ """
329
+ Run prediction on an image.
330
+
331
+ Args:
332
+ image: PIL Image object
333
+ image_bytes: Raw image bytes (will be converted to PIL Image)
334
+ explain: If True, compute GradCAM heatmap
335
+
336
+ Returns:
337
+ Standardized prediction dictionary with optional heatmap
338
+ """
339
+ if self._model is None or self._resize is None:
340
+ raise InferenceError(
341
+ message="Model not loaded",
342
+ details={"repo_id": self.repo_id}
343
+ )
344
+
345
+ try:
346
+ # Convert bytes to PIL Image if needed
347
+ if image is None and image_bytes is not None:
348
+ import io
349
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
350
+ elif image is not None:
351
+ image = image.convert("RGB")
352
+ else:
353
+ raise InferenceError(
354
+ message="No image provided",
355
+ details={"repo_id": self.repo_id}
356
+ )
357
+
358
+ # Resize
359
+ image = self._resize(image)
360
+
361
+ # Convert to tensor
362
+ img_tensor = transforms.functional.to_tensor(image)
363
+
364
+ # Convert to luminance
365
+ luminance = self._rgb_to_luminance(img_tensor)
366
+ luminance = luminance.unsqueeze(0).to(self._device) # Add batch dim
367
+
368
+ # Run inference
369
+ result = self._run_inference(luminance, explain=explain)
370
+
371
+ # Standardize output
372
+ labels = self.config.get("labels", {"0": "real", "1": "fake"})
373
+ pred_int = result["pred_int"]
374
+
375
+ output = {
376
+ "pred_int": pred_int,
377
+ "pred": labels.get(str(pred_int), "unknown"),
378
+ "prob_fake": result["prob_fake"],
379
+ "meta": {
380
+ "model": self.name,
381
+ "threshold": self._threshold
382
+ }
383
+ }
384
+
385
+ # Add heatmap if requested
386
+ if explain and "heatmap" in result:
387
+ heatmap = result["heatmap"]
388
+ output["heatmap_base64"] = heatmap_to_base64(heatmap)
389
+ output["explainability_type"] = "grad_cam"
390
+ output["focus_summary"] = compute_focus_summary(heatmap) + " (edge-based analysis)"
391
+
392
+ return output
393
+
394
+ except InferenceError:
395
+ raise
396
+ except Exception as e:
397
+ logger.error(f"Prediction failed for {self.repo_id}: {e}")
398
+ raise InferenceError(
399
+ message=f"Prediction failed: {e}",
400
+ details={"repo_id": self.repo_id, "error": str(e)}
401
+ )
app/models/wrappers/logreg_fusion_wrapper.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Wrapper for logistic regression stacking fusion model.
3
+ """
4
+
5
+ import pickle
6
+ from pathlib import Path
7
+ from typing import Any, Dict, List
8
+
9
+ import joblib
10
+ import numpy as np
11
+
12
+ from app.core.errors import FusionError, ConfigurationError
13
+ from app.core.logging import get_logger
14
+ from app.models.wrappers.base_wrapper import BaseFusionWrapper
15
+
16
+ logger = get_logger(__name__)
17
+
18
+
19
+ class LogRegFusionWrapper(BaseFusionWrapper):
20
+ """
21
+ Wrapper for probability stacking fusion with logistic regression.
22
+
23
+ This fusion model takes probability outputs from submodels,
24
+ stacks them into a feature vector, and runs them through a
25
+ trained logistic regression classifier.
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ repo_id: str,
31
+ config: Dict[str, Any],
32
+ local_path: str
33
+ ):
34
+ """
35
+ Initialize the wrapper.
36
+
37
+ Args:
38
+ repo_id: Hugging Face repository ID
39
+ config: Configuration from config.json
40
+ local_path: Local path where the model files are stored
41
+ """
42
+ super().__init__(repo_id, config, local_path)
43
+ self._model = None
44
+ self._submodel_order: List[str] = config.get("submodel_order", [])
45
+ self._threshold: float = config.get("threshold", 0.5)
46
+ logger.info(f"Initialized LogRegFusionWrapper for {repo_id}")
47
+ logger.info(f"Submodel order: {self._submodel_order}")
48
+
49
+ @property
50
+ def submodel_repos(self) -> List[str]:
51
+ """Get list of submodel repository IDs."""
52
+ return self.config.get("submodels", [])
53
+
54
+ def load(self) -> None:
55
+ """
56
+ Load the logistic regression model from the downloaded repository.
57
+
58
+ Loads fusion_logreg.pkl using joblib (sklearn models are saved with joblib).
59
+ """
60
+ model_path = Path(self.local_path) / "fusion_logreg.pkl"
61
+
62
+ if not model_path.exists():
63
+ raise ConfigurationError(
64
+ message=f"fusion_logreg.pkl not found in {self.local_path}",
65
+ details={"repo_id": self.repo_id, "expected_path": str(model_path)}
66
+ )
67
+
68
+ try:
69
+ # Use joblib for sklearn models instead of pickle
70
+ self._model = joblib.load(model_path)
71
+ logger.info(f"Loaded logistic regression fusion model from {self.repo_id}")
72
+
73
+ except Exception as e:
74
+ logger.error(f"Failed to load fusion model from {self.repo_id}: {e}")
75
+ raise ConfigurationError(
76
+ message=f"Failed to load fusion model: {e}",
77
+ details={"repo_id": self.repo_id, "error": str(e)}
78
+ )
79
+
80
+ def predict(
81
+ self,
82
+ submodel_outputs: Dict[str, Dict[str, Any]],
83
+ **kwargs
84
+ ) -> Dict[str, Any]:
85
+ """
86
+ Run fusion prediction on submodel outputs.
87
+
88
+ Stacks submodel probabilities in the correct order and runs
89
+ through the logistic regression classifier.
90
+
91
+ Args:
92
+ submodel_outputs: Dictionary mapping submodel name to its prediction output
93
+ Each output must contain "prob_fake" key
94
+ **kwargs: Additional arguments (unused)
95
+
96
+ Returns:
97
+ Standardized prediction dictionary with:
98
+ - pred_int: 0 or 1
99
+ - pred: "real" or "fake"
100
+ - prob_fake: float probability of being fake
101
+ - meta: dict with submodel probabilities
102
+ """
103
+ if self._model is None:
104
+ raise FusionError(
105
+ message="Fusion model not loaded",
106
+ details={"repo_id": self.repo_id}
107
+ )
108
+
109
+ try:
110
+ # Stack submodel probabilities in the correct order
111
+ probs = []
112
+ for submodel_name in self._submodel_order:
113
+ if submodel_name not in submodel_outputs:
114
+ raise FusionError(
115
+ message=f"Missing output from submodel: {submodel_name}",
116
+ details={
117
+ "repo_id": self.repo_id,
118
+ "missing_submodel": submodel_name,
119
+ "available_submodels": list(submodel_outputs.keys())
120
+ }
121
+ )
122
+
123
+ output = submodel_outputs[submodel_name]
124
+ if "prob_fake" not in output:
125
+ raise FusionError(
126
+ message=f"Submodel output missing 'prob_fake': {submodel_name}",
127
+ details={
128
+ "repo_id": self.repo_id,
129
+ "submodel": submodel_name,
130
+ "output_keys": list(output.keys())
131
+ }
132
+ )
133
+
134
+ probs.append(output["prob_fake"])
135
+
136
+ # Convert to numpy array and reshape for sklearn
137
+ X = np.array(probs).reshape(1, -1)
138
+
139
+ # Get prediction and probability
140
+ prob_fake = float(self._model.predict_proba(X)[0, 1])
141
+ pred_int = 1 if prob_fake >= self._threshold else 0
142
+ pred = "fake" if pred_int == 1 else "real"
143
+
144
+ return {
145
+ "pred_int": pred_int,
146
+ "pred": pred,
147
+ "prob_fake": prob_fake,
148
+ "meta": {
149
+ "submodel_probs": dict(zip(self._submodel_order, probs)),
150
+ "threshold": self._threshold
151
+ }
152
+ }
153
+
154
+ except FusionError:
155
+ raise
156
+ except Exception as e:
157
+ logger.error(f"Fusion prediction failed for {self.repo_id}: {e}")
158
+ raise FusionError(
159
+ message=f"Fusion prediction failed: {e}",
160
+ details={"repo_id": self.repo_id, "error": str(e)}
161
+ )
app/models/wrappers/vit_base_wrapper.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Wrapper for ViT Base submodel.
3
+ """
4
+
5
+ import json
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from pathlib import Path
11
+ from typing import Any, Dict, List, Optional, Tuple
12
+ from PIL import Image
13
+ from torchvision import transforms
14
+
15
+ try:
16
+ import timm
17
+ TIMM_AVAILABLE = True
18
+ except ImportError:
19
+ TIMM_AVAILABLE = False
20
+
21
+ from app.core.errors import InferenceError, ConfigurationError
22
+ from app.core.logging import get_logger
23
+ from app.models.wrappers.base_wrapper import BaseSubmodelWrapper
24
+ from app.services.explainability import attention_rollout, heatmap_to_base64, compute_focus_summary
25
+
26
+ logger = get_logger(__name__)
27
+
28
+
29
+ class ViTWithMLPHead(nn.Module):
30
+ """
31
+ ViT model wrapper matching the training checkpoint format.
32
+
33
+ The checkpoint was saved with:
34
+ - self.vit = timm ViT backbone (num_classes=0)
35
+ - self.fc1 = Linear(768, hidden)
36
+ - self.fc2 = Linear(hidden, num_classes)
37
+ """
38
+
39
+ def __init__(self, arch: str = "vit_base_patch16_224", num_classes: int = 2, hidden_dim: int = 512):
40
+ super().__init__()
41
+ # Create backbone without classification head
42
+ self.vit = timm.create_model(arch, pretrained=False, num_classes=0)
43
+ embed_dim = self.vit.embed_dim # 768 for ViT-Base
44
+ self.fc1 = nn.Linear(embed_dim, hidden_dim)
45
+ self.fc2 = nn.Linear(hidden_dim, num_classes)
46
+
47
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
48
+ features = self.vit(x) # [B, embed_dim]
49
+ x = F.relu(self.fc1(features))
50
+ logits = self.fc2(x)
51
+ return logits
52
+
53
+
54
+ class ViTBaseWrapper(BaseSubmodelWrapper):
55
+ """
56
+ Wrapper for ViT Base model (Vision Transformer).
57
+
58
+ Model expects 224x224 RGB images with ImageNet normalization.
59
+ """
60
+
61
+ def __init__(
62
+ self,
63
+ repo_id: str,
64
+ config: Dict[str, Any],
65
+ local_path: str
66
+ ):
67
+ super().__init__(repo_id, config, local_path)
68
+ self._model: Optional[nn.Module] = None
69
+ self._transform: Optional[transforms.Compose] = None
70
+ self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
71
+ self._threshold = config.get("threshold", 0.5)
72
+ logger.info(f"Initialized ViTBaseWrapper for {repo_id}")
73
+
74
+ def load(self) -> None:
75
+ """Load the ViT Base model with trained weights."""
76
+ if not TIMM_AVAILABLE:
77
+ raise ConfigurationError(
78
+ message="timm package not installed. Run: pip install timm",
79
+ details={"repo_id": self.repo_id}
80
+ )
81
+
82
+ weights_path = Path(self.local_path) / "deepfake_vit_finetuned_wildfake.pth"
83
+ preprocess_path = Path(self.local_path) / "preprocess.json"
84
+
85
+ if not weights_path.exists():
86
+ raise ConfigurationError(
87
+ message=f"deepfake_vit_finetuned_wildfake.pth not found in {self.local_path}",
88
+ details={"repo_id": self.repo_id, "expected_path": str(weights_path)}
89
+ )
90
+
91
+ try:
92
+ # Load preprocessing config
93
+ preprocess_config = {}
94
+ if preprocess_path.exists():
95
+ with open(preprocess_path, "r") as f:
96
+ preprocess_config = json.load(f)
97
+
98
+ # Build transform pipeline
99
+ input_size = preprocess_config.get("input_size", 224)
100
+ if isinstance(input_size, list):
101
+ input_size = input_size[0]
102
+
103
+ normalize_config = preprocess_config.get("normalize", {})
104
+ mean = normalize_config.get("mean", [0.485, 0.456, 0.406])
105
+ std = normalize_config.get("std", [0.229, 0.224, 0.225])
106
+
107
+ # Use bicubic interpolation as specified
108
+ interpolation = preprocess_config.get("interpolation", "bicubic")
109
+ interp_mode = transforms.InterpolationMode.BICUBIC if interpolation == "bicubic" else transforms.InterpolationMode.BILINEAR
110
+
111
+ self._transform = transforms.Compose([
112
+ transforms.Resize((input_size, input_size), interpolation=interp_mode),
113
+ transforms.ToTensor(),
114
+ transforms.Normalize(mean=mean, std=std)
115
+ ])
116
+
117
+ # Create model architecture matching the training checkpoint format
118
+ arch = self.config.get("arch", "vit_base_patch16_224")
119
+ num_classes = self.config.get("num_classes", 2)
120
+ # MLP hidden dim is 512 per training notebook (fc1: 768->512, fc2: 512->2)
121
+ # Note: config.hidden_dim (768) is ViT embedding dim, not MLP hidden dim
122
+ mlp_hidden_dim = self.config.get("mlp_hidden_dim", 512)
123
+
124
+ # Use custom wrapper that matches checkpoint structure (vit.* + fc1/fc2)
125
+ self._model = ViTWithMLPHead(arch=arch, num_classes=num_classes, hidden_dim=mlp_hidden_dim)
126
+
127
+ # Load trained weights
128
+ checkpoint = torch.load(weights_path, map_location=self._device, weights_only=False)
129
+
130
+ # Handle training checkpoint format (has "model", "optimizer_state", "epoch" keys)
131
+ if isinstance(checkpoint, dict) and "model" in checkpoint:
132
+ state_dict = checkpoint["model"]
133
+ else:
134
+ state_dict = checkpoint
135
+
136
+ self._model.load_state_dict(state_dict)
137
+ self._model.to(self._device)
138
+ self._model.eval()
139
+
140
+ # Mark as loaded
141
+ self._predict_fn = self._run_inference
142
+ logger.info(f"Loaded ViT Base model from {self.repo_id}")
143
+
144
+ except ConfigurationError:
145
+ raise
146
+ except Exception as e:
147
+ logger.error(f"Failed to load ViT Base model: {e}")
148
+ raise ConfigurationError(
149
+ message=f"Failed to load model: {e}",
150
+ details={"repo_id": self.repo_id, "error": str(e)}
151
+ )
152
+
153
+ def _run_inference(
154
+ self,
155
+ image_tensor: torch.Tensor,
156
+ explain: bool = False
157
+ ) -> Dict[str, Any]:
158
+ """Run model inference on preprocessed tensor."""
159
+ heatmap = None
160
+
161
+ if explain:
162
+ # Collect attention weights from all blocks
163
+ attentions: List[torch.Tensor] = []
164
+ handles = []
165
+
166
+ def get_attention_hook(module, input, output):
167
+ # For timm ViT, the attention forward returns (attn @ v)
168
+ # We need to hook into the softmax to get raw attention weights
169
+ # Alternative: access module's internal attn variable if available
170
+ pass
171
+
172
+ # Hook into attention modules to capture weights
173
+ # timm ViT blocks structure: blocks[i].attn
174
+ # We'll use a forward hook that computes attention manually
175
+ def create_attn_hook():
176
+ stored_attn = []
177
+
178
+ def hook(module, inputs, outputs):
179
+ # Get q, k from the module's forward computation
180
+ # inputs[0] is x of shape [B, N, C]
181
+ x = inputs[0]
182
+ B, N, C = x.shape
183
+
184
+ # Access the attention module's parameters
185
+ qkv = module.qkv(x) # [B, N, 3*dim]
186
+ qkv = qkv.reshape(B, N, 3, module.num_heads, C // module.num_heads)
187
+ qkv = qkv.permute(2, 0, 3, 1, 4) # [3, B, heads, N, dim_head]
188
+ q, k, v = qkv[0], qkv[1], qkv[2]
189
+
190
+ # Compute attention weights
191
+ scale = (C // module.num_heads) ** -0.5
192
+ attn = (q @ k.transpose(-2, -1)) * scale
193
+ attn = attn.softmax(dim=-1) # [B, heads, N, N]
194
+
195
+ # Average over heads
196
+ attn_avg = attn.mean(dim=1) # [B, N, N]
197
+ stored_attn.append(attn_avg.detach())
198
+
199
+ return hook, stored_attn
200
+
201
+ all_stored_attns = []
202
+ for block in self._model.vit.blocks:
203
+ hook_fn, stored = create_attn_hook()
204
+ all_stored_attns.append(stored)
205
+ handle = block.attn.register_forward_hook(hook_fn)
206
+ handles.append(handle)
207
+
208
+ try:
209
+ with torch.no_grad():
210
+ logits = self._model(image_tensor)
211
+ probs = F.softmax(logits, dim=1)
212
+ prob_fake = probs[0, 1].item()
213
+ pred_int = 1 if prob_fake >= self._threshold else 0
214
+
215
+ # Get attention from hooks
216
+ attention_list = [stored[0] for stored in all_stored_attns if len(stored) > 0]
217
+
218
+ if attention_list:
219
+ # Stack: [num_layers, B, N, N]
220
+ attention_stack = torch.stack(attention_list, dim=0)
221
+ # Compute rollout - returns (grid_size, grid_size) heatmap
222
+ attention_map = attention_rollout(
223
+ attention_stack[:, 0], # [num_layers, N, N]
224
+ head_fusion="mean", # Already averaged
225
+ discard_ratio=0.0,
226
+ num_prefix_tokens=1 # ViT has 1 CLS token
227
+ ) # Returns (14, 14) for ViT-Base
228
+
229
+ # Resize to image size
230
+ from PIL import Image as PILImage
231
+ heatmap_img = PILImage.fromarray(
232
+ (attention_map * 255).astype(np.uint8)
233
+ ).resize((224, 224), PILImage.BILINEAR)
234
+ heatmap = np.array(heatmap_img).astype(np.float32) / 255.0
235
+
236
+ finally:
237
+ for handle in handles:
238
+ handle.remove()
239
+ else:
240
+ with torch.no_grad():
241
+ logits = self._model(image_tensor)
242
+ probs = F.softmax(logits, dim=1)
243
+ prob_fake = probs[0, 1].item()
244
+ pred_int = 1 if prob_fake >= self._threshold else 0
245
+
246
+ result = {
247
+ "logits": logits[0].cpu().numpy().tolist(),
248
+ "prob_fake": prob_fake,
249
+ "pred_int": pred_int
250
+ }
251
+
252
+ if heatmap is not None:
253
+ result["heatmap"] = heatmap
254
+
255
+ return result
256
+
257
+ def predict(
258
+ self,
259
+ image: Optional[Image.Image] = None,
260
+ image_bytes: Optional[bytes] = None,
261
+ explain: bool = False,
262
+ **kwargs
263
+ ) -> Dict[str, Any]:
264
+ """
265
+ Run prediction on an image.
266
+
267
+ Args:
268
+ image: PIL Image object
269
+ image_bytes: Raw image bytes (will be converted to PIL Image)
270
+ explain: If True, compute attention rollout heatmap
271
+
272
+ Returns:
273
+ Standardized prediction dictionary with optional heatmap
274
+ """
275
+ if self._model is None or self._transform is None:
276
+ raise InferenceError(
277
+ message="Model not loaded",
278
+ details={"repo_id": self.repo_id}
279
+ )
280
+
281
+ try:
282
+ # Convert bytes to PIL Image if needed
283
+ if image is None and image_bytes is not None:
284
+ import io
285
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
286
+ elif image is not None:
287
+ image = image.convert("RGB")
288
+ else:
289
+ raise InferenceError(
290
+ message="No image provided",
291
+ details={"repo_id": self.repo_id}
292
+ )
293
+
294
+ # Preprocess
295
+ image_tensor = self._transform(image).unsqueeze(0).to(self._device)
296
+
297
+ # Run inference
298
+ result = self._run_inference(image_tensor, explain=explain)
299
+
300
+ # Standardize output
301
+ labels = self.config.get("labels", {"0": "real", "1": "fake"})
302
+ pred_int = result["pred_int"]
303
+
304
+ output = {
305
+ "pred_int": pred_int,
306
+ "pred": labels.get(str(pred_int), "unknown"),
307
+ "prob_fake": result["prob_fake"],
308
+ "meta": {
309
+ "model": self.name,
310
+ "threshold": self._threshold,
311
+ "logits": result["logits"]
312
+ }
313
+ }
314
+
315
+ # Add heatmap if requested
316
+ if explain and "heatmap" in result:
317
+ heatmap = result["heatmap"]
318
+ output["heatmap_base64"] = heatmap_to_base64(heatmap)
319
+ output["explainability_type"] = "attention_rollout"
320
+ output["focus_summary"] = compute_focus_summary(heatmap)
321
+
322
+ return output
323
+
324
+ except InferenceError:
325
+ raise
326
+ except Exception as e:
327
+ logger.error(f"Prediction failed for {self.repo_id}: {e}")
328
+ raise InferenceError(
329
+ message=f"Prediction failed: {e}",
330
+ details={"repo_id": self.repo_id, "error": str(e)}
331
+ )
app/schemas/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Schemas module
app/schemas/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (196 Bytes). View file
 
app/schemas/__pycache__/models.cpython-312.pyc ADDED
Binary file (2.74 kB). View file
 
app/schemas/__pycache__/predict.cpython-312.pyc ADDED
Binary file (8.17 kB). View file
 
app/schemas/models.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pydantic schemas for model-related endpoints.
3
+ """
4
+
5
+ from typing import Dict, List, Literal, Optional, Any
6
+ from pydantic import BaseModel, Field
7
+
8
+
9
+ class ModelInfo(BaseModel):
10
+ """Information about a loaded model."""
11
+
12
+ repo_id: str = Field(..., description="Hugging Face repository ID")
13
+ name: str = Field(..., description="Short name of the model")
14
+ model_type: Literal["submodel", "fusion"] = Field(
15
+ ...,
16
+ description="Type of model"
17
+ )
18
+ config: Optional[Dict[str, Any]] = Field(
19
+ None,
20
+ description="Model configuration from config.json"
21
+ )
22
+
23
+
24
+ class ModelsListResponse(BaseModel):
25
+ """Response schema for listing models."""
26
+
27
+ fusion: Optional[ModelInfo] = Field(
28
+ None,
29
+ description="Fusion model information"
30
+ )
31
+ submodels: List[ModelInfo] = Field(
32
+ default_factory=list,
33
+ description="List of loaded submodels"
34
+ )
35
+ total_count: int = Field(..., description="Total number of loaded models")
36
+
37
+
38
+ class HealthResponse(BaseModel):
39
+ """Response schema for health check."""
40
+
41
+ status: Literal["ok", "error"] = Field(..., description="Health status")
42
+
43
+
44
+ class ReadyResponse(BaseModel):
45
+ """Response schema for readiness check."""
46
+
47
+ status: Literal["ready", "not_ready"] = Field(..., description="Readiness status")
48
+ models_loaded: bool = Field(..., description="Whether models are loaded")
49
+ fusion_repo: Optional[str] = Field(None, description="Fusion repository ID")
50
+ submodels: List[str] = Field(
51
+ default_factory=list,
52
+ description="List of loaded submodel repository IDs"
53
+ )