Commit ·
df4a21a
1
Parent(s): 14a1b30
Deploy DeepFake Detector API - 2026-03-07 09:12:00
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .env.example +24 -0
- Dockerfile +43 -0
- Dockerfile.huggingface +43 -0
- README.md +178 -8
- README_HF.md +182 -0
- app/__init__.py +1 -0
- app/__pycache__/__init__.cpython-312.pyc +0 -0
- app/__pycache__/main.cpython-312.pyc +0 -0
- app/api/__init__.py +1 -0
- app/api/__pycache__/__init__.cpython-312.pyc +0 -0
- app/api/__pycache__/routes_health.cpython-312.pyc +0 -0
- app/api/__pycache__/routes_models.cpython-312.pyc +0 -0
- app/api/__pycache__/routes_predict.cpython-312.pyc +0 -0
- app/api/routes_health.py +62 -0
- app/api/routes_models.py +51 -0
- app/api/routes_predict.py +286 -0
- app/core/__init__.py +1 -0
- app/core/__pycache__/__init__.cpython-312.pyc +0 -0
- app/core/__pycache__/config.cpython-312.pyc +0 -0
- app/core/__pycache__/errors.cpython-312.pyc +0 -0
- app/core/__pycache__/logging.cpython-312.pyc +0 -0
- app/core/config.py +64 -0
- app/core/errors.py +53 -0
- app/core/logging.py +61 -0
- app/main.py +128 -0
- app/models/__init__.py +1 -0
- app/models/__pycache__/__init__.cpython-312.pyc +0 -0
- app/models/wrappers/__init__.py +1 -0
- app/models/wrappers/__pycache__/__init__.cpython-312.pyc +0 -0
- app/models/wrappers/__pycache__/base_wrapper.cpython-312.pyc +0 -0
- app/models/wrappers/__pycache__/cnn_transfer_wrapper.cpython-312.pyc +0 -0
- app/models/wrappers/__pycache__/deit_distilled_wrapper.cpython-312.pyc +0 -0
- app/models/wrappers/__pycache__/dummy_majority_fusion_wrapper.cpython-312.pyc +0 -0
- app/models/wrappers/__pycache__/dummy_random_wrapper.cpython-312.pyc +0 -0
- app/models/wrappers/__pycache__/gradfield_cnn_wrapper.cpython-312.pyc +0 -0
- app/models/wrappers/__pycache__/logreg_fusion_wrapper.cpython-312.pyc +0 -0
- app/models/wrappers/__pycache__/vit_base_wrapper.cpython-312.pyc +0 -0
- app/models/wrappers/base_wrapper.py +150 -0
- app/models/wrappers/cnn_transfer_wrapper.py +226 -0
- app/models/wrappers/deit_distilled_wrapper.py +312 -0
- app/models/wrappers/dummy_majority_fusion_wrapper.py +171 -0
- app/models/wrappers/dummy_random_wrapper.py +168 -0
- app/models/wrappers/gradfield_cnn_wrapper.py +401 -0
- app/models/wrappers/logreg_fusion_wrapper.py +161 -0
- app/models/wrappers/vit_base_wrapper.py +331 -0
- app/schemas/__init__.py +1 -0
- app/schemas/__pycache__/__init__.cpython-312.pyc +0 -0
- app/schemas/__pycache__/models.cpython-312.pyc +0 -0
- app/schemas/__pycache__/predict.cpython-312.pyc +0 -0
- app/schemas/models.py +53 -0
.env.example
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DeepFake Detector Backend - Environment Variables
|
| 2 |
+
# Copy this file to .env and update with your values
|
| 3 |
+
|
| 4 |
+
# Hugging Face Configuration
|
| 5 |
+
# Available fusion models:
|
| 6 |
+
# - DeepFakeDetector/fusion-logreg-final (Logistic Regression - default)
|
| 7 |
+
# - DeepFakeDetector/fusion-meta-final (Meta-classifier)
|
| 8 |
+
HF_FUSION_REPO_ID=DeepFakeDetector/fusion-logreg-final
|
| 9 |
+
HF_CACHE_DIR=.hf_cache
|
| 10 |
+
# HF_TOKEN=your_huggingface_token_here # Optional: for private repos
|
| 11 |
+
|
| 12 |
+
# Google Gemini API (Optional - for LLM explanations)
|
| 13 |
+
# GOOGLE_API_KEY=your_google_api_key_here
|
| 14 |
+
|
| 15 |
+
# Server Configuration
|
| 16 |
+
HOST=0.0.0.0
|
| 17 |
+
PORT=8000
|
| 18 |
+
|
| 19 |
+
# CORS Configuration (comma-separated list of allowed origins)
|
| 20 |
+
CORS_ORIGINS=http://localhost:8082,https://www.deepfake-detector.app,https://deepfake-detector.app
|
| 21 |
+
|
| 22 |
+
# Debugging
|
| 23 |
+
ENABLE_DEBUG=false
|
| 24 |
+
LOG_LEVEL=INFO
|
Dockerfile
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DeepFake Detector API - Hugging Face Spaces Docker Image
|
| 2 |
+
# Optimized for HF Spaces deployment with GPU support
|
| 3 |
+
|
| 4 |
+
FROM python:3.11-slim
|
| 5 |
+
|
| 6 |
+
# Set working directory
|
| 7 |
+
WORKDIR /app
|
| 8 |
+
|
| 9 |
+
# Set environment variables
|
| 10 |
+
ENV PYTHONDONTWRITEBYTECODE=1 \
|
| 11 |
+
PYTHONUNBUFFERED=1 \
|
| 12 |
+
PIP_NO_CACHE_DIR=1 \
|
| 13 |
+
PIP_DISABLE_PIP_VERSION_CHECK=1 \
|
| 14 |
+
PORT=7860
|
| 15 |
+
|
| 16 |
+
# Install system dependencies
|
| 17 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 18 |
+
curl \
|
| 19 |
+
git \
|
| 20 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 21 |
+
|
| 22 |
+
# Create non-root user (HF Spaces requirement)
|
| 23 |
+
RUN useradd -m -u 1000 user
|
| 24 |
+
USER user
|
| 25 |
+
|
| 26 |
+
# Set PATH for user-installed packages
|
| 27 |
+
ENV PATH="/home/user/.local/bin:$PATH"
|
| 28 |
+
|
| 29 |
+
# Copy requirements and install dependencies as user
|
| 30 |
+
COPY --chown=user:user requirements.txt .
|
| 31 |
+
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
| 32 |
+
|
| 33 |
+
# Copy application code
|
| 34 |
+
COPY --chown=user:user . /app
|
| 35 |
+
|
| 36 |
+
# Create cache directory for Hugging Face models
|
| 37 |
+
RUN mkdir -p /app/.hf_cache
|
| 38 |
+
|
| 39 |
+
# Expose HF Spaces port
|
| 40 |
+
EXPOSE 7860
|
| 41 |
+
|
| 42 |
+
# Run the application (start.sh already defaults to port 7860)
|
| 43 |
+
CMD ["./start.sh"]
|
Dockerfile.huggingface
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DeepFake Detector API - Hugging Face Spaces Docker Image
|
| 2 |
+
# Optimized for HF Spaces deployment with GPU support
|
| 3 |
+
|
| 4 |
+
FROM python:3.11-slim
|
| 5 |
+
|
| 6 |
+
# Set working directory
|
| 7 |
+
WORKDIR /app
|
| 8 |
+
|
| 9 |
+
# Set environment variables
|
| 10 |
+
ENV PYTHONDONTWRITEBYTECODE=1 \
|
| 11 |
+
PYTHONUNBUFFERED=1 \
|
| 12 |
+
PIP_NO_CACHE_DIR=1 \
|
| 13 |
+
PIP_DISABLE_PIP_VERSION_CHECK=1 \
|
| 14 |
+
PORT=7860
|
| 15 |
+
|
| 16 |
+
# Install system dependencies
|
| 17 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 18 |
+
curl \
|
| 19 |
+
git \
|
| 20 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 21 |
+
|
| 22 |
+
# Create non-root user (HF Spaces requirement)
|
| 23 |
+
RUN useradd -m -u 1000 user
|
| 24 |
+
USER user
|
| 25 |
+
|
| 26 |
+
# Set PATH for user-installed packages
|
| 27 |
+
ENV PATH="/home/user/.local/bin:$PATH"
|
| 28 |
+
|
| 29 |
+
# Copy requirements and install dependencies as user
|
| 30 |
+
COPY --chown=user:user requirements.txt .
|
| 31 |
+
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
| 32 |
+
|
| 33 |
+
# Copy application code
|
| 34 |
+
COPY --chown=user:user . /app
|
| 35 |
+
|
| 36 |
+
# Create cache directory for Hugging Face models
|
| 37 |
+
RUN mkdir -p /app/.hf_cache
|
| 38 |
+
|
| 39 |
+
# Expose HF Spaces port
|
| 40 |
+
EXPOSE 7860
|
| 41 |
+
|
| 42 |
+
# Run the application (start.sh already defaults to port 7860)
|
| 43 |
+
CMD ["./start.sh"]
|
README.md
CHANGED
|
@@ -1,12 +1,182 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
-
|
| 8 |
-
license: mit
|
| 9 |
-
short_description: FastAPI Backend for MacAI Society DeepFake Detector
|
| 10 |
---
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: DeepFake Detector API
|
| 3 |
+
emoji: 🎭
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
sdk: docker
|
| 7 |
+
app_port: 7860
|
|
|
|
|
|
|
| 8 |
---
|
| 9 |
|
| 10 |
+
# 🎭 DeepFake Detector API
|
| 11 |
+
|
| 12 |
+
FastAPI backend for detecting AI-generated (deepfake) images using an ensemble of state-of-the-art deep learning models.
|
| 13 |
+
|
| 14 |
+
## 🤖 Models
|
| 15 |
+
|
| 16 |
+
This API uses a fusion ensemble of 5 deep learning models:
|
| 17 |
+
|
| 18 |
+
- **CNN Transfer** (EfficientNet-B0) - Transfer learning from ImageNet
|
| 19 |
+
- **ViT Base** (Vision Transformer) - Attention-based architecture
|
| 20 |
+
- **DeiT Distilled** (Data-efficient Image Transformer) - Distilled ViT variant
|
| 21 |
+
- **Gradient Field CNN** - Custom architecture analyzing gradient patterns
|
| 22 |
+
- **FFT CNN** - Frequency domain analysis using Fast Fourier Transform
|
| 23 |
+
|
| 24 |
+
All models are combined using a **Logistic Regression stacking ensemble** for optimal accuracy.
|
| 25 |
+
|
| 26 |
+
## 🔗 API Endpoints
|
| 27 |
+
|
| 28 |
+
| Endpoint | Method | Description |
|
| 29 |
+
|----------|--------|-------------|
|
| 30 |
+
| `/health` | GET | Health check - returns API status |
|
| 31 |
+
| `/ready` | GET | Model readiness check - confirms models are loaded |
|
| 32 |
+
| `/models` | GET | List all loaded models with metadata |
|
| 33 |
+
| `/predict` | POST | Predict if an image is real or AI-generated |
|
| 34 |
+
| `/docs` | GET | Interactive Swagger API documentation |
|
| 35 |
+
| `/redoc` | GET | Alternative API documentation |
|
| 36 |
+
|
| 37 |
+
## 🚀 Usage Example
|
| 38 |
+
|
| 39 |
+
### Using cURL
|
| 40 |
+
|
| 41 |
+
```bash
|
| 42 |
+
# Check if API is ready
|
| 43 |
+
curl https://lukhsaankumar-deepfakedetectorbackend.hf.space/ready
|
| 44 |
+
|
| 45 |
+
# Make a prediction
|
| 46 |
+
curl -X POST "https://lukhsaankumar-deepfakedetectorbackend.hf.space/predict" \
|
| 47 |
+
-F "file=@image.jpg" \
|
| 48 |
+
-F "explain=true"
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
### Using Python
|
| 52 |
+
|
| 53 |
+
```python
|
| 54 |
+
import requests
|
| 55 |
+
|
| 56 |
+
# Upload an image for prediction
|
| 57 |
+
url = "https://lukhsaankumar-deepfakedetectorbackend.hf.space/predict"
|
| 58 |
+
files = {"file": open("image.jpg", "rb")}
|
| 59 |
+
data = {"explain": True}
|
| 60 |
+
|
| 61 |
+
response = requests.post(url, files=files, data=data)
|
| 62 |
+
result = response.json()
|
| 63 |
+
|
| 64 |
+
print(f"Prediction: {result['prediction']}")
|
| 65 |
+
print(f"Confidence: {result['confidence']:.2%}")
|
| 66 |
+
print(f"Explanation: {result['explanation']}")
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
## 🎯 Response Format
|
| 70 |
+
|
| 71 |
+
```json
|
| 72 |
+
{
|
| 73 |
+
"prediction": "fake",
|
| 74 |
+
"confidence": 0.8734,
|
| 75 |
+
"probabilities": {
|
| 76 |
+
"real": 0.1266,
|
| 77 |
+
"fake": 0.8734
|
| 78 |
+
},
|
| 79 |
+
"model_predictions": {
|
| 80 |
+
"cnn_transfer": {"prediction": "fake", "confidence": 0.89},
|
| 81 |
+
"vit_base": {"prediction": "fake", "confidence": 0.92},
|
| 82 |
+
"deit": {"prediction": "fake", "confidence": 0.85},
|
| 83 |
+
"gradient_field": {"prediction": "real", "confidence": 0.55},
|
| 84 |
+
"fft_cnn": {"prediction": "fake", "confidence": 0.78}
|
| 85 |
+
},
|
| 86 |
+
"fusion_confidence": 0.8734,
|
| 87 |
+
"explanation": "AI-powered analysis of the prediction...",
|
| 88 |
+
"processing_time_ms": 342
|
| 89 |
+
}
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
## 🔧 Configuration
|
| 93 |
+
|
| 94 |
+
### Required Secrets
|
| 95 |
+
|
| 96 |
+
Set these in your Space Settings → Repository secrets:
|
| 97 |
+
|
| 98 |
+
| Secret | Description | Required |
|
| 99 |
+
|--------|-------------|----------|
|
| 100 |
+
| `GOOGLE_API_KEY` | Google Gemini API key for AI explanations | Yes |
|
| 101 |
+
| `HF_TOKEN` | Hugging Face token (auto-set by Spaces) | No |
|
| 102 |
+
|
| 103 |
+
### Optional Environment Variables
|
| 104 |
+
|
| 105 |
+
| Variable | Default | Description |
|
| 106 |
+
|----------|---------|-------------|
|
| 107 |
+
| `HF_FUSION_REPO_ID` | `DeepFakeDetector/fusion-logreg-final` | Hugging Face model repository |
|
| 108 |
+
| `CORS_ORIGINS` | Multiple defaults | Comma-separated allowed CORS origins |
|
| 109 |
+
| `GEMINI_MODEL` | `gemini-2.5-flash` | Gemini model for explanations |
|
| 110 |
+
|
| 111 |
+
## 🏗️ Architecture
|
| 112 |
+
|
| 113 |
+
```
|
| 114 |
+
┌─────────────┐
|
| 115 |
+
│ Client │
|
| 116 |
+
└──────┬──────┘
|
| 117 |
+
│
|
| 118 |
+
▼
|
| 119 |
+
┌─────────────────────────────────┐
|
| 120 |
+
│ FastAPI Backend │
|
| 121 |
+
│ ┌──────────────────────────┐ │
|
| 122 |
+
│ │ Model Registry │ │
|
| 123 |
+
│ │ ┌────────────────────┐ │ │
|
| 124 |
+
│ │ │ CNN Transfer │ │ │
|
| 125 |
+
│ │ │ ViT Base │ │ │
|
| 126 |
+
│ │ │ DeiT Distilled │ │ │
|
| 127 |
+
│ │ │ Gradient Field │ │ │
|
| 128 |
+
│ │ │ FFT CNN │ │ │
|
| 129 |
+
│ │ └────────────────────┘ │ │
|
| 130 |
+
│ │ ┌────────────────────┐ │ │
|
| 131 |
+
│ │ │ Fusion Ensemble │ │ │
|
| 132 |
+
│ │ │ (LogReg Stacking) │ │ │
|
| 133 |
+
│ │ └────────────────────┘ │ │
|
| 134 |
+
│ └──────────────────────────┘ │
|
| 135 |
+
│ ┌──────────────────────────┐ │
|
| 136 |
+
│ │ Gemini Explainer │ │
|
| 137 |
+
│ └──────────────��───────────┘ │
|
| 138 |
+
└─────────────────────────────────┘
|
| 139 |
+
```
|
| 140 |
+
|
| 141 |
+
## 📊 Performance
|
| 142 |
+
|
| 143 |
+
- **Accuracy**: ~87% on test set (OpenFake dataset)
|
| 144 |
+
- **Inference Time**: ~200-500ms per image (with GPU)
|
| 145 |
+
- **Model Size**: ~500MB total
|
| 146 |
+
- **Supported Formats**: JPG, PNG, WEBP
|
| 147 |
+
|
| 148 |
+
## 🐛 Troubleshooting
|
| 149 |
+
|
| 150 |
+
### Models not loading?
|
| 151 |
+
- Check the Logs tab for specific errors
|
| 152 |
+
- Verify `HF_FUSION_REPO_ID` points to a valid repository
|
| 153 |
+
- Ensure the repository is public or `HF_TOKEN` is set
|
| 154 |
+
|
| 155 |
+
### Explanations not working?
|
| 156 |
+
- Verify `GOOGLE_API_KEY` is set in Space Settings
|
| 157 |
+
- Check if you have Gemini API quota remaining
|
| 158 |
+
- Review logs for API errors
|
| 159 |
+
|
| 160 |
+
### CORS errors?
|
| 161 |
+
- Add your frontend domain to `CORS_ORIGINS` in Space Settings
|
| 162 |
+
- Format: `https://yourdomain.com,https://www.yourdomain.com`
|
| 163 |
+
|
| 164 |
+
## 📚 Documentation
|
| 165 |
+
|
| 166 |
+
- **Interactive Docs**: Visit `/docs` for Swagger UI
|
| 167 |
+
- **ReDoc**: Visit `/redoc` for alternative documentation
|
| 168 |
+
- **Source Code**: [GitHub Repository](https://github.com/lukhsaankumar/DeepFakeDetector)
|
| 169 |
+
|
| 170 |
+
## 📝 License
|
| 171 |
+
|
| 172 |
+
This project is part of the MacAI Society research initiative.
|
| 173 |
+
|
| 174 |
+
## 🙏 Acknowledgments
|
| 175 |
+
|
| 176 |
+
- Models trained on OpenFake, ImageNet, and custom datasets
|
| 177 |
+
- Powered by PyTorch, Hugging Face, and FastAPI
|
| 178 |
+
- AI explanations by Google Gemini
|
| 179 |
+
|
| 180 |
+
---
|
| 181 |
+
|
| 182 |
+
**Built with ❤️ by MacAI Society**
|
README_HF.md
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: DeepFake Detector API
|
| 3 |
+
emoji: 🎭
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
# 🎭 DeepFake Detector API
|
| 11 |
+
|
| 12 |
+
FastAPI backend for detecting AI-generated (deepfake) images using an ensemble of state-of-the-art deep learning models.
|
| 13 |
+
|
| 14 |
+
## 🤖 Models
|
| 15 |
+
|
| 16 |
+
This API uses a fusion ensemble of 5 deep learning models:
|
| 17 |
+
|
| 18 |
+
- **CNN Transfer** (EfficientNet-B0) - Transfer learning from ImageNet
|
| 19 |
+
- **ViT Base** (Vision Transformer) - Attention-based architecture
|
| 20 |
+
- **DeiT Distilled** (Data-efficient Image Transformer) - Distilled ViT variant
|
| 21 |
+
- **Gradient Field CNN** - Custom architecture analyzing gradient patterns
|
| 22 |
+
- **FFT CNN** - Frequency domain analysis using Fast Fourier Transform
|
| 23 |
+
|
| 24 |
+
All models are combined using a **Logistic Regression stacking ensemble** for optimal accuracy.
|
| 25 |
+
|
| 26 |
+
## 🔗 API Endpoints
|
| 27 |
+
|
| 28 |
+
| Endpoint | Method | Description |
|
| 29 |
+
|----------|--------|-------------|
|
| 30 |
+
| `/health` | GET | Health check - returns API status |
|
| 31 |
+
| `/ready` | GET | Model readiness check - confirms models are loaded |
|
| 32 |
+
| `/models` | GET | List all loaded models with metadata |
|
| 33 |
+
| `/predict` | POST | Predict if an image is real or AI-generated |
|
| 34 |
+
| `/docs` | GET | Interactive Swagger API documentation |
|
| 35 |
+
| `/redoc` | GET | Alternative API documentation |
|
| 36 |
+
|
| 37 |
+
## 🚀 Usage Example
|
| 38 |
+
|
| 39 |
+
### Using cURL
|
| 40 |
+
|
| 41 |
+
```bash
|
| 42 |
+
# Check if API is ready
|
| 43 |
+
curl https://lukhsaankumar-deepfakedetectorbackend.hf.space/ready
|
| 44 |
+
|
| 45 |
+
# Make a prediction
|
| 46 |
+
curl -X POST "https://lukhsaankumar-deepfakedetectorbackend.hf.space/predict" \
|
| 47 |
+
-F "file=@image.jpg" \
|
| 48 |
+
-F "explain=true"
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
### Using Python
|
| 52 |
+
|
| 53 |
+
```python
|
| 54 |
+
import requests
|
| 55 |
+
|
| 56 |
+
# Upload an image for prediction
|
| 57 |
+
url = "https://lukhsaankumar-deepfakedetectorbackend.hf.space/predict"
|
| 58 |
+
files = {"file": open("image.jpg", "rb")}
|
| 59 |
+
data = {"explain": True}
|
| 60 |
+
|
| 61 |
+
response = requests.post(url, files=files, data=data)
|
| 62 |
+
result = response.json()
|
| 63 |
+
|
| 64 |
+
print(f"Prediction: {result['prediction']}")
|
| 65 |
+
print(f"Confidence: {result['confidence']:.2%}")
|
| 66 |
+
print(f"Explanation: {result['explanation']}")
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
## 🎯 Response Format
|
| 70 |
+
|
| 71 |
+
```json
|
| 72 |
+
{
|
| 73 |
+
"prediction": "fake",
|
| 74 |
+
"confidence": 0.8734,
|
| 75 |
+
"probabilities": {
|
| 76 |
+
"real": 0.1266,
|
| 77 |
+
"fake": 0.8734
|
| 78 |
+
},
|
| 79 |
+
"model_predictions": {
|
| 80 |
+
"cnn_transfer": {"prediction": "fake", "confidence": 0.89},
|
| 81 |
+
"vit_base": {"prediction": "fake", "confidence": 0.92},
|
| 82 |
+
"deit": {"prediction": "fake", "confidence": 0.85},
|
| 83 |
+
"gradient_field": {"prediction": "real", "confidence": 0.55},
|
| 84 |
+
"fft_cnn": {"prediction": "fake", "confidence": 0.78}
|
| 85 |
+
},
|
| 86 |
+
"fusion_confidence": 0.8734,
|
| 87 |
+
"explanation": "AI-powered analysis of the prediction...",
|
| 88 |
+
"processing_time_ms": 342
|
| 89 |
+
}
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
## 🔧 Configuration
|
| 93 |
+
|
| 94 |
+
### Required Secrets
|
| 95 |
+
|
| 96 |
+
Set these in your Space Settings → Repository secrets:
|
| 97 |
+
|
| 98 |
+
| Secret | Description | Required |
|
| 99 |
+
|--------|-------------|----------|
|
| 100 |
+
| `GOOGLE_API_KEY` | Google Gemini API key for AI explanations | Yes |
|
| 101 |
+
| `HF_TOKEN` | Hugging Face token (auto-set by Spaces) | No |
|
| 102 |
+
|
| 103 |
+
### Optional Environment Variables
|
| 104 |
+
|
| 105 |
+
| Variable | Default | Description |
|
| 106 |
+
|----------|---------|-------------|
|
| 107 |
+
| `HF_FUSION_REPO_ID` | `DeepFakeDetector/fusion-logreg-final` | Hugging Face model repository |
|
| 108 |
+
| `CORS_ORIGINS` | Multiple defaults | Comma-separated allowed CORS origins |
|
| 109 |
+
| `GEMINI_MODEL` | `gemini-2.5-flash` | Gemini model for explanations |
|
| 110 |
+
|
| 111 |
+
## 🏗️ Architecture
|
| 112 |
+
|
| 113 |
+
```
|
| 114 |
+
┌─────────────┐
|
| 115 |
+
│ Client │
|
| 116 |
+
└──────┬──────┘
|
| 117 |
+
│
|
| 118 |
+
▼
|
| 119 |
+
┌─────────────────────────────────┐
|
| 120 |
+
│ FastAPI Backend │
|
| 121 |
+
│ ┌──────────────────────────┐ │
|
| 122 |
+
│ │ Model Registry │ │
|
| 123 |
+
│ │ ┌────────────────────┐ │ │
|
| 124 |
+
│ │ │ CNN Transfer │ │ │
|
| 125 |
+
│ │ │ ViT Base │ │ │
|
| 126 |
+
│ │ │ DeiT Distilled │ │ │
|
| 127 |
+
│ │ │ Gradient Field │ │ │
|
| 128 |
+
│ │ │ FFT CNN │ │ │
|
| 129 |
+
│ │ └────────────────────┘ │ │
|
| 130 |
+
│ │ ┌────────────────────┐ │ │
|
| 131 |
+
│ │ │ Fusion Ensemble │ │ │
|
| 132 |
+
│ │ │ (LogReg Stacking) │ │ │
|
| 133 |
+
│ │ └────────────────────┘ │ │
|
| 134 |
+
│ └──────────────────────────┘ │
|
| 135 |
+
│ ┌──────────────────────────┐ │
|
| 136 |
+
│ │ Gemini Explainer │ │
|
| 137 |
+
│ └──────────────────────────┘ │
|
| 138 |
+
└─────────────────────────────────┘
|
| 139 |
+
```
|
| 140 |
+
|
| 141 |
+
## 📊 Performance
|
| 142 |
+
|
| 143 |
+
- **Accuracy**: ~87% on test set (OpenFake dataset)
|
| 144 |
+
- **Inference Time**: ~200-500ms per image (with GPU)
|
| 145 |
+
- **Model Size**: ~500MB total
|
| 146 |
+
- **Supported Formats**: JPG, PNG, WEBP
|
| 147 |
+
|
| 148 |
+
## 🐛 Troubleshooting
|
| 149 |
+
|
| 150 |
+
### Models not loading?
|
| 151 |
+
- Check the Logs tab for specific errors
|
| 152 |
+
- Verify `HF_FUSION_REPO_ID` points to a valid repository
|
| 153 |
+
- Ensure the repository is public or `HF_TOKEN` is set
|
| 154 |
+
|
| 155 |
+
### Explanations not working?
|
| 156 |
+
- Verify `GOOGLE_API_KEY` is set in Space Settings
|
| 157 |
+
- Check if you have Gemini API quota remaining
|
| 158 |
+
- Review logs for API errors
|
| 159 |
+
|
| 160 |
+
### CORS errors?
|
| 161 |
+
- Add your frontend domain to `CORS_ORIGINS` in Space Settings
|
| 162 |
+
- Format: `https://yourdomain.com,https://www.yourdomain.com`
|
| 163 |
+
|
| 164 |
+
## 📚 Documentation
|
| 165 |
+
|
| 166 |
+
- **Interactive Docs**: Visit `/docs` for Swagger UI
|
| 167 |
+
- **ReDoc**: Visit `/redoc` for alternative documentation
|
| 168 |
+
- **Source Code**: [GitHub Repository](https://github.com/lukhsaankumar/DeepFakeDetector)
|
| 169 |
+
|
| 170 |
+
## 📝 License
|
| 171 |
+
|
| 172 |
+
This project is part of the MacAI Society research initiative.
|
| 173 |
+
|
| 174 |
+
## 🙏 Acknowledgments
|
| 175 |
+
|
| 176 |
+
- Models trained on OpenFake, ImageNet, and custom datasets
|
| 177 |
+
- Powered by PyTorch, Hugging Face, and FastAPI
|
| 178 |
+
- AI explanations by Google Gemini
|
| 179 |
+
|
| 180 |
+
---
|
| 181 |
+
|
| 182 |
+
**Built with ❤️ by MacAI Society**
|
app/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# DeepFake Detector Backend Application
|
app/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (188 Bytes). View file
|
|
|
app/__pycache__/main.cpython-312.pyc
ADDED
|
Binary file (4.92 kB). View file
|
|
|
app/api/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# API module
|
app/api/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (192 Bytes). View file
|
|
|
app/api/__pycache__/routes_health.cpython-312.pyc
ADDED
|
Binary file (2.08 kB). View file
|
|
|
app/api/__pycache__/routes_models.cpython-312.pyc
ADDED
|
Binary file (1.74 kB). View file
|
|
|
app/api/__pycache__/routes_predict.cpython-312.pyc
ADDED
|
Binary file (10.7 kB). View file
|
|
|
app/api/routes_health.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Health check routes.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from fastapi import APIRouter
|
| 6 |
+
|
| 7 |
+
from app.core.logging import get_logger
|
| 8 |
+
from app.schemas.models import HealthResponse, ReadyResponse
|
| 9 |
+
from app.services.model_registry import get_model_registry
|
| 10 |
+
|
| 11 |
+
logger = get_logger(__name__)
|
| 12 |
+
router = APIRouter(tags=["health"])
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@router.get(
|
| 16 |
+
"/health",
|
| 17 |
+
response_model=HealthResponse,
|
| 18 |
+
summary="Health check",
|
| 19 |
+
description="Simple health check to verify the API is running"
|
| 20 |
+
)
|
| 21 |
+
async def health_check() -> HealthResponse:
|
| 22 |
+
"""
|
| 23 |
+
Health check endpoint.
|
| 24 |
+
|
| 25 |
+
Returns OK if the API server is running.
|
| 26 |
+
"""
|
| 27 |
+
return HealthResponse(status="ok")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@router.get(
|
| 31 |
+
"/ready",
|
| 32 |
+
response_model=ReadyResponse,
|
| 33 |
+
summary="Readiness check",
|
| 34 |
+
description="Check if models are loaded and the API is ready to serve predictions"
|
| 35 |
+
)
|
| 36 |
+
async def readiness_check() -> ReadyResponse:
|
| 37 |
+
"""
|
| 38 |
+
Readiness check endpoint.
|
| 39 |
+
|
| 40 |
+
Verifies that models are loaded and ready for inference.
|
| 41 |
+
Returns detailed information about loaded models.
|
| 42 |
+
"""
|
| 43 |
+
registry = get_model_registry()
|
| 44 |
+
|
| 45 |
+
if not registry.is_loaded:
|
| 46 |
+
return ReadyResponse(
|
| 47 |
+
status="not_ready",
|
| 48 |
+
models_loaded=False,
|
| 49 |
+
fusion_repo=None,
|
| 50 |
+
submodels=[]
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
return ReadyResponse(
|
| 54 |
+
status="ready",
|
| 55 |
+
models_loaded=True,
|
| 56 |
+
fusion_repo=registry.get_fusion_repo_id(),
|
| 57 |
+
submodels=[
|
| 58 |
+
model["repo_id"]
|
| 59 |
+
for model in registry.list_models()
|
| 60 |
+
if model["model_type"] == "submodel"
|
| 61 |
+
]
|
| 62 |
+
)
|
app/api/routes_models.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Model listing routes.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from fastapi import APIRouter
|
| 6 |
+
|
| 7 |
+
from app.core.logging import get_logger
|
| 8 |
+
from app.schemas.models import ModelsListResponse, ModelInfo
|
| 9 |
+
from app.services.model_registry import get_model_registry
|
| 10 |
+
|
| 11 |
+
logger = get_logger(__name__)
|
| 12 |
+
router = APIRouter(tags=["models"])
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@router.get(
|
| 16 |
+
"/models",
|
| 17 |
+
response_model=ModelsListResponse,
|
| 18 |
+
summary="List loaded models",
|
| 19 |
+
description="Get information about all loaded models including fusion and submodels"
|
| 20 |
+
)
|
| 21 |
+
async def list_models() -> ModelsListResponse:
|
| 22 |
+
"""
|
| 23 |
+
List all loaded models.
|
| 24 |
+
|
| 25 |
+
Returns information about the fusion model and all submodels,
|
| 26 |
+
including their Hugging Face repository IDs and configurations.
|
| 27 |
+
"""
|
| 28 |
+
registry = get_model_registry()
|
| 29 |
+
models = registry.list_models()
|
| 30 |
+
|
| 31 |
+
fusion_info = None
|
| 32 |
+
submodels_info = []
|
| 33 |
+
|
| 34 |
+
for model in models:
|
| 35 |
+
model_info = ModelInfo(
|
| 36 |
+
repo_id=model["repo_id"],
|
| 37 |
+
name=model["name"],
|
| 38 |
+
model_type=model["model_type"],
|
| 39 |
+
config=model.get("config")
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
if model["model_type"] == "fusion":
|
| 43 |
+
fusion_info = model_info
|
| 44 |
+
else:
|
| 45 |
+
submodels_info.append(model_info)
|
| 46 |
+
|
| 47 |
+
return ModelsListResponse(
|
| 48 |
+
fusion=fusion_info,
|
| 49 |
+
submodels=submodels_info,
|
| 50 |
+
total_count=len(models)
|
| 51 |
+
)
|
app/api/routes_predict.py
ADDED
|
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Prediction routes.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import base64
|
| 6 |
+
from typing import Optional
|
| 7 |
+
|
| 8 |
+
from fastapi import APIRouter, File, Form, HTTPException, Query, UploadFile
|
| 9 |
+
|
| 10 |
+
from app.core.errors import (
|
| 11 |
+
DeepFakeDetectorError,
|
| 12 |
+
ImageProcessingError,
|
| 13 |
+
InferenceError,
|
| 14 |
+
FusionError,
|
| 15 |
+
ModelNotFoundError,
|
| 16 |
+
ModelNotLoadedError
|
| 17 |
+
)
|
| 18 |
+
from app.core.logging import get_logger
|
| 19 |
+
from app.schemas.predict import (
|
| 20 |
+
PredictResponse,
|
| 21 |
+
PredictionResult,
|
| 22 |
+
TimingInfo,
|
| 23 |
+
ErrorResponse,
|
| 24 |
+
FusionMeta,
|
| 25 |
+
ModelDisplayInfo,
|
| 26 |
+
ExplainModelResponse,
|
| 27 |
+
SingleModelInsight
|
| 28 |
+
)
|
| 29 |
+
from app.services.inference_service import get_inference_service
|
| 30 |
+
from app.services.fusion_service import get_fusion_service
|
| 31 |
+
from app.services.preprocess_service import get_preprocess_service
|
| 32 |
+
from app.services.model_registry import get_model_registry
|
| 33 |
+
from app.services.llm_service import get_llm_service, get_model_display_info, MODEL_DISPLAY_INFO
|
| 34 |
+
from app.utils.timing import Timer
|
| 35 |
+
|
| 36 |
+
logger = get_logger(__name__)
|
| 37 |
+
router = APIRouter(tags=["predict"])
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@router.post(
|
| 41 |
+
"/predict",
|
| 42 |
+
response_model=PredictResponse,
|
| 43 |
+
summary="Predict if image is real or fake",
|
| 44 |
+
description="Upload an image to get a deepfake detection prediction",
|
| 45 |
+
responses={
|
| 46 |
+
400: {"model": ErrorResponse, "description": "Invalid image or request"},
|
| 47 |
+
404: {"model": ErrorResponse, "description": "Model not found"},
|
| 48 |
+
500: {"model": ErrorResponse, "description": "Inference error"}
|
| 49 |
+
}
|
| 50 |
+
)
|
| 51 |
+
async def predict(
|
| 52 |
+
image: UploadFile = File(..., description="Image file to analyze"),
|
| 53 |
+
use_fusion: bool = Query(
|
| 54 |
+
True,
|
| 55 |
+
description="Use fusion model (majority vote) across all submodels"
|
| 56 |
+
),
|
| 57 |
+
model: Optional[str] = Query(
|
| 58 |
+
None,
|
| 59 |
+
description="Specific submodel to use (name or repo_id). Only used when use_fusion=false"
|
| 60 |
+
),
|
| 61 |
+
return_submodels: Optional[bool] = Query(
|
| 62 |
+
None,
|
| 63 |
+
description="Include individual submodel predictions in response. Defaults to true when use_fusion=true"
|
| 64 |
+
),
|
| 65 |
+
explain: bool = Query(
|
| 66 |
+
True,
|
| 67 |
+
description="Generate explainability heatmaps (Grad-CAM for CNNs, attention rollout for transformers)"
|
| 68 |
+
)
|
| 69 |
+
) -> PredictResponse:
|
| 70 |
+
"""
|
| 71 |
+
Predict if an uploaded image is real or fake.
|
| 72 |
+
|
| 73 |
+
When use_fusion=true (default):
|
| 74 |
+
- Runs all submodels on the image
|
| 75 |
+
- Combines predictions using majority vote fusion
|
| 76 |
+
- Returns the fused result plus optionally individual submodel results
|
| 77 |
+
|
| 78 |
+
When use_fusion=false:
|
| 79 |
+
- Runs only the specified submodel (or the first available if not specified)
|
| 80 |
+
- Returns just that model's prediction
|
| 81 |
+
|
| 82 |
+
Response includes timing information for each step.
|
| 83 |
+
"""
|
| 84 |
+
timer = Timer()
|
| 85 |
+
timer.start_total()
|
| 86 |
+
|
| 87 |
+
# Determine if we should return submodel results
|
| 88 |
+
should_return_submodels = return_submodels if return_submodels is not None else use_fusion
|
| 89 |
+
|
| 90 |
+
try:
|
| 91 |
+
# Read image bytes
|
| 92 |
+
with timer.measure("download"):
|
| 93 |
+
image_bytes = await image.read()
|
| 94 |
+
|
| 95 |
+
# Validate and preprocess
|
| 96 |
+
with timer.measure("preprocess"):
|
| 97 |
+
preprocess_service = get_preprocess_service()
|
| 98 |
+
preprocess_service.validate_image(image_bytes)
|
| 99 |
+
|
| 100 |
+
inference_service = get_inference_service()
|
| 101 |
+
fusion_service = get_fusion_service()
|
| 102 |
+
registry = get_model_registry()
|
| 103 |
+
|
| 104 |
+
if use_fusion:
|
| 105 |
+
# Run all submodels
|
| 106 |
+
with timer.measure("inference"):
|
| 107 |
+
submodel_outputs = inference_service.predict_all_submodels(
|
| 108 |
+
image_bytes=image_bytes,
|
| 109 |
+
explain=explain
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
# Run fusion
|
| 113 |
+
with timer.measure("fusion"):
|
| 114 |
+
final_result = fusion_service.fuse(submodel_outputs=submodel_outputs)
|
| 115 |
+
|
| 116 |
+
timer.stop_total()
|
| 117 |
+
|
| 118 |
+
# Extract fusion meta (contribution percentages)
|
| 119 |
+
fusion_meta_dict = final_result.get("meta", {})
|
| 120 |
+
contribution_percentages = fusion_meta_dict.get("contribution_percentages", {})
|
| 121 |
+
|
| 122 |
+
# Build fusion meta object
|
| 123 |
+
fusion_meta = FusionMeta(
|
| 124 |
+
submodel_weights=fusion_meta_dict.get("submodel_weights", {}),
|
| 125 |
+
weighted_contributions=fusion_meta_dict.get("weighted_contributions", {}),
|
| 126 |
+
contribution_percentages=contribution_percentages
|
| 127 |
+
) if fusion_meta_dict else None
|
| 128 |
+
|
| 129 |
+
# Build model display info for frontend
|
| 130 |
+
model_display_info = {
|
| 131 |
+
name: ModelDisplayInfo(**get_model_display_info(name))
|
| 132 |
+
for name in submodel_outputs.keys()
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
# Build response
|
| 136 |
+
return PredictResponse(
|
| 137 |
+
final=PredictionResult(
|
| 138 |
+
pred=final_result["pred"],
|
| 139 |
+
pred_int=final_result["pred_int"],
|
| 140 |
+
prob_fake=final_result["prob_fake"]
|
| 141 |
+
),
|
| 142 |
+
fusion_used=True,
|
| 143 |
+
submodels={
|
| 144 |
+
name: PredictionResult(
|
| 145 |
+
pred=output["pred"],
|
| 146 |
+
pred_int=output["pred_int"],
|
| 147 |
+
prob_fake=output["prob_fake"],
|
| 148 |
+
heatmap_base64=output.get("heatmap_base64"),
|
| 149 |
+
explainability_type=output.get("explainability_type"),
|
| 150 |
+
focus_summary=output.get("focus_summary"),
|
| 151 |
+
contribution_percentage=contribution_percentages.get(name)
|
| 152 |
+
)
|
| 153 |
+
for name, output in submodel_outputs.items()
|
| 154 |
+
} if should_return_submodels else None,
|
| 155 |
+
fusion_meta=fusion_meta,
|
| 156 |
+
model_display_info=model_display_info if should_return_submodels else None,
|
| 157 |
+
timing_ms=TimingInfo(**timer.get_timings())
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
else:
|
| 161 |
+
# Single model prediction
|
| 162 |
+
model_key = model or registry.get_submodel_names()[0]
|
| 163 |
+
|
| 164 |
+
with timer.measure("inference"):
|
| 165 |
+
result = inference_service.predict_single(
|
| 166 |
+
model_key=model_key,
|
| 167 |
+
image_bytes=image_bytes,
|
| 168 |
+
explain=explain
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
timer.stop_total()
|
| 172 |
+
|
| 173 |
+
return PredictResponse(
|
| 174 |
+
final=PredictionResult(
|
| 175 |
+
pred=result["pred"],
|
| 176 |
+
pred_int=result["pred_int"],
|
| 177 |
+
prob_fake=result["prob_fake"],
|
| 178 |
+
heatmap_base64=result.get("heatmap_base64"),
|
| 179 |
+
explainability_type=result.get("explainability_type"),
|
| 180 |
+
focus_summary=result.get("focus_summary")
|
| 181 |
+
),
|
| 182 |
+
fusion_used=False,
|
| 183 |
+
submodels=None,
|
| 184 |
+
timing_ms=TimingInfo(**timer.get_timings())
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
except ImageProcessingError as e:
|
| 188 |
+
logger.warning(f"Image processing error: {e.message}")
|
| 189 |
+
raise HTTPException(
|
| 190 |
+
status_code=400,
|
| 191 |
+
detail={"error": "ImageProcessingError", "message": e.message, "details": e.details}
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
except ModelNotFoundError as e:
|
| 195 |
+
logger.warning(f"Model not found: {e.message}")
|
| 196 |
+
raise HTTPException(
|
| 197 |
+
status_code=404,
|
| 198 |
+
detail={"error": "ModelNotFoundError", "message": e.message, "details": e.details}
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
except ModelNotLoadedError as e:
|
| 202 |
+
logger.error(f"Models not loaded: {e.message}")
|
| 203 |
+
raise HTTPException(
|
| 204 |
+
status_code=503,
|
| 205 |
+
detail={"error": "ModelNotLoadedError", "message": e.message, "details": e.details}
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
except (InferenceError, FusionError) as e:
|
| 209 |
+
logger.error(f"Inference/Fusion error: {e.message}")
|
| 210 |
+
raise HTTPException(
|
| 211 |
+
status_code=500,
|
| 212 |
+
detail={"error": type(e).__name__, "message": e.message, "details": e.details}
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
except Exception as e:
|
| 216 |
+
logger.exception(f"Unexpected error in predict endpoint: {e}")
|
| 217 |
+
raise HTTPException(
|
| 218 |
+
status_code=500,
|
| 219 |
+
detail={"error": "InternalError", "message": str(e)}
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
@router.post("/explain-model", response_model=ExplainModelResponse)
|
| 224 |
+
async def explain_model(
|
| 225 |
+
image: UploadFile = File(...),
|
| 226 |
+
model_name: str = Form(...),
|
| 227 |
+
prob_fake: float = Form(...),
|
| 228 |
+
contribution_percentage: float = Form(None),
|
| 229 |
+
heatmap_base64: str = Form(None),
|
| 230 |
+
focus_summary: str = Form(None)
|
| 231 |
+
):
|
| 232 |
+
"""
|
| 233 |
+
Generate an on-demand LLM explanation for a single model's prediction.
|
| 234 |
+
This endpoint is token-efficient - only called when user requests insights.
|
| 235 |
+
"""
|
| 236 |
+
try:
|
| 237 |
+
# Read and validate image
|
| 238 |
+
image_bytes = await image.read()
|
| 239 |
+
if len(image_bytes) == 0:
|
| 240 |
+
raise HTTPException(status_code=400, detail="Empty image file")
|
| 241 |
+
|
| 242 |
+
# Encode image to base64 for LLM
|
| 243 |
+
original_b64 = base64.b64encode(image_bytes).decode('utf-8')
|
| 244 |
+
|
| 245 |
+
# Get LLM service
|
| 246 |
+
llm_service = get_llm_service()
|
| 247 |
+
if not llm_service.enabled:
|
| 248 |
+
raise HTTPException(
|
| 249 |
+
status_code=503,
|
| 250 |
+
detail="LLM service is not enabled. Set GEMINI_API_KEY environment variable."
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
# Generate explanation
|
| 254 |
+
result = llm_service.generate_single_model_explanation(
|
| 255 |
+
model_name=model_name,
|
| 256 |
+
original_image_b64=original_b64,
|
| 257 |
+
prob_fake=prob_fake,
|
| 258 |
+
heatmap_b64=heatmap_base64,
|
| 259 |
+
contribution_percentage=contribution_percentage,
|
| 260 |
+
focus_summary=focus_summary
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
if result is None:
|
| 264 |
+
raise HTTPException(
|
| 265 |
+
status_code=500,
|
| 266 |
+
detail="Failed to generate explanation from LLM"
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
return ExplainModelResponse(
|
| 270 |
+
model_name=model_name,
|
| 271 |
+
insight=SingleModelInsight(
|
| 272 |
+
key_finding=result["key_finding"],
|
| 273 |
+
what_model_saw=result["what_model_saw"],
|
| 274 |
+
important_regions=result["important_regions"],
|
| 275 |
+
confidence_qualifier=result["confidence_qualifier"]
|
| 276 |
+
)
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
except HTTPException:
|
| 280 |
+
raise
|
| 281 |
+
except Exception as e:
|
| 282 |
+
logger.exception(f"Error generating model explanation: {e}")
|
| 283 |
+
raise HTTPException(
|
| 284 |
+
status_code=500,
|
| 285 |
+
detail={"error": "ExplanationError", "message": str(e)}
|
| 286 |
+
)
|
app/core/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Core module
|
app/core/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (193 Bytes). View file
|
|
|
app/core/__pycache__/config.cpython-312.pyc
ADDED
|
Binary file (2.22 kB). View file
|
|
|
app/core/__pycache__/errors.cpython-312.pyc
ADDED
|
Binary file (2.54 kB). View file
|
|
|
app/core/__pycache__/logging.cpython-312.pyc
ADDED
|
Binary file (2.52 kB). View file
|
|
|
app/core/config.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Application configuration with environment variable support.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
from functools import lru_cache
|
| 7 |
+
from pydantic_settings import BaseSettings
|
| 8 |
+
from typing import Optional
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Settings(BaseSettings):
|
| 12 |
+
"""Application settings loaded from environment variables."""
|
| 13 |
+
|
| 14 |
+
# Hugging Face configuration
|
| 15 |
+
# Available fusion models:
|
| 16 |
+
# - DeepFakeDetector/fusion-logreg-final (Logistic Regression - default)
|
| 17 |
+
# - DeepFakeDetector/fusion-meta-final (Meta-classifier)
|
| 18 |
+
HF_FUSION_REPO_ID: str = "DeepFakeDetector/fusion-logreg-final"
|
| 19 |
+
HF_CACHE_DIR: str = ".hf_cache"
|
| 20 |
+
HF_TOKEN: Optional[str] = None
|
| 21 |
+
|
| 22 |
+
# Google Gemini API configuration
|
| 23 |
+
GOOGLE_API_KEY: Optional[str] = None
|
| 24 |
+
GEMINI_MODEL: str = "gemini-2.5-flash"
|
| 25 |
+
|
| 26 |
+
@property
|
| 27 |
+
def llm_enabled(self) -> bool:
|
| 28 |
+
"""Check if LLM explanations are available."""
|
| 29 |
+
return self.GOOGLE_API_KEY is not None and len(self.GOOGLE_API_KEY) > 0
|
| 30 |
+
|
| 31 |
+
# Application configuration
|
| 32 |
+
ENABLE_DEBUG: bool = False
|
| 33 |
+
LOG_LEVEL: str = "INFO"
|
| 34 |
+
|
| 35 |
+
# Server configuration
|
| 36 |
+
HOST: str = "0.0.0.0"
|
| 37 |
+
PORT: int = 8000
|
| 38 |
+
|
| 39 |
+
# CORS configuration
|
| 40 |
+
CORS_ORIGINS: str = "http://localhost:5173,http://localhost:3000,https://www.deepfake-detector.app,https://deepfake-detector.app"
|
| 41 |
+
|
| 42 |
+
@property
|
| 43 |
+
def cors_origins_list(self) -> list[str]:
|
| 44 |
+
"""Parse CORS origins from comma-separated string."""
|
| 45 |
+
return [origin.strip() for origin in self.CORS_ORIGINS.split(",") if origin.strip()]
|
| 46 |
+
|
| 47 |
+
# API configuration
|
| 48 |
+
API_V1_PREFIX: str = "/api/v1"
|
| 49 |
+
PROJECT_NAME: str = "DeepFake Detector API"
|
| 50 |
+
VERSION: str = "0.1.0"
|
| 51 |
+
|
| 52 |
+
class Config:
|
| 53 |
+
env_file = ".env"
|
| 54 |
+
env_file_encoding = "utf-8"
|
| 55 |
+
case_sensitive = True
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@lru_cache()
|
| 59 |
+
def get_settings() -> Settings:
|
| 60 |
+
"""Get cached settings instance."""
|
| 61 |
+
return Settings()
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
settings = get_settings()
|
app/core/errors.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Custom exceptions and error handling for the application.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from typing import Any, Dict, Optional
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class DeepFakeDetectorError(Exception):
|
| 9 |
+
"""Base exception for DeepFake Detector application."""
|
| 10 |
+
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
message: str,
|
| 14 |
+
details: Optional[Dict[str, Any]] = None
|
| 15 |
+
):
|
| 16 |
+
self.message = message
|
| 17 |
+
self.details = details or {}
|
| 18 |
+
super().__init__(self.message)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class ModelNotLoadedError(DeepFakeDetectorError):
|
| 22 |
+
"""Raised when attempting to use a model that hasn't been loaded."""
|
| 23 |
+
pass
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class ModelNotFoundError(DeepFakeDetectorError):
|
| 27 |
+
"""Raised when a requested model is not found in the registry."""
|
| 28 |
+
pass
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class HuggingFaceDownloadError(DeepFakeDetectorError):
|
| 32 |
+
"""Raised when downloading from Hugging Face fails."""
|
| 33 |
+
pass
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class ImageProcessingError(DeepFakeDetectorError):
|
| 37 |
+
"""Raised when image processing/decoding fails."""
|
| 38 |
+
pass
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class InferenceError(DeepFakeDetectorError):
|
| 42 |
+
"""Raised when model inference fails."""
|
| 43 |
+
pass
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class FusionError(DeepFakeDetectorError):
|
| 47 |
+
"""Raised when fusion prediction fails."""
|
| 48 |
+
pass
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class ConfigurationError(DeepFakeDetectorError):
|
| 52 |
+
"""Raised when configuration is invalid or missing."""
|
| 53 |
+
pass
|
app/core/logging.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Logging configuration for the application.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
import sys
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
from app.core.config import settings
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def setup_logging(level: Optional[str] = None) -> logging.Logger:
|
| 13 |
+
"""
|
| 14 |
+
Set up application logging.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
level: Log level string (DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
Configured logger instance
|
| 21 |
+
"""
|
| 22 |
+
log_level = level or settings.LOG_LEVEL
|
| 23 |
+
|
| 24 |
+
# Create formatter
|
| 25 |
+
formatter = logging.Formatter(
|
| 26 |
+
fmt="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s",
|
| 27 |
+
datefmt="%Y-%m-%d %H:%M:%S"
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
# Configure root logger
|
| 31 |
+
root_logger = logging.getLogger()
|
| 32 |
+
root_logger.setLevel(getattr(logging, log_level.upper(), logging.INFO))
|
| 33 |
+
|
| 34 |
+
# Remove existing handlers
|
| 35 |
+
for handler in root_logger.handlers[:]:
|
| 36 |
+
root_logger.removeHandler(handler)
|
| 37 |
+
|
| 38 |
+
# Add stdout handler
|
| 39 |
+
stdout_handler = logging.StreamHandler(sys.stdout)
|
| 40 |
+
stdout_handler.setFormatter(formatter)
|
| 41 |
+
root_logger.addHandler(stdout_handler)
|
| 42 |
+
|
| 43 |
+
# Set third-party loggers to WARNING to reduce noise
|
| 44 |
+
logging.getLogger("uvicorn").setLevel(logging.WARNING)
|
| 45 |
+
logging.getLogger("httpx").setLevel(logging.WARNING)
|
| 46 |
+
logging.getLogger("huggingface_hub").setLevel(logging.WARNING)
|
| 47 |
+
|
| 48 |
+
return root_logger
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def get_logger(name: str) -> logging.Logger:
|
| 52 |
+
"""
|
| 53 |
+
Get a named logger instance.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
name: Logger name (typically __name__)
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
Logger instance
|
| 60 |
+
"""
|
| 61 |
+
return logging.getLogger(name)
|
app/main.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FastAPI application entry point.
|
| 3 |
+
|
| 4 |
+
DeepFake Detector API - Milestone 1: Hugging Face hosted dummy models.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from contextlib import asynccontextmanager
|
| 8 |
+
from typing import AsyncGenerator
|
| 9 |
+
|
| 10 |
+
from fastapi import FastAPI, Request
|
| 11 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 12 |
+
from fastapi.responses import JSONResponse
|
| 13 |
+
|
| 14 |
+
from app.api import routes_health, routes_models, routes_predict
|
| 15 |
+
from app.core.config import settings
|
| 16 |
+
from app.core.errors import DeepFakeDetectorError
|
| 17 |
+
from app.core.logging import setup_logging, get_logger
|
| 18 |
+
from app.services.model_registry import get_model_registry
|
| 19 |
+
|
| 20 |
+
# Set up logging
|
| 21 |
+
setup_logging()
|
| 22 |
+
logger = get_logger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@asynccontextmanager
|
| 26 |
+
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
| 27 |
+
"""
|
| 28 |
+
Application lifespan manager.
|
| 29 |
+
|
| 30 |
+
Handles startup and shutdown events:
|
| 31 |
+
- Startup: Load models from Hugging Face
|
| 32 |
+
- Shutdown: Cleanup resources
|
| 33 |
+
"""
|
| 34 |
+
# Startup
|
| 35 |
+
logger.info("Starting DeepFake Detector API...")
|
| 36 |
+
logger.info(f"Configuration: HF_FUSION_REPO_ID={settings.HF_FUSION_REPO_ID}")
|
| 37 |
+
logger.info(f"Configuration: HF_CACHE_DIR={settings.HF_CACHE_DIR}")
|
| 38 |
+
|
| 39 |
+
# Load models from Hugging Face
|
| 40 |
+
try:
|
| 41 |
+
registry = get_model_registry()
|
| 42 |
+
await registry.load_from_fusion_repo(settings.HF_FUSION_REPO_ID)
|
| 43 |
+
logger.info("Models loaded successfully!")
|
| 44 |
+
except Exception as e:
|
| 45 |
+
logger.error(f"Failed to load models on startup: {e}")
|
| 46 |
+
logger.warning("API will start but /ready will report not_ready until models are loaded")
|
| 47 |
+
|
| 48 |
+
yield # Application runs here
|
| 49 |
+
|
| 50 |
+
# Shutdown
|
| 51 |
+
logger.info("Shutting down DeepFake Detector API...")
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# Create FastAPI application
|
| 55 |
+
app = FastAPI(
|
| 56 |
+
title=settings.PROJECT_NAME,
|
| 57 |
+
version=settings.VERSION,
|
| 58 |
+
description="""
|
| 59 |
+
DeepFake Detector API - Analyze images to detect AI-generated content.
|
| 60 |
+
|
| 61 |
+
## Features
|
| 62 |
+
|
| 63 |
+
- **Fusion prediction**: Combines multiple model predictions using majority vote
|
| 64 |
+
- **Individual model prediction**: Run specific submodels directly
|
| 65 |
+
- **Timing information**: Detailed performance metrics for each request
|
| 66 |
+
|
| 67 |
+
## Milestone 1
|
| 68 |
+
|
| 69 |
+
This is the initial milestone using dummy random models hosted on Hugging Face
|
| 70 |
+
for testing the API infrastructure.
|
| 71 |
+
""",
|
| 72 |
+
lifespan=lifespan,
|
| 73 |
+
debug=settings.ENABLE_DEBUG
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
# Add CORS middleware
|
| 77 |
+
app.add_middleware(
|
| 78 |
+
CORSMiddleware,
|
| 79 |
+
allow_origins=settings.cors_origins_list,
|
| 80 |
+
allow_credentials=True,
|
| 81 |
+
allow_methods=["*"],
|
| 82 |
+
allow_headers=["*"],
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
logger.info(f"CORS enabled for origins: {settings.cors_origins_list}")
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
# Global exception handler for custom errors
|
| 89 |
+
@app.exception_handler(DeepFakeDetectorError)
|
| 90 |
+
async def deepfake_error_handler(request: Request, exc: DeepFakeDetectorError):
|
| 91 |
+
"""Handle custom DeepFakeDetector exceptions."""
|
| 92 |
+
return JSONResponse(
|
| 93 |
+
status_code=500,
|
| 94 |
+
content={
|
| 95 |
+
"error": type(exc).__name__,
|
| 96 |
+
"message": exc.message,
|
| 97 |
+
"details": exc.details
|
| 98 |
+
}
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# Include routers
|
| 103 |
+
app.include_router(routes_health.router)
|
| 104 |
+
app.include_router(routes_models.router)
|
| 105 |
+
app.include_router(routes_predict.router)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
# Root endpoint
|
| 109 |
+
@app.get("/", tags=["root"])
|
| 110 |
+
async def root():
|
| 111 |
+
"""Root endpoint with API information."""
|
| 112 |
+
return {
|
| 113 |
+
"name": settings.PROJECT_NAME,
|
| 114 |
+
"version": settings.VERSION,
|
| 115 |
+
"docs": "/docs",
|
| 116 |
+
"health": "/health",
|
| 117 |
+
"ready": "/ready"
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
if __name__ == "__main__":
|
| 122 |
+
import uvicorn
|
| 123 |
+
uvicorn.run(
|
| 124 |
+
"app.main:app",
|
| 125 |
+
host=settings.HOST,
|
| 126 |
+
port=settings.PORT,
|
| 127 |
+
reload=settings.ENABLE_DEBUG
|
| 128 |
+
)
|
app/models/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Models module
|
app/models/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (195 Bytes). View file
|
|
|
app/models/wrappers/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Model wrappers module
|
app/models/wrappers/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (204 Bytes). View file
|
|
|
app/models/wrappers/__pycache__/base_wrapper.cpython-312.pyc
ADDED
|
Binary file (5.75 kB). View file
|
|
|
app/models/wrappers/__pycache__/cnn_transfer_wrapper.cpython-312.pyc
ADDED
|
Binary file (9.91 kB). View file
|
|
|
app/models/wrappers/__pycache__/deit_distilled_wrapper.cpython-312.pyc
ADDED
|
Binary file (13.7 kB). View file
|
|
|
app/models/wrappers/__pycache__/dummy_majority_fusion_wrapper.cpython-312.pyc
ADDED
|
Binary file (7.04 kB). View file
|
|
|
app/models/wrappers/__pycache__/dummy_random_wrapper.cpython-312.pyc
ADDED
|
Binary file (6.52 kB). View file
|
|
|
app/models/wrappers/__pycache__/gradfield_cnn_wrapper.cpython-312.pyc
ADDED
|
Binary file (18.1 kB). View file
|
|
|
app/models/wrappers/__pycache__/logreg_fusion_wrapper.cpython-312.pyc
ADDED
|
Binary file (6.73 kB). View file
|
|
|
app/models/wrappers/__pycache__/vit_base_wrapper.cpython-312.pyc
ADDED
|
Binary file (14.6 kB). View file
|
|
|
app/models/wrappers/base_wrapper.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Base wrapper class for model wrappers.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from abc import ABC, abstractmethod
|
| 6 |
+
from typing import Any, Callable, Dict, Optional
|
| 7 |
+
|
| 8 |
+
from PIL import Image
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class BaseModelWrapper(ABC):
|
| 12 |
+
"""
|
| 13 |
+
Abstract base class for model wrappers.
|
| 14 |
+
|
| 15 |
+
All model wrappers should inherit from this class and implement
|
| 16 |
+
the abstract methods.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
repo_id: str,
|
| 22 |
+
config: Dict[str, Any],
|
| 23 |
+
local_path: str
|
| 24 |
+
):
|
| 25 |
+
"""
|
| 26 |
+
Initialize the wrapper.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
repo_id: Hugging Face repository ID
|
| 30 |
+
config: Configuration from config.json
|
| 31 |
+
local_path: Local path where the model files are stored
|
| 32 |
+
"""
|
| 33 |
+
self.repo_id = repo_id
|
| 34 |
+
self.config = config
|
| 35 |
+
self.local_path = local_path
|
| 36 |
+
self._predict_fn: Optional[Callable] = None
|
| 37 |
+
|
| 38 |
+
@property
|
| 39 |
+
def name(self) -> str:
|
| 40 |
+
"""
|
| 41 |
+
Get the short name of the model.
|
| 42 |
+
|
| 43 |
+
Prefers 'name' from config if available, otherwise derives from repo_id.
|
| 44 |
+
Strips '-final' suffix to ensure consistency with fusion configs.
|
| 45 |
+
"""
|
| 46 |
+
# Try to get name from config first
|
| 47 |
+
config_name = self.config.get("name")
|
| 48 |
+
if config_name:
|
| 49 |
+
# Strip -final suffix if present
|
| 50 |
+
return config_name.replace("-final", "")
|
| 51 |
+
|
| 52 |
+
# Fall back to repo_id last part, strip -final suffix
|
| 53 |
+
repo_name = self.repo_id.split("/")[-1]
|
| 54 |
+
return repo_name.replace("-final", "")
|
| 55 |
+
|
| 56 |
+
@abstractmethod
|
| 57 |
+
def load(self) -> None:
|
| 58 |
+
"""
|
| 59 |
+
Load the model and prepare for inference.
|
| 60 |
+
|
| 61 |
+
This method should import the predict function from the downloaded
|
| 62 |
+
repository and store it for later use.
|
| 63 |
+
"""
|
| 64 |
+
pass
|
| 65 |
+
|
| 66 |
+
@abstractmethod
|
| 67 |
+
def predict(self, *args, **kwargs) -> Dict[str, Any]:
|
| 68 |
+
"""
|
| 69 |
+
Run prediction.
|
| 70 |
+
|
| 71 |
+
Returns:
|
| 72 |
+
Dictionary with standardized prediction fields:
|
| 73 |
+
- pred_int: 0 (real) or 1 (fake)
|
| 74 |
+
- pred: "real" or "fake"
|
| 75 |
+
- prob_fake: float probability
|
| 76 |
+
- meta: dict with any additional metadata
|
| 77 |
+
"""
|
| 78 |
+
pass
|
| 79 |
+
|
| 80 |
+
def is_loaded(self) -> bool:
|
| 81 |
+
"""Check if the model is loaded and ready for inference."""
|
| 82 |
+
return self._predict_fn is not None
|
| 83 |
+
|
| 84 |
+
def get_info(self) -> Dict[str, Any]:
|
| 85 |
+
"""
|
| 86 |
+
Get model information.
|
| 87 |
+
|
| 88 |
+
Returns:
|
| 89 |
+
Dictionary with model info
|
| 90 |
+
"""
|
| 91 |
+
return {
|
| 92 |
+
"repo_id": self.repo_id,
|
| 93 |
+
"name": self.name,
|
| 94 |
+
"config": self.config,
|
| 95 |
+
"local_path": self.local_path,
|
| 96 |
+
"is_loaded": self.is_loaded()
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class BaseSubmodelWrapper(BaseModelWrapper):
|
| 101 |
+
"""Base wrapper for submodels that process images."""
|
| 102 |
+
|
| 103 |
+
@abstractmethod
|
| 104 |
+
def predict(
|
| 105 |
+
self,
|
| 106 |
+
image: Optional[Image.Image] = None,
|
| 107 |
+
image_bytes: Optional[bytes] = None,
|
| 108 |
+
explain: bool = False,
|
| 109 |
+
**kwargs
|
| 110 |
+
) -> Dict[str, Any]:
|
| 111 |
+
"""
|
| 112 |
+
Run prediction on an image.
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
image: PIL Image object
|
| 116 |
+
image_bytes: Raw image bytes (alternative to image)
|
| 117 |
+
explain: If True, include explainability heatmap in output
|
| 118 |
+
**kwargs: Additional arguments
|
| 119 |
+
|
| 120 |
+
Returns:
|
| 121 |
+
Standardized prediction dictionary with:
|
| 122 |
+
- pred_int: 0 (real) or 1 (fake)
|
| 123 |
+
- pred: "real" or "fake"
|
| 124 |
+
- prob_fake: float probability
|
| 125 |
+
- heatmap_base64: Optional[str] (when explain=True)
|
| 126 |
+
- explainability_type: Optional[str] (when explain=True)
|
| 127 |
+
"""
|
| 128 |
+
pass
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class BaseFusionWrapper(BaseModelWrapper):
|
| 132 |
+
"""Base wrapper for fusion models that combine submodel outputs."""
|
| 133 |
+
|
| 134 |
+
@abstractmethod
|
| 135 |
+
def predict(
|
| 136 |
+
self,
|
| 137 |
+
submodel_outputs: Dict[str, Dict[str, Any]],
|
| 138 |
+
**kwargs
|
| 139 |
+
) -> Dict[str, Any]:
|
| 140 |
+
"""
|
| 141 |
+
Run fusion prediction on submodel outputs.
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
submodel_outputs: Dictionary mapping submodel name to its output
|
| 145 |
+
**kwargs: Additional arguments
|
| 146 |
+
|
| 147 |
+
Returns:
|
| 148 |
+
Standardized prediction dictionary
|
| 149 |
+
"""
|
| 150 |
+
pass
|
app/models/wrappers/cnn_transfer_wrapper.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Wrapper for CNN Transfer (EfficientNet-B0) submodel.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Any, Dict, Optional, Tuple
|
| 12 |
+
from PIL import Image
|
| 13 |
+
from torchvision import transforms
|
| 14 |
+
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
|
| 15 |
+
|
| 16 |
+
from app.core.errors import InferenceError, ConfigurationError
|
| 17 |
+
from app.core.logging import get_logger
|
| 18 |
+
from app.models.wrappers.base_wrapper import BaseSubmodelWrapper
|
| 19 |
+
from app.services.explainability import GradCAM, heatmap_to_base64, compute_focus_summary
|
| 20 |
+
|
| 21 |
+
logger = get_logger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class CNNTransferWrapper(BaseSubmodelWrapper):
|
| 25 |
+
"""
|
| 26 |
+
Wrapper for CNN Transfer model using EfficientNet-B0 backbone.
|
| 27 |
+
|
| 28 |
+
Model expects 224x224 RGB images with ImageNet normalization.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
repo_id: str,
|
| 34 |
+
config: Dict[str, Any],
|
| 35 |
+
local_path: str
|
| 36 |
+
):
|
| 37 |
+
super().__init__(repo_id, config, local_path)
|
| 38 |
+
self._model: Optional[nn.Module] = None
|
| 39 |
+
self._transform: Optional[transforms.Compose] = None
|
| 40 |
+
self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 41 |
+
self._threshold = config.get("threshold", 0.5)
|
| 42 |
+
logger.info(f"Initialized CNNTransferWrapper for {repo_id}")
|
| 43 |
+
|
| 44 |
+
def load(self) -> None:
|
| 45 |
+
"""Load the EfficientNet-B0 model with trained weights."""
|
| 46 |
+
weights_path = Path(self.local_path) / "model.pth"
|
| 47 |
+
preprocess_path = Path(self.local_path) / "preprocess.json"
|
| 48 |
+
|
| 49 |
+
if not weights_path.exists():
|
| 50 |
+
raise ConfigurationError(
|
| 51 |
+
message=f"model.pth not found in {self.local_path}",
|
| 52 |
+
details={"repo_id": self.repo_id, "expected_path": str(weights_path)}
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
try:
|
| 56 |
+
# Load preprocessing config
|
| 57 |
+
preprocess_config = {}
|
| 58 |
+
if preprocess_path.exists():
|
| 59 |
+
with open(preprocess_path, "r") as f:
|
| 60 |
+
preprocess_config = json.load(f)
|
| 61 |
+
|
| 62 |
+
# Build transform pipeline
|
| 63 |
+
input_size = preprocess_config.get("input_size", [224, 224])
|
| 64 |
+
if isinstance(input_size, int):
|
| 65 |
+
input_size = [input_size, input_size]
|
| 66 |
+
|
| 67 |
+
normalize_config = preprocess_config.get("normalize", {})
|
| 68 |
+
mean = normalize_config.get("mean", [0.485, 0.456, 0.406])
|
| 69 |
+
std = normalize_config.get("std", [0.229, 0.224, 0.225])
|
| 70 |
+
|
| 71 |
+
self._transform = transforms.Compose([
|
| 72 |
+
transforms.Resize(input_size),
|
| 73 |
+
transforms.ToTensor(),
|
| 74 |
+
transforms.Normalize(mean=mean, std=std)
|
| 75 |
+
])
|
| 76 |
+
|
| 77 |
+
# Create model architecture
|
| 78 |
+
num_classes = self.config.get("num_classes", 2)
|
| 79 |
+
self._model = efficientnet_b0(weights=None)
|
| 80 |
+
|
| 81 |
+
# Replace classifier for binary classification
|
| 82 |
+
in_features = self._model.classifier[1].in_features
|
| 83 |
+
self._model.classifier = nn.Sequential(
|
| 84 |
+
nn.Dropout(p=0.2, inplace=True),
|
| 85 |
+
nn.Linear(in_features, num_classes)
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
# Load trained weights
|
| 89 |
+
state_dict = torch.load(weights_path, map_location=self._device, weights_only=True)
|
| 90 |
+
self._model.load_state_dict(state_dict)
|
| 91 |
+
self._model.to(self._device)
|
| 92 |
+
self._model.eval()
|
| 93 |
+
|
| 94 |
+
# Mark as loaded
|
| 95 |
+
self._predict_fn = self._run_inference
|
| 96 |
+
logger.info(f"Loaded CNN Transfer model from {self.repo_id}")
|
| 97 |
+
|
| 98 |
+
except ConfigurationError:
|
| 99 |
+
raise
|
| 100 |
+
except Exception as e:
|
| 101 |
+
logger.error(f"Failed to load CNN Transfer model: {e}")
|
| 102 |
+
raise ConfigurationError(
|
| 103 |
+
message=f"Failed to load model: {e}",
|
| 104 |
+
details={"repo_id": self.repo_id, "error": str(e)}
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
def _run_inference(
|
| 108 |
+
self,
|
| 109 |
+
image_tensor: torch.Tensor,
|
| 110 |
+
explain: bool = False
|
| 111 |
+
) -> Dict[str, Any]:
|
| 112 |
+
"""Run model inference on preprocessed tensor."""
|
| 113 |
+
heatmap = None
|
| 114 |
+
|
| 115 |
+
if explain:
|
| 116 |
+
# Use GradCAM for explainability (requires gradients)
|
| 117 |
+
target_layer = self._model.features[-1] # Last MBConv block
|
| 118 |
+
gradcam = GradCAM(self._model, target_layer)
|
| 119 |
+
try:
|
| 120 |
+
# GradCAM needs gradients, so don't use no_grad
|
| 121 |
+
logits = self._model(image_tensor)
|
| 122 |
+
probs = F.softmax(logits, dim=1)
|
| 123 |
+
prob_fake = probs[0, 1].item()
|
| 124 |
+
pred_int = 1 if prob_fake >= self._threshold else 0
|
| 125 |
+
|
| 126 |
+
# Compute heatmap for predicted class
|
| 127 |
+
heatmap = gradcam(
|
| 128 |
+
image_tensor.clone(),
|
| 129 |
+
target_class=pred_int,
|
| 130 |
+
output_size=(224, 224)
|
| 131 |
+
)
|
| 132 |
+
finally:
|
| 133 |
+
gradcam.remove_hooks()
|
| 134 |
+
else:
|
| 135 |
+
with torch.no_grad():
|
| 136 |
+
logits = self._model(image_tensor)
|
| 137 |
+
probs = F.softmax(logits, dim=1)
|
| 138 |
+
prob_fake = probs[0, 1].item()
|
| 139 |
+
pred_int = 1 if prob_fake >= self._threshold else 0
|
| 140 |
+
|
| 141 |
+
result = {
|
| 142 |
+
"logits": logits[0].detach().cpu().numpy().tolist(),
|
| 143 |
+
"prob_fake": prob_fake,
|
| 144 |
+
"pred_int": pred_int
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
if heatmap is not None:
|
| 148 |
+
result["heatmap"] = heatmap
|
| 149 |
+
|
| 150 |
+
return result
|
| 151 |
+
|
| 152 |
+
def predict(
|
| 153 |
+
self,
|
| 154 |
+
image: Optional[Image.Image] = None,
|
| 155 |
+
image_bytes: Optional[bytes] = None,
|
| 156 |
+
explain: bool = False,
|
| 157 |
+
**kwargs
|
| 158 |
+
) -> Dict[str, Any]:
|
| 159 |
+
"""
|
| 160 |
+
Run prediction on an image.
|
| 161 |
+
|
| 162 |
+
Args:
|
| 163 |
+
image: PIL Image object
|
| 164 |
+
image_bytes: Raw image bytes (will be converted to PIL Image)
|
| 165 |
+
explain: If True, compute GradCAM heatmap
|
| 166 |
+
|
| 167 |
+
Returns:
|
| 168 |
+
Standardized prediction dictionary with optional heatmap
|
| 169 |
+
"""
|
| 170 |
+
if self._model is None or self._transform is None:
|
| 171 |
+
raise InferenceError(
|
| 172 |
+
message="Model not loaded",
|
| 173 |
+
details={"repo_id": self.repo_id}
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
try:
|
| 177 |
+
# Convert bytes to PIL Image if needed
|
| 178 |
+
if image is None and image_bytes is not None:
|
| 179 |
+
import io
|
| 180 |
+
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
| 181 |
+
elif image is not None:
|
| 182 |
+
image = image.convert("RGB")
|
| 183 |
+
else:
|
| 184 |
+
raise InferenceError(
|
| 185 |
+
message="No image provided",
|
| 186 |
+
details={"repo_id": self.repo_id}
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
# Preprocess
|
| 190 |
+
image_tensor = self._transform(image).unsqueeze(0).to(self._device)
|
| 191 |
+
|
| 192 |
+
# Run inference
|
| 193 |
+
result = self._run_inference(image_tensor, explain=explain)
|
| 194 |
+
|
| 195 |
+
# Standardize output
|
| 196 |
+
labels = self.config.get("labels", {"0": "real", "1": "fake"})
|
| 197 |
+
pred_int = result["pred_int"]
|
| 198 |
+
|
| 199 |
+
output = {
|
| 200 |
+
"pred_int": pred_int,
|
| 201 |
+
"pred": labels.get(str(pred_int), "unknown"),
|
| 202 |
+
"prob_fake": result["prob_fake"],
|
| 203 |
+
"meta": {
|
| 204 |
+
"model": self.name,
|
| 205 |
+
"threshold": self._threshold,
|
| 206 |
+
"logits": result["logits"]
|
| 207 |
+
}
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
# Add heatmap if requested
|
| 211 |
+
if explain and "heatmap" in result:
|
| 212 |
+
heatmap = result["heatmap"]
|
| 213 |
+
output["heatmap_base64"] = heatmap_to_base64(heatmap)
|
| 214 |
+
output["explainability_type"] = "grad_cam"
|
| 215 |
+
output["focus_summary"] = compute_focus_summary(heatmap)
|
| 216 |
+
|
| 217 |
+
return output
|
| 218 |
+
|
| 219 |
+
except InferenceError:
|
| 220 |
+
raise
|
| 221 |
+
except Exception as e:
|
| 222 |
+
logger.error(f"Prediction failed for {self.repo_id}: {e}")
|
| 223 |
+
raise InferenceError(
|
| 224 |
+
message=f"Prediction failed: {e}",
|
| 225 |
+
details={"repo_id": self.repo_id, "error": str(e)}
|
| 226 |
+
)
|
app/models/wrappers/deit_distilled_wrapper.py
ADDED
|
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Wrapper for DeiT Distilled submodel.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 12 |
+
from PIL import Image
|
| 13 |
+
from torchvision import transforms
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
import timm
|
| 17 |
+
TIMM_AVAILABLE = True
|
| 18 |
+
except ImportError:
|
| 19 |
+
TIMM_AVAILABLE = False
|
| 20 |
+
|
| 21 |
+
from app.core.errors import InferenceError, ConfigurationError
|
| 22 |
+
from app.core.logging import get_logger
|
| 23 |
+
from app.models.wrappers.base_wrapper import BaseSubmodelWrapper
|
| 24 |
+
from app.services.explainability import attention_rollout, heatmap_to_base64, compute_focus_summary
|
| 25 |
+
|
| 26 |
+
logger = get_logger(__name__)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def create_custom_mlp_head(in_features: int = 768, num_classes: int = 2) -> nn.Sequential:
|
| 30 |
+
"""
|
| 31 |
+
Create custom MLP head for DeiT model matching training configuration.
|
| 32 |
+
|
| 33 |
+
Returns nn.Sequential to match saved state dict keys (0, 1, 4 indices).
|
| 34 |
+
"""
|
| 35 |
+
return nn.Sequential(
|
| 36 |
+
nn.LayerNorm(in_features), # 0
|
| 37 |
+
nn.Linear(in_features, 512), # 1
|
| 38 |
+
nn.GELU(), # 2 (no params)
|
| 39 |
+
nn.Dropout(p=0.2), # 3 (no params)
|
| 40 |
+
nn.Linear(512, num_classes) # 4
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class DeiTDistilledWrapper(BaseSubmodelWrapper):
|
| 45 |
+
"""
|
| 46 |
+
Wrapper for DeiT Distilled model.
|
| 47 |
+
|
| 48 |
+
Model expects 224x224 RGB images with ImageNet normalization.
|
| 49 |
+
Uses a custom MLP head for classification.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
def __init__(
|
| 53 |
+
self,
|
| 54 |
+
repo_id: str,
|
| 55 |
+
config: Dict[str, Any],
|
| 56 |
+
local_path: str
|
| 57 |
+
):
|
| 58 |
+
super().__init__(repo_id, config, local_path)
|
| 59 |
+
self._model: Optional[nn.Module] = None
|
| 60 |
+
self._transform: Optional[transforms.Compose] = None
|
| 61 |
+
self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 62 |
+
self._threshold = config.get("threshold", 0.5)
|
| 63 |
+
logger.info(f"Initialized DeiTDistilledWrapper for {repo_id}")
|
| 64 |
+
|
| 65 |
+
def load(self) -> None:
|
| 66 |
+
"""Load the DeiT model with custom head and trained weights."""
|
| 67 |
+
if not TIMM_AVAILABLE:
|
| 68 |
+
raise ConfigurationError(
|
| 69 |
+
message="timm package not installed. Run: pip install timm",
|
| 70 |
+
details={"repo_id": self.repo_id}
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
weights_path = Path(self.local_path) / "deit_distilled_final.pt"
|
| 74 |
+
preprocess_path = Path(self.local_path) / "preprocess.json"
|
| 75 |
+
|
| 76 |
+
if not weights_path.exists():
|
| 77 |
+
raise ConfigurationError(
|
| 78 |
+
message=f"deit_distilled_final.pt not found in {self.local_path}",
|
| 79 |
+
details={"repo_id": self.repo_id, "expected_path": str(weights_path)}
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
try:
|
| 83 |
+
# Load preprocessing config
|
| 84 |
+
preprocess_config = {}
|
| 85 |
+
if preprocess_path.exists():
|
| 86 |
+
with open(preprocess_path, "r") as f:
|
| 87 |
+
preprocess_config = json.load(f)
|
| 88 |
+
|
| 89 |
+
# Build transform pipeline
|
| 90 |
+
input_size = preprocess_config.get("input_size", 224)
|
| 91 |
+
if isinstance(input_size, list):
|
| 92 |
+
input_size = input_size[0]
|
| 93 |
+
|
| 94 |
+
normalize_config = preprocess_config.get("normalize", {})
|
| 95 |
+
mean = normalize_config.get("mean", [0.485, 0.456, 0.406])
|
| 96 |
+
std = normalize_config.get("std", [0.229, 0.224, 0.225])
|
| 97 |
+
|
| 98 |
+
# Use bicubic interpolation as specified
|
| 99 |
+
interpolation = preprocess_config.get("interpolation", "bicubic")
|
| 100 |
+
interp_mode = transforms.InterpolationMode.BICUBIC if interpolation == "bicubic" else transforms.InterpolationMode.BILINEAR
|
| 101 |
+
|
| 102 |
+
self._transform = transforms.Compose([
|
| 103 |
+
transforms.Resize((input_size, input_size), interpolation=interp_mode),
|
| 104 |
+
transforms.ToTensor(),
|
| 105 |
+
transforms.Normalize(mean=mean, std=std)
|
| 106 |
+
])
|
| 107 |
+
|
| 108 |
+
# Create model architecture
|
| 109 |
+
model_name = self.config.get("model_name", "deit_base_distilled_patch16_224")
|
| 110 |
+
num_classes = self.config.get("num_classes", 2)
|
| 111 |
+
|
| 112 |
+
# Create base model without pretrained weights
|
| 113 |
+
self._model = timm.create_model(model_name, pretrained=False, num_classes=0)
|
| 114 |
+
|
| 115 |
+
# Replace heads with custom MLP heads (Sequential assigned directly)
|
| 116 |
+
# Note: state dict has separate keys for head and head_dist, so don't share
|
| 117 |
+
hidden_dim = 768 # DeiT base hidden dimension
|
| 118 |
+
self._model.head = create_custom_mlp_head(hidden_dim, num_classes)
|
| 119 |
+
self._model.head_dist = create_custom_mlp_head(hidden_dim, num_classes)
|
| 120 |
+
|
| 121 |
+
# Load trained weights
|
| 122 |
+
state_dict = torch.load(weights_path, map_location=self._device, weights_only=True)
|
| 123 |
+
self._model.load_state_dict(state_dict)
|
| 124 |
+
self._model.to(self._device)
|
| 125 |
+
self._model.eval()
|
| 126 |
+
|
| 127 |
+
# Mark as loaded
|
| 128 |
+
self._predict_fn = self._run_inference
|
| 129 |
+
logger.info(f"Loaded DeiT Distilled model from {self.repo_id}")
|
| 130 |
+
|
| 131 |
+
except ConfigurationError:
|
| 132 |
+
raise
|
| 133 |
+
except Exception as e:
|
| 134 |
+
logger.error(f"Failed to load DeiT Distilled model: {e}")
|
| 135 |
+
raise ConfigurationError(
|
| 136 |
+
message=f"Failed to load model: {e}",
|
| 137 |
+
details={"repo_id": self.repo_id, "error": str(e)}
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
def _run_inference(
|
| 141 |
+
self,
|
| 142 |
+
image_tensor: torch.Tensor,
|
| 143 |
+
explain: bool = False
|
| 144 |
+
) -> Dict[str, Any]:
|
| 145 |
+
"""Run model inference on preprocessed tensor."""
|
| 146 |
+
heatmap = None
|
| 147 |
+
|
| 148 |
+
if explain:
|
| 149 |
+
# Collect attention weights from all blocks
|
| 150 |
+
attentions: List[torch.Tensor] = []
|
| 151 |
+
handles = []
|
| 152 |
+
|
| 153 |
+
# Hook into attention modules to capture weights
|
| 154 |
+
# DeiT blocks structure: blocks[i].attn
|
| 155 |
+
def create_attn_hook():
|
| 156 |
+
stored_attn = []
|
| 157 |
+
|
| 158 |
+
def hook(module, inputs, outputs):
|
| 159 |
+
# Get q, k from the module's forward computation
|
| 160 |
+
# inputs[0] is x of shape [B, N, C]
|
| 161 |
+
x = inputs[0]
|
| 162 |
+
B, N, C = x.shape
|
| 163 |
+
|
| 164 |
+
# Access the attention module's parameters
|
| 165 |
+
qkv = module.qkv(x) # [B, N, 3*dim]
|
| 166 |
+
qkv = qkv.reshape(B, N, 3, module.num_heads, C // module.num_heads)
|
| 167 |
+
qkv = qkv.permute(2, 0, 3, 1, 4) # [3, B, heads, N, dim_head]
|
| 168 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
| 169 |
+
|
| 170 |
+
# Compute attention weights
|
| 171 |
+
scale = (C // module.num_heads) ** -0.5
|
| 172 |
+
attn = (q @ k.transpose(-2, -1)) * scale
|
| 173 |
+
attn = attn.softmax(dim=-1) # [B, heads, N, N]
|
| 174 |
+
|
| 175 |
+
# Average over heads
|
| 176 |
+
attn_avg = attn.mean(dim=1) # [B, N, N]
|
| 177 |
+
stored_attn.append(attn_avg.detach())
|
| 178 |
+
|
| 179 |
+
return hook, stored_attn
|
| 180 |
+
|
| 181 |
+
all_stored_attns = []
|
| 182 |
+
for block in self._model.blocks:
|
| 183 |
+
hook_fn, stored = create_attn_hook()
|
| 184 |
+
all_stored_attns.append(stored)
|
| 185 |
+
handle = block.attn.register_forward_hook(hook_fn)
|
| 186 |
+
handles.append(handle)
|
| 187 |
+
|
| 188 |
+
try:
|
| 189 |
+
with torch.no_grad():
|
| 190 |
+
logits = self._model(image_tensor)
|
| 191 |
+
probs = F.softmax(logits, dim=1)
|
| 192 |
+
prob_fake = probs[0, 1].item()
|
| 193 |
+
pred_int = 1 if prob_fake >= self._threshold else 0
|
| 194 |
+
|
| 195 |
+
# Get attention from hooks
|
| 196 |
+
attention_list = [stored[0] for stored in all_stored_attns if len(stored) > 0]
|
| 197 |
+
|
| 198 |
+
if attention_list:
|
| 199 |
+
# Stack: [num_layers, B, N, N]
|
| 200 |
+
attention_stack = torch.stack(attention_list, dim=0)
|
| 201 |
+
# Compute rollout - returns (grid_size, grid_size) heatmap
|
| 202 |
+
attention_map = attention_rollout(
|
| 203 |
+
attention_stack[:, 0], # [num_layers, N, N]
|
| 204 |
+
head_fusion="mean", # Already averaged
|
| 205 |
+
discard_ratio=0.0,
|
| 206 |
+
num_prefix_tokens=2 # DeiT has CLS + distillation token
|
| 207 |
+
) # Returns (14, 14) for DeiT-Base
|
| 208 |
+
|
| 209 |
+
# Resize to image size
|
| 210 |
+
from PIL import Image as PILImage
|
| 211 |
+
heatmap_img = PILImage.fromarray(
|
| 212 |
+
(attention_map * 255).astype(np.uint8)
|
| 213 |
+
).resize((224, 224), PILImage.BILINEAR)
|
| 214 |
+
heatmap = np.array(heatmap_img).astype(np.float32) / 255.0
|
| 215 |
+
|
| 216 |
+
finally:
|
| 217 |
+
for handle in handles:
|
| 218 |
+
handle.remove()
|
| 219 |
+
else:
|
| 220 |
+
with torch.no_grad():
|
| 221 |
+
# In eval mode, DeiT returns single tensor
|
| 222 |
+
logits = self._model(image_tensor)
|
| 223 |
+
probs = F.softmax(logits, dim=1)
|
| 224 |
+
prob_fake = probs[0, 1].item()
|
| 225 |
+
pred_int = 1 if prob_fake >= self._threshold else 0
|
| 226 |
+
|
| 227 |
+
result = {
|
| 228 |
+
"logits": logits[0].cpu().numpy().tolist(),
|
| 229 |
+
"prob_fake": prob_fake,
|
| 230 |
+
"pred_int": pred_int
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
if heatmap is not None:
|
| 234 |
+
result["heatmap"] = heatmap
|
| 235 |
+
|
| 236 |
+
return result
|
| 237 |
+
|
| 238 |
+
def predict(
|
| 239 |
+
self,
|
| 240 |
+
image: Optional[Image.Image] = None,
|
| 241 |
+
image_bytes: Optional[bytes] = None,
|
| 242 |
+
explain: bool = False,
|
| 243 |
+
**kwargs
|
| 244 |
+
) -> Dict[str, Any]:
|
| 245 |
+
"""
|
| 246 |
+
Run prediction on an image.
|
| 247 |
+
|
| 248 |
+
Args:
|
| 249 |
+
image: PIL Image object
|
| 250 |
+
image_bytes: Raw image bytes (will be converted to PIL Image)
|
| 251 |
+
explain: If True, compute attention rollout heatmap
|
| 252 |
+
|
| 253 |
+
Returns:
|
| 254 |
+
Standardized prediction dictionary with optional heatmap
|
| 255 |
+
"""
|
| 256 |
+
if self._model is None or self._transform is None:
|
| 257 |
+
raise InferenceError(
|
| 258 |
+
message="Model not loaded",
|
| 259 |
+
details={"repo_id": self.repo_id}
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
try:
|
| 263 |
+
# Convert bytes to PIL Image if needed
|
| 264 |
+
if image is None and image_bytes is not None:
|
| 265 |
+
import io
|
| 266 |
+
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
| 267 |
+
elif image is not None:
|
| 268 |
+
image = image.convert("RGB")
|
| 269 |
+
else:
|
| 270 |
+
raise InferenceError(
|
| 271 |
+
message="No image provided",
|
| 272 |
+
details={"repo_id": self.repo_id}
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
# Preprocess
|
| 276 |
+
image_tensor = self._transform(image).unsqueeze(0).to(self._device)
|
| 277 |
+
|
| 278 |
+
# Run inference
|
| 279 |
+
result = self._run_inference(image_tensor, explain=explain)
|
| 280 |
+
|
| 281 |
+
# Standardize output
|
| 282 |
+
class_mapping = self.config.get("class_mapping", {"0": "real", "1": "fake"})
|
| 283 |
+
pred_int = result["pred_int"]
|
| 284 |
+
|
| 285 |
+
output = {
|
| 286 |
+
"pred_int": pred_int,
|
| 287 |
+
"pred": class_mapping.get(str(pred_int), "unknown"),
|
| 288 |
+
"prob_fake": result["prob_fake"],
|
| 289 |
+
"meta": {
|
| 290 |
+
"model": self.name,
|
| 291 |
+
"threshold": self._threshold,
|
| 292 |
+
"logits": result["logits"]
|
| 293 |
+
}
|
| 294 |
+
}
|
| 295 |
+
|
| 296 |
+
# Add heatmap if requested
|
| 297 |
+
if explain and "heatmap" in result:
|
| 298 |
+
heatmap = result["heatmap"]
|
| 299 |
+
output["heatmap_base64"] = heatmap_to_base64(heatmap)
|
| 300 |
+
output["explainability_type"] = "attention_rollout"
|
| 301 |
+
output["focus_summary"] = compute_focus_summary(heatmap)
|
| 302 |
+
|
| 303 |
+
return output
|
| 304 |
+
|
| 305 |
+
except InferenceError:
|
| 306 |
+
raise
|
| 307 |
+
except Exception as e:
|
| 308 |
+
logger.error(f"Prediction failed for {self.repo_id}: {e}")
|
| 309 |
+
raise InferenceError(
|
| 310 |
+
message=f"Prediction failed: {e}",
|
| 311 |
+
details={"repo_id": self.repo_id, "error": str(e)}
|
| 312 |
+
)
|
app/models/wrappers/dummy_majority_fusion_wrapper.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Wrapper for dummy majority vote fusion model.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import importlib.util
|
| 6 |
+
import sys
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Any, Dict, List
|
| 9 |
+
|
| 10 |
+
from app.core.errors import FusionError, ConfigurationError
|
| 11 |
+
from app.core.logging import get_logger
|
| 12 |
+
from app.models.wrappers.base_wrapper import BaseFusionWrapper
|
| 13 |
+
|
| 14 |
+
logger = get_logger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class DummyMajorityFusionWrapper(BaseFusionWrapper):
|
| 18 |
+
"""
|
| 19 |
+
Wrapper for dummy majority vote fusion models.
|
| 20 |
+
|
| 21 |
+
These models are hosted on Hugging Face and contain a fusion.py
|
| 22 |
+
with a predict() function that performs majority voting on submodel outputs.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
repo_id: str,
|
| 28 |
+
config: Dict[str, Any],
|
| 29 |
+
local_path: str
|
| 30 |
+
):
|
| 31 |
+
"""
|
| 32 |
+
Initialize the wrapper.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
repo_id: Hugging Face repository ID (e.g., "DeepFakeDetector/fusion-majority-test")
|
| 36 |
+
config: Configuration from config.json
|
| 37 |
+
local_path: Local path where the model files are stored
|
| 38 |
+
"""
|
| 39 |
+
super().__init__(repo_id, config, local_path)
|
| 40 |
+
self._submodel_repos: List[str] = config.get("submodels", [])
|
| 41 |
+
logger.info(f"Initialized DummyMajorityFusionWrapper for {repo_id}")
|
| 42 |
+
logger.info(f"Submodels: {self._submodel_repos}")
|
| 43 |
+
|
| 44 |
+
@property
|
| 45 |
+
def submodel_repos(self) -> List[str]:
|
| 46 |
+
"""Get list of submodel repository IDs."""
|
| 47 |
+
return self._submodel_repos
|
| 48 |
+
|
| 49 |
+
def load(self) -> None:
|
| 50 |
+
"""
|
| 51 |
+
Load the fusion predict function from the downloaded repository.
|
| 52 |
+
|
| 53 |
+
Dynamically imports predict.py and extracts the predict function.
|
| 54 |
+
"""
|
| 55 |
+
fusion_path = Path(self.local_path) / "predict.py"
|
| 56 |
+
|
| 57 |
+
if not fusion_path.exists():
|
| 58 |
+
raise ConfigurationError(
|
| 59 |
+
message=f"predict.py not found in {self.local_path}",
|
| 60 |
+
details={"repo_id": self.repo_id, "expected_path": str(fusion_path)}
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
try:
|
| 64 |
+
# Create a unique module name to avoid conflicts
|
| 65 |
+
module_name = f"hf_model_{self.name.replace('-', '_')}_fusion"
|
| 66 |
+
|
| 67 |
+
# Load the module dynamically
|
| 68 |
+
spec = importlib.util.spec_from_file_location(module_name, fusion_path)
|
| 69 |
+
if spec is None or spec.loader is None:
|
| 70 |
+
raise ConfigurationError(
|
| 71 |
+
message=f"Could not load spec for {fusion_path}",
|
| 72 |
+
details={"repo_id": self.repo_id}
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
module = importlib.util.module_from_spec(spec)
|
| 76 |
+
sys.modules[module_name] = module
|
| 77 |
+
spec.loader.exec_module(module)
|
| 78 |
+
|
| 79 |
+
# Get the predict function
|
| 80 |
+
if not hasattr(module, "predict"):
|
| 81 |
+
raise ConfigurationError(
|
| 82 |
+
message=f"predict.py does not have a 'predict' function",
|
| 83 |
+
details={"repo_id": self.repo_id}
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
self._predict_fn = module.predict
|
| 87 |
+
logger.info(f"Loaded fusion predict function from {self.repo_id}")
|
| 88 |
+
|
| 89 |
+
except ConfigurationError:
|
| 90 |
+
raise
|
| 91 |
+
except Exception as e:
|
| 92 |
+
logger.error(f"Failed to load fusion function from {self.repo_id}: {e}")
|
| 93 |
+
raise ConfigurationError(
|
| 94 |
+
message=f"Failed to load fusion model: {e}",
|
| 95 |
+
details={"repo_id": self.repo_id, "error": str(e)}
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
def predict(
|
| 99 |
+
self,
|
| 100 |
+
submodel_outputs: Dict[str, Dict[str, Any]],
|
| 101 |
+
**kwargs
|
| 102 |
+
) -> Dict[str, Any]:
|
| 103 |
+
"""
|
| 104 |
+
Run fusion prediction on submodel outputs.
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
submodel_outputs: Dictionary mapping submodel name to its prediction output
|
| 108 |
+
**kwargs: Additional arguments passed to the fusion function
|
| 109 |
+
|
| 110 |
+
Returns:
|
| 111 |
+
Standardized prediction dictionary with:
|
| 112 |
+
- pred_int: 0 or 1
|
| 113 |
+
- pred: "real" or "fake"
|
| 114 |
+
- prob_fake: float (average of pred_ints)
|
| 115 |
+
- meta: dict
|
| 116 |
+
"""
|
| 117 |
+
if self._predict_fn is None:
|
| 118 |
+
raise FusionError(
|
| 119 |
+
message="Fusion model not loaded",
|
| 120 |
+
details={"repo_id": self.repo_id}
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
try:
|
| 124 |
+
# Call the actual fusion predict function from the HF repo
|
| 125 |
+
result = self._predict_fn(submodel_outputs=submodel_outputs, **kwargs)
|
| 126 |
+
|
| 127 |
+
# Validate and standardize the output
|
| 128 |
+
standardized = self._standardize_output(result)
|
| 129 |
+
return standardized
|
| 130 |
+
|
| 131 |
+
except FusionError:
|
| 132 |
+
raise
|
| 133 |
+
except Exception as e:
|
| 134 |
+
logger.error(f"Fusion prediction failed for {self.repo_id}: {e}")
|
| 135 |
+
raise FusionError(
|
| 136 |
+
message=f"Fusion prediction failed: {e}",
|
| 137 |
+
details={"repo_id": self.repo_id, "error": str(e)}
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
def _standardize_output(self, result: Dict[str, Any]) -> Dict[str, Any]:
|
| 141 |
+
"""
|
| 142 |
+
Standardize the fusion output to ensure consistent format.
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
result: Raw fusion output
|
| 146 |
+
|
| 147 |
+
Returns:
|
| 148 |
+
Standardized dictionary
|
| 149 |
+
"""
|
| 150 |
+
pred_int = result.get("pred_int", 0)
|
| 151 |
+
|
| 152 |
+
# Ensure pred_int is 0 or 1
|
| 153 |
+
if pred_int not in (0, 1):
|
| 154 |
+
pred_int = 1 if pred_int > 0.5 else 0
|
| 155 |
+
|
| 156 |
+
# Generate pred label if not present
|
| 157 |
+
pred = result.get("pred")
|
| 158 |
+
if pred is None:
|
| 159 |
+
pred = "fake" if pred_int == 1 else "real"
|
| 160 |
+
|
| 161 |
+
# Generate prob_fake if not present
|
| 162 |
+
prob_fake = result.get("prob_fake")
|
| 163 |
+
if prob_fake is None:
|
| 164 |
+
prob_fake = float(pred_int)
|
| 165 |
+
|
| 166 |
+
return {
|
| 167 |
+
"pred_int": pred_int,
|
| 168 |
+
"pred": pred,
|
| 169 |
+
"prob_fake": float(prob_fake),
|
| 170 |
+
"meta": result.get("meta", {})
|
| 171 |
+
}
|
app/models/wrappers/dummy_random_wrapper.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Wrapper for dummy random submodels.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import importlib.util
|
| 6 |
+
import sys
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Any, Dict, Optional
|
| 9 |
+
|
| 10 |
+
from PIL import Image
|
| 11 |
+
|
| 12 |
+
from app.core.errors import InferenceError, ConfigurationError
|
| 13 |
+
from app.core.logging import get_logger
|
| 14 |
+
from app.models.wrappers.base_wrapper import BaseSubmodelWrapper
|
| 15 |
+
|
| 16 |
+
logger = get_logger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class DummyRandomWrapper(BaseSubmodelWrapper):
|
| 20 |
+
"""
|
| 21 |
+
Wrapper for dummy random prediction models.
|
| 22 |
+
|
| 23 |
+
These models are hosted on Hugging Face and contain a predict.py
|
| 24 |
+
with a predict() function that returns random predictions.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
repo_id: str,
|
| 30 |
+
config: Dict[str, Any],
|
| 31 |
+
local_path: str
|
| 32 |
+
):
|
| 33 |
+
"""
|
| 34 |
+
Initialize the wrapper.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
repo_id: Hugging Face repository ID (e.g., "DeepFakeDetector/test-random-a")
|
| 38 |
+
config: Configuration from config.json
|
| 39 |
+
local_path: Local path where the model files are stored
|
| 40 |
+
"""
|
| 41 |
+
super().__init__(repo_id, config, local_path)
|
| 42 |
+
logger.info(f"Initialized DummyRandomWrapper for {repo_id}")
|
| 43 |
+
|
| 44 |
+
def load(self) -> None:
|
| 45 |
+
"""
|
| 46 |
+
Load the predict function from the downloaded repository.
|
| 47 |
+
|
| 48 |
+
Dynamically imports predict.py and extracts the predict function.
|
| 49 |
+
"""
|
| 50 |
+
predict_path = Path(self.local_path) / "predict.py"
|
| 51 |
+
|
| 52 |
+
if not predict_path.exists():
|
| 53 |
+
raise ConfigurationError(
|
| 54 |
+
message=f"predict.py not found in {self.local_path}",
|
| 55 |
+
details={"repo_id": self.repo_id, "expected_path": str(predict_path)}
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
try:
|
| 59 |
+
# Create a unique module name to avoid conflicts
|
| 60 |
+
module_name = f"hf_model_{self.name.replace('-', '_')}_predict"
|
| 61 |
+
|
| 62 |
+
# Load the module dynamically
|
| 63 |
+
spec = importlib.util.spec_from_file_location(module_name, predict_path)
|
| 64 |
+
if spec is None or spec.loader is None:
|
| 65 |
+
raise ConfigurationError(
|
| 66 |
+
message=f"Could not load spec for {predict_path}",
|
| 67 |
+
details={"repo_id": self.repo_id}
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
module = importlib.util.module_from_spec(spec)
|
| 71 |
+
sys.modules[module_name] = module
|
| 72 |
+
spec.loader.exec_module(module)
|
| 73 |
+
|
| 74 |
+
# Get the predict function
|
| 75 |
+
if not hasattr(module, "predict"):
|
| 76 |
+
raise ConfigurationError(
|
| 77 |
+
message=f"predict.py does not have a 'predict' function",
|
| 78 |
+
details={"repo_id": self.repo_id}
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
self._predict_fn = module.predict
|
| 82 |
+
logger.info(f"Loaded predict function from {self.repo_id}")
|
| 83 |
+
|
| 84 |
+
except ConfigurationError:
|
| 85 |
+
raise
|
| 86 |
+
except Exception as e:
|
| 87 |
+
logger.error(f"Failed to load predict function from {self.repo_id}: {e}")
|
| 88 |
+
raise ConfigurationError(
|
| 89 |
+
message=f"Failed to load model: {e}",
|
| 90 |
+
details={"repo_id": self.repo_id, "error": str(e)}
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
def predict(
|
| 94 |
+
self,
|
| 95 |
+
image: Optional[Image.Image] = None,
|
| 96 |
+
image_bytes: Optional[bytes] = None,
|
| 97 |
+
**kwargs
|
| 98 |
+
) -> Dict[str, Any]:
|
| 99 |
+
"""
|
| 100 |
+
Run prediction on an image.
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
image: PIL Image object (optional for dummy model)
|
| 104 |
+
image_bytes: Raw image bytes (optional for dummy model)
|
| 105 |
+
**kwargs: Additional arguments passed to the model
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
Standardized prediction dictionary with:
|
| 109 |
+
- pred_int: 0 or 1
|
| 110 |
+
- pred: "real" or "fake"
|
| 111 |
+
- prob_fake: float
|
| 112 |
+
- meta: dict
|
| 113 |
+
"""
|
| 114 |
+
if self._predict_fn is None:
|
| 115 |
+
raise InferenceError(
|
| 116 |
+
message="Model not loaded",
|
| 117 |
+
details={"repo_id": self.repo_id}
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
try:
|
| 121 |
+
# Call the actual predict function from the HF repo
|
| 122 |
+
result = self._predict_fn(image_bytes=image_bytes, **kwargs)
|
| 123 |
+
|
| 124 |
+
# Validate and standardize the output
|
| 125 |
+
standardized = self._standardize_output(result)
|
| 126 |
+
return standardized
|
| 127 |
+
|
| 128 |
+
except InferenceError:
|
| 129 |
+
raise
|
| 130 |
+
except Exception as e:
|
| 131 |
+
logger.error(f"Prediction failed for {self.repo_id}: {e}")
|
| 132 |
+
raise InferenceError(
|
| 133 |
+
message=f"Prediction failed: {e}",
|
| 134 |
+
details={"repo_id": self.repo_id, "error": str(e)}
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
def _standardize_output(self, result: Dict[str, Any]) -> Dict[str, Any]:
|
| 138 |
+
"""
|
| 139 |
+
Standardize the model output to ensure consistent format.
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
result: Raw model output
|
| 143 |
+
|
| 144 |
+
Returns:
|
| 145 |
+
Standardized dictionary
|
| 146 |
+
"""
|
| 147 |
+
pred_int = result.get("pred_int", 0)
|
| 148 |
+
|
| 149 |
+
# Ensure pred_int is 0 or 1
|
| 150 |
+
if pred_int not in (0, 1):
|
| 151 |
+
pred_int = 1 if pred_int > 0.5 else 0
|
| 152 |
+
|
| 153 |
+
# Generate pred label if not present
|
| 154 |
+
pred = result.get("pred")
|
| 155 |
+
if pred is None:
|
| 156 |
+
pred = "fake" if pred_int == 1 else "real"
|
| 157 |
+
|
| 158 |
+
# Generate prob_fake if not present
|
| 159 |
+
prob_fake = result.get("prob_fake")
|
| 160 |
+
if prob_fake is None:
|
| 161 |
+
prob_fake = float(pred_int)
|
| 162 |
+
|
| 163 |
+
return {
|
| 164 |
+
"pred_int": pred_int,
|
| 165 |
+
"pred": pred,
|
| 166 |
+
"prob_fake": float(prob_fake),
|
| 167 |
+
"meta": result.get("meta", {})
|
| 168 |
+
}
|
app/models/wrappers/gradfield_cnn_wrapper.py
ADDED
|
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Wrapper for Gradient Field CNN submodel.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import math
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import Any, Dict, Optional, Tuple
|
| 13 |
+
from PIL import Image
|
| 14 |
+
from torchvision import transforms
|
| 15 |
+
|
| 16 |
+
from app.core.errors import InferenceError, ConfigurationError
|
| 17 |
+
from app.core.logging import get_logger
|
| 18 |
+
from app.models.wrappers.base_wrapper import BaseSubmodelWrapper
|
| 19 |
+
from app.services.explainability import heatmap_to_base64, compute_focus_summary
|
| 20 |
+
|
| 21 |
+
logger = get_logger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class CompactGradientNet(nn.Module):
|
| 25 |
+
"""
|
| 26 |
+
CNN for gradient field classification with discriminative features.
|
| 27 |
+
|
| 28 |
+
Input: Luminance image (1-channel)
|
| 29 |
+
Internal: Computes 6-channel gradient field [luminance, Gx, Gy, magnitude, angle, coherence]
|
| 30 |
+
Output: Logits and embeddings
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(self, depth=4, base_filters=32, dropout=0.3, embedding_dim=128):
|
| 34 |
+
super().__init__()
|
| 35 |
+
|
| 36 |
+
# Sobel kernels
|
| 37 |
+
sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]],
|
| 38 |
+
dtype=torch.float32).view(1, 1, 3, 3)
|
| 39 |
+
sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]],
|
| 40 |
+
dtype=torch.float32).view(1, 1, 3, 3)
|
| 41 |
+
self.register_buffer('sobel_x', sobel_x)
|
| 42 |
+
self.register_buffer('sobel_y', sobel_y)
|
| 43 |
+
|
| 44 |
+
# Gaussian kernel for structure tensor smoothing
|
| 45 |
+
gaussian = torch.tensor([[1, 4, 6, 4, 1], [4, 16, 24, 16, 4],
|
| 46 |
+
[6, 24, 36, 24, 6], [4, 16, 24, 16, 4],
|
| 47 |
+
[1, 4, 6, 4, 1]], dtype=torch.float32) / 256.0
|
| 48 |
+
self.register_buffer('gaussian', gaussian.view(1, 1, 5, 5))
|
| 49 |
+
|
| 50 |
+
# Input normalization and channel mixing
|
| 51 |
+
self.input_norm = nn.BatchNorm2d(6)
|
| 52 |
+
self.channel_mix = nn.Sequential(
|
| 53 |
+
nn.Conv2d(6, 6, kernel_size=1),
|
| 54 |
+
nn.ReLU()
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
# CNN layers
|
| 58 |
+
layers = []
|
| 59 |
+
in_ch = 6
|
| 60 |
+
for i in range(depth):
|
| 61 |
+
out_ch = base_filters * (2**i)
|
| 62 |
+
layers.extend([
|
| 63 |
+
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
|
| 64 |
+
nn.BatchNorm2d(out_ch),
|
| 65 |
+
nn.ReLU(),
|
| 66 |
+
nn.MaxPool2d(2)
|
| 67 |
+
])
|
| 68 |
+
if dropout > 0:
|
| 69 |
+
layers.append(nn.Dropout2d(dropout))
|
| 70 |
+
in_ch = out_ch
|
| 71 |
+
|
| 72 |
+
self.cnn = nn.Sequential(*layers)
|
| 73 |
+
self.global_pool = nn.AdaptiveAvgPool2d(1)
|
| 74 |
+
self.embedding = nn.Linear(out_ch, embedding_dim)
|
| 75 |
+
self.classifier = nn.Linear(embedding_dim, 1)
|
| 76 |
+
|
| 77 |
+
def compute_gradient_field(self, luminance):
|
| 78 |
+
"""Compute 6-channel gradient field on GPU (includes luminance)."""
|
| 79 |
+
G_x = F.conv2d(luminance, self.sobel_x, padding=1)
|
| 80 |
+
G_y = F.conv2d(luminance, self.sobel_y, padding=1)
|
| 81 |
+
|
| 82 |
+
magnitude = torch.sqrt(G_x**2 + G_y**2 + 1e-8)
|
| 83 |
+
angle = torch.atan2(G_y, G_x) / math.pi
|
| 84 |
+
|
| 85 |
+
# Structure tensor for coherence
|
| 86 |
+
Gxx, Gxy, Gyy = G_x * G_x, G_x * G_y, G_y * G_y
|
| 87 |
+
Sxx = F.conv2d(Gxx, self.gaussian, padding=2)
|
| 88 |
+
Sxy = F.conv2d(Gxy, self.gaussian, padding=2)
|
| 89 |
+
Syy = F.conv2d(Gyy, self.gaussian, padding=2)
|
| 90 |
+
|
| 91 |
+
trace = Sxx + Syy
|
| 92 |
+
det_term = torch.sqrt((Sxx - Syy)**2 + 4 * Sxy**2 + 1e-8)
|
| 93 |
+
lambda1, lambda2 = 0.5 * (trace + det_term), 0.5 * (trace - det_term)
|
| 94 |
+
coherence = ((lambda1 - lambda2) / (lambda1 + lambda2 + 1e-8))**2
|
| 95 |
+
|
| 96 |
+
magnitude_scaled = torch.log1p(magnitude * 10)
|
| 97 |
+
|
| 98 |
+
return torch.cat([luminance, G_x, G_y, magnitude_scaled, angle, coherence], dim=1)
|
| 99 |
+
|
| 100 |
+
def forward(self, luminance):
|
| 101 |
+
x = self.compute_gradient_field(luminance)
|
| 102 |
+
x = self.input_norm(x)
|
| 103 |
+
x = self.channel_mix(x)
|
| 104 |
+
x = self.cnn(x)
|
| 105 |
+
x = self.global_pool(x).flatten(1)
|
| 106 |
+
emb = self.embedding(x)
|
| 107 |
+
logit = self.classifier(emb)
|
| 108 |
+
return logit.squeeze(1), emb
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class GradfieldCNNWrapper(BaseSubmodelWrapper):
|
| 112 |
+
"""
|
| 113 |
+
Wrapper for Gradient Field CNN model.
|
| 114 |
+
|
| 115 |
+
Model expects 256x256 luminance images.
|
| 116 |
+
Internally computes Sobel gradients and other discriminative features.
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
# BT.709 luminance coefficients
|
| 120 |
+
R_COEFF = 0.2126
|
| 121 |
+
G_COEFF = 0.7152
|
| 122 |
+
B_COEFF = 0.0722
|
| 123 |
+
|
| 124 |
+
def __init__(
|
| 125 |
+
self,
|
| 126 |
+
repo_id: str,
|
| 127 |
+
config: Dict[str, Any],
|
| 128 |
+
local_path: str
|
| 129 |
+
):
|
| 130 |
+
super().__init__(repo_id, config, local_path)
|
| 131 |
+
self._model: Optional[nn.Module] = None
|
| 132 |
+
self._resize: Optional[transforms.Resize] = None
|
| 133 |
+
self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 134 |
+
self._threshold = config.get("threshold", 0.5)
|
| 135 |
+
logger.info(f"Initialized GradfieldCNNWrapper for {repo_id}")
|
| 136 |
+
|
| 137 |
+
def load(self) -> None:
|
| 138 |
+
"""Load the Gradient Field CNN model with trained weights."""
|
| 139 |
+
# Try different weight file names
|
| 140 |
+
weights_path = None
|
| 141 |
+
for fname in ["gradient_field_cnn_v3_finetuned.pth", "gradient_field_cnn_v2.pth", "weights.pt", "model.pth"]:
|
| 142 |
+
candidate = Path(self.local_path) / fname
|
| 143 |
+
if candidate.exists():
|
| 144 |
+
weights_path = candidate
|
| 145 |
+
break
|
| 146 |
+
|
| 147 |
+
preprocess_path = Path(self.local_path) / "preprocess.json"
|
| 148 |
+
|
| 149 |
+
if weights_path is None:
|
| 150 |
+
raise ConfigurationError(
|
| 151 |
+
message=f"No weights file found in {self.local_path}",
|
| 152 |
+
details={"repo_id": self.repo_id}
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
try:
|
| 156 |
+
# Load preprocessing config
|
| 157 |
+
preprocess_config = {}
|
| 158 |
+
if preprocess_path.exists():
|
| 159 |
+
with open(preprocess_path, "r") as f:
|
| 160 |
+
preprocess_config = json.load(f)
|
| 161 |
+
|
| 162 |
+
# Get input size (default 256 for gradient field)
|
| 163 |
+
input_size = preprocess_config.get("input_size", 256)
|
| 164 |
+
if isinstance(input_size, list):
|
| 165 |
+
input_size = input_size[0]
|
| 166 |
+
|
| 167 |
+
self._resize = transforms.Resize((input_size, input_size))
|
| 168 |
+
|
| 169 |
+
# Get model parameters from config
|
| 170 |
+
model_params = self.config.get("model_parameters", {})
|
| 171 |
+
depth = model_params.get("depth", 4)
|
| 172 |
+
base_filters = model_params.get("base_filters", 32)
|
| 173 |
+
dropout = model_params.get("dropout", 0.3)
|
| 174 |
+
embedding_dim = model_params.get("embedding_dim", 128)
|
| 175 |
+
|
| 176 |
+
# Create model
|
| 177 |
+
self._model = CompactGradientNet(
|
| 178 |
+
depth=depth,
|
| 179 |
+
base_filters=base_filters,
|
| 180 |
+
dropout=dropout,
|
| 181 |
+
embedding_dim=embedding_dim
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
# Load trained weights
|
| 185 |
+
# Note: weights_only=False needed because checkpoint contains numpy types
|
| 186 |
+
state_dict = torch.load(weights_path, map_location=self._device, weights_only=False)
|
| 187 |
+
|
| 188 |
+
# Handle different checkpoint formats
|
| 189 |
+
if isinstance(state_dict, dict):
|
| 190 |
+
if "model_state_dict" in state_dict:
|
| 191 |
+
state_dict = state_dict["model_state_dict"]
|
| 192 |
+
elif "state_dict" in state_dict:
|
| 193 |
+
state_dict = state_dict["state_dict"]
|
| 194 |
+
elif "model" in state_dict:
|
| 195 |
+
state_dict = state_dict["model"]
|
| 196 |
+
|
| 197 |
+
self._model.load_state_dict(state_dict)
|
| 198 |
+
self._model.to(self._device)
|
| 199 |
+
self._model.eval()
|
| 200 |
+
|
| 201 |
+
# Mark as loaded
|
| 202 |
+
self._predict_fn = self._run_inference
|
| 203 |
+
logger.info(f"Loaded Gradient Field CNN model from {self.repo_id}")
|
| 204 |
+
|
| 205 |
+
except ConfigurationError:
|
| 206 |
+
raise
|
| 207 |
+
except Exception as e:
|
| 208 |
+
logger.error(f"Failed to load Gradient Field CNN model: {e}")
|
| 209 |
+
raise ConfigurationError(
|
| 210 |
+
message=f"Failed to load model: {e}",
|
| 211 |
+
details={"repo_id": self.repo_id, "error": str(e)}
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
def _rgb_to_luminance(self, img_tensor: torch.Tensor) -> torch.Tensor:
|
| 215 |
+
"""
|
| 216 |
+
Convert RGB tensor to luminance using BT.709 coefficients.
|
| 217 |
+
|
| 218 |
+
Args:
|
| 219 |
+
img_tensor: RGB tensor of shape (3, H, W) with values in [0, 1]
|
| 220 |
+
|
| 221 |
+
Returns:
|
| 222 |
+
Luminance tensor of shape (1, H, W)
|
| 223 |
+
"""
|
| 224 |
+
luminance = (
|
| 225 |
+
self.R_COEFF * img_tensor[0] +
|
| 226 |
+
self.G_COEFF * img_tensor[1] +
|
| 227 |
+
self.B_COEFF * img_tensor[2]
|
| 228 |
+
)
|
| 229 |
+
return luminance.unsqueeze(0)
|
| 230 |
+
|
| 231 |
+
def _run_inference(
|
| 232 |
+
self,
|
| 233 |
+
luminance_tensor: torch.Tensor,
|
| 234 |
+
explain: bool = False
|
| 235 |
+
) -> Dict[str, Any]:
|
| 236 |
+
"""Run model inference on preprocessed luminance tensor."""
|
| 237 |
+
heatmap = None
|
| 238 |
+
|
| 239 |
+
if explain:
|
| 240 |
+
# Custom GradCAM implementation for single-logit binary model
|
| 241 |
+
# Using absolute CAM values to capture both positive and negative contributions
|
| 242 |
+
# Target the last Conv2d layer (cnn[-5])
|
| 243 |
+
target_layer = self._model.cnn[-5]
|
| 244 |
+
|
| 245 |
+
activations = None
|
| 246 |
+
gradients = None
|
| 247 |
+
|
| 248 |
+
def forward_hook(module, input, output):
|
| 249 |
+
nonlocal activations
|
| 250 |
+
activations = output.detach()
|
| 251 |
+
|
| 252 |
+
def backward_hook(module, grad_input, grad_output):
|
| 253 |
+
nonlocal gradients
|
| 254 |
+
gradients = grad_output[0].detach()
|
| 255 |
+
|
| 256 |
+
h_fwd = target_layer.register_forward_hook(forward_hook)
|
| 257 |
+
h_bwd = target_layer.register_full_backward_hook(backward_hook)
|
| 258 |
+
|
| 259 |
+
try:
|
| 260 |
+
# Forward pass with gradients
|
| 261 |
+
input_tensor = luminance_tensor.clone().requires_grad_(True)
|
| 262 |
+
logits, embedding = self._model(input_tensor)
|
| 263 |
+
prob_fake = torch.sigmoid(logits).item()
|
| 264 |
+
pred_int = 1 if prob_fake >= self._threshold else 0
|
| 265 |
+
|
| 266 |
+
# Backward pass
|
| 267 |
+
self._model.zero_grad()
|
| 268 |
+
logits.backward()
|
| 269 |
+
|
| 270 |
+
if gradients is not None and activations is not None:
|
| 271 |
+
# Compute Grad-CAM weights (global average pooled gradients)
|
| 272 |
+
weights = gradients.mean(dim=(2, 3), keepdim=True) # [1, C, 1, 1]
|
| 273 |
+
|
| 274 |
+
# Weighted combination of activation maps
|
| 275 |
+
cam = (weights * activations).sum(dim=1, keepdim=True) # [1, 1, H, W]
|
| 276 |
+
|
| 277 |
+
# Use absolute values instead of ReLU to capture all contributions
|
| 278 |
+
# This is important for models where negative gradients carry meaning
|
| 279 |
+
cam = torch.abs(cam)
|
| 280 |
+
|
| 281 |
+
# Normalize to [0, 1]
|
| 282 |
+
cam = cam - cam.min()
|
| 283 |
+
cam_max = cam.max()
|
| 284 |
+
if cam_max > 0:
|
| 285 |
+
cam = cam / cam_max
|
| 286 |
+
|
| 287 |
+
# Resize to output size (256x256)
|
| 288 |
+
cam = F.interpolate(
|
| 289 |
+
cam,
|
| 290 |
+
size=(256, 256),
|
| 291 |
+
mode='bilinear',
|
| 292 |
+
align_corners=False
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
heatmap = cam.squeeze().cpu().numpy()
|
| 296 |
+
else:
|
| 297 |
+
logger.warning("GradCAM: gradients or activations not captured")
|
| 298 |
+
heatmap = np.zeros((256, 256), dtype=np.float32)
|
| 299 |
+
|
| 300 |
+
finally:
|
| 301 |
+
h_fwd.remove()
|
| 302 |
+
h_bwd.remove()
|
| 303 |
+
else:
|
| 304 |
+
with torch.no_grad():
|
| 305 |
+
logits, embedding = self._model(luminance_tensor)
|
| 306 |
+
prob_fake = torch.sigmoid(logits).item()
|
| 307 |
+
pred_int = 1 if prob_fake >= self._threshold else 0
|
| 308 |
+
|
| 309 |
+
result = {
|
| 310 |
+
"logits": logits.detach().cpu().numpy().tolist() if hasattr(logits, 'detach') else logits.cpu().numpy().tolist(),
|
| 311 |
+
"prob_fake": prob_fake,
|
| 312 |
+
"pred_int": pred_int,
|
| 313 |
+
"embedding": embedding.detach().cpu().numpy().tolist() if explain else embedding.cpu().numpy().tolist()
|
| 314 |
+
}
|
| 315 |
+
|
| 316 |
+
if heatmap is not None:
|
| 317 |
+
result["heatmap"] = heatmap
|
| 318 |
+
|
| 319 |
+
return result
|
| 320 |
+
|
| 321 |
+
def predict(
|
| 322 |
+
self,
|
| 323 |
+
image: Optional[Image.Image] = None,
|
| 324 |
+
image_bytes: Optional[bytes] = None,
|
| 325 |
+
explain: bool = False,
|
| 326 |
+
**kwargs
|
| 327 |
+
) -> Dict[str, Any]:
|
| 328 |
+
"""
|
| 329 |
+
Run prediction on an image.
|
| 330 |
+
|
| 331 |
+
Args:
|
| 332 |
+
image: PIL Image object
|
| 333 |
+
image_bytes: Raw image bytes (will be converted to PIL Image)
|
| 334 |
+
explain: If True, compute GradCAM heatmap
|
| 335 |
+
|
| 336 |
+
Returns:
|
| 337 |
+
Standardized prediction dictionary with optional heatmap
|
| 338 |
+
"""
|
| 339 |
+
if self._model is None or self._resize is None:
|
| 340 |
+
raise InferenceError(
|
| 341 |
+
message="Model not loaded",
|
| 342 |
+
details={"repo_id": self.repo_id}
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
try:
|
| 346 |
+
# Convert bytes to PIL Image if needed
|
| 347 |
+
if image is None and image_bytes is not None:
|
| 348 |
+
import io
|
| 349 |
+
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
| 350 |
+
elif image is not None:
|
| 351 |
+
image = image.convert("RGB")
|
| 352 |
+
else:
|
| 353 |
+
raise InferenceError(
|
| 354 |
+
message="No image provided",
|
| 355 |
+
details={"repo_id": self.repo_id}
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
# Resize
|
| 359 |
+
image = self._resize(image)
|
| 360 |
+
|
| 361 |
+
# Convert to tensor
|
| 362 |
+
img_tensor = transforms.functional.to_tensor(image)
|
| 363 |
+
|
| 364 |
+
# Convert to luminance
|
| 365 |
+
luminance = self._rgb_to_luminance(img_tensor)
|
| 366 |
+
luminance = luminance.unsqueeze(0).to(self._device) # Add batch dim
|
| 367 |
+
|
| 368 |
+
# Run inference
|
| 369 |
+
result = self._run_inference(luminance, explain=explain)
|
| 370 |
+
|
| 371 |
+
# Standardize output
|
| 372 |
+
labels = self.config.get("labels", {"0": "real", "1": "fake"})
|
| 373 |
+
pred_int = result["pred_int"]
|
| 374 |
+
|
| 375 |
+
output = {
|
| 376 |
+
"pred_int": pred_int,
|
| 377 |
+
"pred": labels.get(str(pred_int), "unknown"),
|
| 378 |
+
"prob_fake": result["prob_fake"],
|
| 379 |
+
"meta": {
|
| 380 |
+
"model": self.name,
|
| 381 |
+
"threshold": self._threshold
|
| 382 |
+
}
|
| 383 |
+
}
|
| 384 |
+
|
| 385 |
+
# Add heatmap if requested
|
| 386 |
+
if explain and "heatmap" in result:
|
| 387 |
+
heatmap = result["heatmap"]
|
| 388 |
+
output["heatmap_base64"] = heatmap_to_base64(heatmap)
|
| 389 |
+
output["explainability_type"] = "grad_cam"
|
| 390 |
+
output["focus_summary"] = compute_focus_summary(heatmap) + " (edge-based analysis)"
|
| 391 |
+
|
| 392 |
+
return output
|
| 393 |
+
|
| 394 |
+
except InferenceError:
|
| 395 |
+
raise
|
| 396 |
+
except Exception as e:
|
| 397 |
+
logger.error(f"Prediction failed for {self.repo_id}: {e}")
|
| 398 |
+
raise InferenceError(
|
| 399 |
+
message=f"Prediction failed: {e}",
|
| 400 |
+
details={"repo_id": self.repo_id, "error": str(e)}
|
| 401 |
+
)
|
app/models/wrappers/logreg_fusion_wrapper.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Wrapper for logistic regression stacking fusion model.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import pickle
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Any, Dict, List
|
| 8 |
+
|
| 9 |
+
import joblib
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
from app.core.errors import FusionError, ConfigurationError
|
| 13 |
+
from app.core.logging import get_logger
|
| 14 |
+
from app.models.wrappers.base_wrapper import BaseFusionWrapper
|
| 15 |
+
|
| 16 |
+
logger = get_logger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class LogRegFusionWrapper(BaseFusionWrapper):
|
| 20 |
+
"""
|
| 21 |
+
Wrapper for probability stacking fusion with logistic regression.
|
| 22 |
+
|
| 23 |
+
This fusion model takes probability outputs from submodels,
|
| 24 |
+
stacks them into a feature vector, and runs them through a
|
| 25 |
+
trained logistic regression classifier.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
repo_id: str,
|
| 31 |
+
config: Dict[str, Any],
|
| 32 |
+
local_path: str
|
| 33 |
+
):
|
| 34 |
+
"""
|
| 35 |
+
Initialize the wrapper.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
repo_id: Hugging Face repository ID
|
| 39 |
+
config: Configuration from config.json
|
| 40 |
+
local_path: Local path where the model files are stored
|
| 41 |
+
"""
|
| 42 |
+
super().__init__(repo_id, config, local_path)
|
| 43 |
+
self._model = None
|
| 44 |
+
self._submodel_order: List[str] = config.get("submodel_order", [])
|
| 45 |
+
self._threshold: float = config.get("threshold", 0.5)
|
| 46 |
+
logger.info(f"Initialized LogRegFusionWrapper for {repo_id}")
|
| 47 |
+
logger.info(f"Submodel order: {self._submodel_order}")
|
| 48 |
+
|
| 49 |
+
@property
|
| 50 |
+
def submodel_repos(self) -> List[str]:
|
| 51 |
+
"""Get list of submodel repository IDs."""
|
| 52 |
+
return self.config.get("submodels", [])
|
| 53 |
+
|
| 54 |
+
def load(self) -> None:
|
| 55 |
+
"""
|
| 56 |
+
Load the logistic regression model from the downloaded repository.
|
| 57 |
+
|
| 58 |
+
Loads fusion_logreg.pkl using joblib (sklearn models are saved with joblib).
|
| 59 |
+
"""
|
| 60 |
+
model_path = Path(self.local_path) / "fusion_logreg.pkl"
|
| 61 |
+
|
| 62 |
+
if not model_path.exists():
|
| 63 |
+
raise ConfigurationError(
|
| 64 |
+
message=f"fusion_logreg.pkl not found in {self.local_path}",
|
| 65 |
+
details={"repo_id": self.repo_id, "expected_path": str(model_path)}
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
try:
|
| 69 |
+
# Use joblib for sklearn models instead of pickle
|
| 70 |
+
self._model = joblib.load(model_path)
|
| 71 |
+
logger.info(f"Loaded logistic regression fusion model from {self.repo_id}")
|
| 72 |
+
|
| 73 |
+
except Exception as e:
|
| 74 |
+
logger.error(f"Failed to load fusion model from {self.repo_id}: {e}")
|
| 75 |
+
raise ConfigurationError(
|
| 76 |
+
message=f"Failed to load fusion model: {e}",
|
| 77 |
+
details={"repo_id": self.repo_id, "error": str(e)}
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
def predict(
|
| 81 |
+
self,
|
| 82 |
+
submodel_outputs: Dict[str, Dict[str, Any]],
|
| 83 |
+
**kwargs
|
| 84 |
+
) -> Dict[str, Any]:
|
| 85 |
+
"""
|
| 86 |
+
Run fusion prediction on submodel outputs.
|
| 87 |
+
|
| 88 |
+
Stacks submodel probabilities in the correct order and runs
|
| 89 |
+
through the logistic regression classifier.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
submodel_outputs: Dictionary mapping submodel name to its prediction output
|
| 93 |
+
Each output must contain "prob_fake" key
|
| 94 |
+
**kwargs: Additional arguments (unused)
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
Standardized prediction dictionary with:
|
| 98 |
+
- pred_int: 0 or 1
|
| 99 |
+
- pred: "real" or "fake"
|
| 100 |
+
- prob_fake: float probability of being fake
|
| 101 |
+
- meta: dict with submodel probabilities
|
| 102 |
+
"""
|
| 103 |
+
if self._model is None:
|
| 104 |
+
raise FusionError(
|
| 105 |
+
message="Fusion model not loaded",
|
| 106 |
+
details={"repo_id": self.repo_id}
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
try:
|
| 110 |
+
# Stack submodel probabilities in the correct order
|
| 111 |
+
probs = []
|
| 112 |
+
for submodel_name in self._submodel_order:
|
| 113 |
+
if submodel_name not in submodel_outputs:
|
| 114 |
+
raise FusionError(
|
| 115 |
+
message=f"Missing output from submodel: {submodel_name}",
|
| 116 |
+
details={
|
| 117 |
+
"repo_id": self.repo_id,
|
| 118 |
+
"missing_submodel": submodel_name,
|
| 119 |
+
"available_submodels": list(submodel_outputs.keys())
|
| 120 |
+
}
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
output = submodel_outputs[submodel_name]
|
| 124 |
+
if "prob_fake" not in output:
|
| 125 |
+
raise FusionError(
|
| 126 |
+
message=f"Submodel output missing 'prob_fake': {submodel_name}",
|
| 127 |
+
details={
|
| 128 |
+
"repo_id": self.repo_id,
|
| 129 |
+
"submodel": submodel_name,
|
| 130 |
+
"output_keys": list(output.keys())
|
| 131 |
+
}
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
probs.append(output["prob_fake"])
|
| 135 |
+
|
| 136 |
+
# Convert to numpy array and reshape for sklearn
|
| 137 |
+
X = np.array(probs).reshape(1, -1)
|
| 138 |
+
|
| 139 |
+
# Get prediction and probability
|
| 140 |
+
prob_fake = float(self._model.predict_proba(X)[0, 1])
|
| 141 |
+
pred_int = 1 if prob_fake >= self._threshold else 0
|
| 142 |
+
pred = "fake" if pred_int == 1 else "real"
|
| 143 |
+
|
| 144 |
+
return {
|
| 145 |
+
"pred_int": pred_int,
|
| 146 |
+
"pred": pred,
|
| 147 |
+
"prob_fake": prob_fake,
|
| 148 |
+
"meta": {
|
| 149 |
+
"submodel_probs": dict(zip(self._submodel_order, probs)),
|
| 150 |
+
"threshold": self._threshold
|
| 151 |
+
}
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
except FusionError:
|
| 155 |
+
raise
|
| 156 |
+
except Exception as e:
|
| 157 |
+
logger.error(f"Fusion prediction failed for {self.repo_id}: {e}")
|
| 158 |
+
raise FusionError(
|
| 159 |
+
message=f"Fusion prediction failed: {e}",
|
| 160 |
+
details={"repo_id": self.repo_id, "error": str(e)}
|
| 161 |
+
)
|
app/models/wrappers/vit_base_wrapper.py
ADDED
|
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Wrapper for ViT Base submodel.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 12 |
+
from PIL import Image
|
| 13 |
+
from torchvision import transforms
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
import timm
|
| 17 |
+
TIMM_AVAILABLE = True
|
| 18 |
+
except ImportError:
|
| 19 |
+
TIMM_AVAILABLE = False
|
| 20 |
+
|
| 21 |
+
from app.core.errors import InferenceError, ConfigurationError
|
| 22 |
+
from app.core.logging import get_logger
|
| 23 |
+
from app.models.wrappers.base_wrapper import BaseSubmodelWrapper
|
| 24 |
+
from app.services.explainability import attention_rollout, heatmap_to_base64, compute_focus_summary
|
| 25 |
+
|
| 26 |
+
logger = get_logger(__name__)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class ViTWithMLPHead(nn.Module):
|
| 30 |
+
"""
|
| 31 |
+
ViT model wrapper matching the training checkpoint format.
|
| 32 |
+
|
| 33 |
+
The checkpoint was saved with:
|
| 34 |
+
- self.vit = timm ViT backbone (num_classes=0)
|
| 35 |
+
- self.fc1 = Linear(768, hidden)
|
| 36 |
+
- self.fc2 = Linear(hidden, num_classes)
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
def __init__(self, arch: str = "vit_base_patch16_224", num_classes: int = 2, hidden_dim: int = 512):
|
| 40 |
+
super().__init__()
|
| 41 |
+
# Create backbone without classification head
|
| 42 |
+
self.vit = timm.create_model(arch, pretrained=False, num_classes=0)
|
| 43 |
+
embed_dim = self.vit.embed_dim # 768 for ViT-Base
|
| 44 |
+
self.fc1 = nn.Linear(embed_dim, hidden_dim)
|
| 45 |
+
self.fc2 = nn.Linear(hidden_dim, num_classes)
|
| 46 |
+
|
| 47 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 48 |
+
features = self.vit(x) # [B, embed_dim]
|
| 49 |
+
x = F.relu(self.fc1(features))
|
| 50 |
+
logits = self.fc2(x)
|
| 51 |
+
return logits
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class ViTBaseWrapper(BaseSubmodelWrapper):
|
| 55 |
+
"""
|
| 56 |
+
Wrapper for ViT Base model (Vision Transformer).
|
| 57 |
+
|
| 58 |
+
Model expects 224x224 RGB images with ImageNet normalization.
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
def __init__(
|
| 62 |
+
self,
|
| 63 |
+
repo_id: str,
|
| 64 |
+
config: Dict[str, Any],
|
| 65 |
+
local_path: str
|
| 66 |
+
):
|
| 67 |
+
super().__init__(repo_id, config, local_path)
|
| 68 |
+
self._model: Optional[nn.Module] = None
|
| 69 |
+
self._transform: Optional[transforms.Compose] = None
|
| 70 |
+
self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 71 |
+
self._threshold = config.get("threshold", 0.5)
|
| 72 |
+
logger.info(f"Initialized ViTBaseWrapper for {repo_id}")
|
| 73 |
+
|
| 74 |
+
def load(self) -> None:
|
| 75 |
+
"""Load the ViT Base model with trained weights."""
|
| 76 |
+
if not TIMM_AVAILABLE:
|
| 77 |
+
raise ConfigurationError(
|
| 78 |
+
message="timm package not installed. Run: pip install timm",
|
| 79 |
+
details={"repo_id": self.repo_id}
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
weights_path = Path(self.local_path) / "deepfake_vit_finetuned_wildfake.pth"
|
| 83 |
+
preprocess_path = Path(self.local_path) / "preprocess.json"
|
| 84 |
+
|
| 85 |
+
if not weights_path.exists():
|
| 86 |
+
raise ConfigurationError(
|
| 87 |
+
message=f"deepfake_vit_finetuned_wildfake.pth not found in {self.local_path}",
|
| 88 |
+
details={"repo_id": self.repo_id, "expected_path": str(weights_path)}
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
try:
|
| 92 |
+
# Load preprocessing config
|
| 93 |
+
preprocess_config = {}
|
| 94 |
+
if preprocess_path.exists():
|
| 95 |
+
with open(preprocess_path, "r") as f:
|
| 96 |
+
preprocess_config = json.load(f)
|
| 97 |
+
|
| 98 |
+
# Build transform pipeline
|
| 99 |
+
input_size = preprocess_config.get("input_size", 224)
|
| 100 |
+
if isinstance(input_size, list):
|
| 101 |
+
input_size = input_size[0]
|
| 102 |
+
|
| 103 |
+
normalize_config = preprocess_config.get("normalize", {})
|
| 104 |
+
mean = normalize_config.get("mean", [0.485, 0.456, 0.406])
|
| 105 |
+
std = normalize_config.get("std", [0.229, 0.224, 0.225])
|
| 106 |
+
|
| 107 |
+
# Use bicubic interpolation as specified
|
| 108 |
+
interpolation = preprocess_config.get("interpolation", "bicubic")
|
| 109 |
+
interp_mode = transforms.InterpolationMode.BICUBIC if interpolation == "bicubic" else transforms.InterpolationMode.BILINEAR
|
| 110 |
+
|
| 111 |
+
self._transform = transforms.Compose([
|
| 112 |
+
transforms.Resize((input_size, input_size), interpolation=interp_mode),
|
| 113 |
+
transforms.ToTensor(),
|
| 114 |
+
transforms.Normalize(mean=mean, std=std)
|
| 115 |
+
])
|
| 116 |
+
|
| 117 |
+
# Create model architecture matching the training checkpoint format
|
| 118 |
+
arch = self.config.get("arch", "vit_base_patch16_224")
|
| 119 |
+
num_classes = self.config.get("num_classes", 2)
|
| 120 |
+
# MLP hidden dim is 512 per training notebook (fc1: 768->512, fc2: 512->2)
|
| 121 |
+
# Note: config.hidden_dim (768) is ViT embedding dim, not MLP hidden dim
|
| 122 |
+
mlp_hidden_dim = self.config.get("mlp_hidden_dim", 512)
|
| 123 |
+
|
| 124 |
+
# Use custom wrapper that matches checkpoint structure (vit.* + fc1/fc2)
|
| 125 |
+
self._model = ViTWithMLPHead(arch=arch, num_classes=num_classes, hidden_dim=mlp_hidden_dim)
|
| 126 |
+
|
| 127 |
+
# Load trained weights
|
| 128 |
+
checkpoint = torch.load(weights_path, map_location=self._device, weights_only=False)
|
| 129 |
+
|
| 130 |
+
# Handle training checkpoint format (has "model", "optimizer_state", "epoch" keys)
|
| 131 |
+
if isinstance(checkpoint, dict) and "model" in checkpoint:
|
| 132 |
+
state_dict = checkpoint["model"]
|
| 133 |
+
else:
|
| 134 |
+
state_dict = checkpoint
|
| 135 |
+
|
| 136 |
+
self._model.load_state_dict(state_dict)
|
| 137 |
+
self._model.to(self._device)
|
| 138 |
+
self._model.eval()
|
| 139 |
+
|
| 140 |
+
# Mark as loaded
|
| 141 |
+
self._predict_fn = self._run_inference
|
| 142 |
+
logger.info(f"Loaded ViT Base model from {self.repo_id}")
|
| 143 |
+
|
| 144 |
+
except ConfigurationError:
|
| 145 |
+
raise
|
| 146 |
+
except Exception as e:
|
| 147 |
+
logger.error(f"Failed to load ViT Base model: {e}")
|
| 148 |
+
raise ConfigurationError(
|
| 149 |
+
message=f"Failed to load model: {e}",
|
| 150 |
+
details={"repo_id": self.repo_id, "error": str(e)}
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
def _run_inference(
|
| 154 |
+
self,
|
| 155 |
+
image_tensor: torch.Tensor,
|
| 156 |
+
explain: bool = False
|
| 157 |
+
) -> Dict[str, Any]:
|
| 158 |
+
"""Run model inference on preprocessed tensor."""
|
| 159 |
+
heatmap = None
|
| 160 |
+
|
| 161 |
+
if explain:
|
| 162 |
+
# Collect attention weights from all blocks
|
| 163 |
+
attentions: List[torch.Tensor] = []
|
| 164 |
+
handles = []
|
| 165 |
+
|
| 166 |
+
def get_attention_hook(module, input, output):
|
| 167 |
+
# For timm ViT, the attention forward returns (attn @ v)
|
| 168 |
+
# We need to hook into the softmax to get raw attention weights
|
| 169 |
+
# Alternative: access module's internal attn variable if available
|
| 170 |
+
pass
|
| 171 |
+
|
| 172 |
+
# Hook into attention modules to capture weights
|
| 173 |
+
# timm ViT blocks structure: blocks[i].attn
|
| 174 |
+
# We'll use a forward hook that computes attention manually
|
| 175 |
+
def create_attn_hook():
|
| 176 |
+
stored_attn = []
|
| 177 |
+
|
| 178 |
+
def hook(module, inputs, outputs):
|
| 179 |
+
# Get q, k from the module's forward computation
|
| 180 |
+
# inputs[0] is x of shape [B, N, C]
|
| 181 |
+
x = inputs[0]
|
| 182 |
+
B, N, C = x.shape
|
| 183 |
+
|
| 184 |
+
# Access the attention module's parameters
|
| 185 |
+
qkv = module.qkv(x) # [B, N, 3*dim]
|
| 186 |
+
qkv = qkv.reshape(B, N, 3, module.num_heads, C // module.num_heads)
|
| 187 |
+
qkv = qkv.permute(2, 0, 3, 1, 4) # [3, B, heads, N, dim_head]
|
| 188 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
| 189 |
+
|
| 190 |
+
# Compute attention weights
|
| 191 |
+
scale = (C // module.num_heads) ** -0.5
|
| 192 |
+
attn = (q @ k.transpose(-2, -1)) * scale
|
| 193 |
+
attn = attn.softmax(dim=-1) # [B, heads, N, N]
|
| 194 |
+
|
| 195 |
+
# Average over heads
|
| 196 |
+
attn_avg = attn.mean(dim=1) # [B, N, N]
|
| 197 |
+
stored_attn.append(attn_avg.detach())
|
| 198 |
+
|
| 199 |
+
return hook, stored_attn
|
| 200 |
+
|
| 201 |
+
all_stored_attns = []
|
| 202 |
+
for block in self._model.vit.blocks:
|
| 203 |
+
hook_fn, stored = create_attn_hook()
|
| 204 |
+
all_stored_attns.append(stored)
|
| 205 |
+
handle = block.attn.register_forward_hook(hook_fn)
|
| 206 |
+
handles.append(handle)
|
| 207 |
+
|
| 208 |
+
try:
|
| 209 |
+
with torch.no_grad():
|
| 210 |
+
logits = self._model(image_tensor)
|
| 211 |
+
probs = F.softmax(logits, dim=1)
|
| 212 |
+
prob_fake = probs[0, 1].item()
|
| 213 |
+
pred_int = 1 if prob_fake >= self._threshold else 0
|
| 214 |
+
|
| 215 |
+
# Get attention from hooks
|
| 216 |
+
attention_list = [stored[0] for stored in all_stored_attns if len(stored) > 0]
|
| 217 |
+
|
| 218 |
+
if attention_list:
|
| 219 |
+
# Stack: [num_layers, B, N, N]
|
| 220 |
+
attention_stack = torch.stack(attention_list, dim=0)
|
| 221 |
+
# Compute rollout - returns (grid_size, grid_size) heatmap
|
| 222 |
+
attention_map = attention_rollout(
|
| 223 |
+
attention_stack[:, 0], # [num_layers, N, N]
|
| 224 |
+
head_fusion="mean", # Already averaged
|
| 225 |
+
discard_ratio=0.0,
|
| 226 |
+
num_prefix_tokens=1 # ViT has 1 CLS token
|
| 227 |
+
) # Returns (14, 14) for ViT-Base
|
| 228 |
+
|
| 229 |
+
# Resize to image size
|
| 230 |
+
from PIL import Image as PILImage
|
| 231 |
+
heatmap_img = PILImage.fromarray(
|
| 232 |
+
(attention_map * 255).astype(np.uint8)
|
| 233 |
+
).resize((224, 224), PILImage.BILINEAR)
|
| 234 |
+
heatmap = np.array(heatmap_img).astype(np.float32) / 255.0
|
| 235 |
+
|
| 236 |
+
finally:
|
| 237 |
+
for handle in handles:
|
| 238 |
+
handle.remove()
|
| 239 |
+
else:
|
| 240 |
+
with torch.no_grad():
|
| 241 |
+
logits = self._model(image_tensor)
|
| 242 |
+
probs = F.softmax(logits, dim=1)
|
| 243 |
+
prob_fake = probs[0, 1].item()
|
| 244 |
+
pred_int = 1 if prob_fake >= self._threshold else 0
|
| 245 |
+
|
| 246 |
+
result = {
|
| 247 |
+
"logits": logits[0].cpu().numpy().tolist(),
|
| 248 |
+
"prob_fake": prob_fake,
|
| 249 |
+
"pred_int": pred_int
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
if heatmap is not None:
|
| 253 |
+
result["heatmap"] = heatmap
|
| 254 |
+
|
| 255 |
+
return result
|
| 256 |
+
|
| 257 |
+
def predict(
|
| 258 |
+
self,
|
| 259 |
+
image: Optional[Image.Image] = None,
|
| 260 |
+
image_bytes: Optional[bytes] = None,
|
| 261 |
+
explain: bool = False,
|
| 262 |
+
**kwargs
|
| 263 |
+
) -> Dict[str, Any]:
|
| 264 |
+
"""
|
| 265 |
+
Run prediction on an image.
|
| 266 |
+
|
| 267 |
+
Args:
|
| 268 |
+
image: PIL Image object
|
| 269 |
+
image_bytes: Raw image bytes (will be converted to PIL Image)
|
| 270 |
+
explain: If True, compute attention rollout heatmap
|
| 271 |
+
|
| 272 |
+
Returns:
|
| 273 |
+
Standardized prediction dictionary with optional heatmap
|
| 274 |
+
"""
|
| 275 |
+
if self._model is None or self._transform is None:
|
| 276 |
+
raise InferenceError(
|
| 277 |
+
message="Model not loaded",
|
| 278 |
+
details={"repo_id": self.repo_id}
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
try:
|
| 282 |
+
# Convert bytes to PIL Image if needed
|
| 283 |
+
if image is None and image_bytes is not None:
|
| 284 |
+
import io
|
| 285 |
+
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
| 286 |
+
elif image is not None:
|
| 287 |
+
image = image.convert("RGB")
|
| 288 |
+
else:
|
| 289 |
+
raise InferenceError(
|
| 290 |
+
message="No image provided",
|
| 291 |
+
details={"repo_id": self.repo_id}
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
# Preprocess
|
| 295 |
+
image_tensor = self._transform(image).unsqueeze(0).to(self._device)
|
| 296 |
+
|
| 297 |
+
# Run inference
|
| 298 |
+
result = self._run_inference(image_tensor, explain=explain)
|
| 299 |
+
|
| 300 |
+
# Standardize output
|
| 301 |
+
labels = self.config.get("labels", {"0": "real", "1": "fake"})
|
| 302 |
+
pred_int = result["pred_int"]
|
| 303 |
+
|
| 304 |
+
output = {
|
| 305 |
+
"pred_int": pred_int,
|
| 306 |
+
"pred": labels.get(str(pred_int), "unknown"),
|
| 307 |
+
"prob_fake": result["prob_fake"],
|
| 308 |
+
"meta": {
|
| 309 |
+
"model": self.name,
|
| 310 |
+
"threshold": self._threshold,
|
| 311 |
+
"logits": result["logits"]
|
| 312 |
+
}
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
# Add heatmap if requested
|
| 316 |
+
if explain and "heatmap" in result:
|
| 317 |
+
heatmap = result["heatmap"]
|
| 318 |
+
output["heatmap_base64"] = heatmap_to_base64(heatmap)
|
| 319 |
+
output["explainability_type"] = "attention_rollout"
|
| 320 |
+
output["focus_summary"] = compute_focus_summary(heatmap)
|
| 321 |
+
|
| 322 |
+
return output
|
| 323 |
+
|
| 324 |
+
except InferenceError:
|
| 325 |
+
raise
|
| 326 |
+
except Exception as e:
|
| 327 |
+
logger.error(f"Prediction failed for {self.repo_id}: {e}")
|
| 328 |
+
raise InferenceError(
|
| 329 |
+
message=f"Prediction failed: {e}",
|
| 330 |
+
details={"repo_id": self.repo_id, "error": str(e)}
|
| 331 |
+
)
|
app/schemas/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Schemas module
|
app/schemas/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (196 Bytes). View file
|
|
|
app/schemas/__pycache__/models.cpython-312.pyc
ADDED
|
Binary file (2.74 kB). View file
|
|
|
app/schemas/__pycache__/predict.cpython-312.pyc
ADDED
|
Binary file (8.17 kB). View file
|
|
|
app/schemas/models.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Pydantic schemas for model-related endpoints.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from typing import Dict, List, Literal, Optional, Any
|
| 6 |
+
from pydantic import BaseModel, Field
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class ModelInfo(BaseModel):
|
| 10 |
+
"""Information about a loaded model."""
|
| 11 |
+
|
| 12 |
+
repo_id: str = Field(..., description="Hugging Face repository ID")
|
| 13 |
+
name: str = Field(..., description="Short name of the model")
|
| 14 |
+
model_type: Literal["submodel", "fusion"] = Field(
|
| 15 |
+
...,
|
| 16 |
+
description="Type of model"
|
| 17 |
+
)
|
| 18 |
+
config: Optional[Dict[str, Any]] = Field(
|
| 19 |
+
None,
|
| 20 |
+
description="Model configuration from config.json"
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class ModelsListResponse(BaseModel):
|
| 25 |
+
"""Response schema for listing models."""
|
| 26 |
+
|
| 27 |
+
fusion: Optional[ModelInfo] = Field(
|
| 28 |
+
None,
|
| 29 |
+
description="Fusion model information"
|
| 30 |
+
)
|
| 31 |
+
submodels: List[ModelInfo] = Field(
|
| 32 |
+
default_factory=list,
|
| 33 |
+
description="List of loaded submodels"
|
| 34 |
+
)
|
| 35 |
+
total_count: int = Field(..., description="Total number of loaded models")
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class HealthResponse(BaseModel):
|
| 39 |
+
"""Response schema for health check."""
|
| 40 |
+
|
| 41 |
+
status: Literal["ok", "error"] = Field(..., description="Health status")
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class ReadyResponse(BaseModel):
|
| 45 |
+
"""Response schema for readiness check."""
|
| 46 |
+
|
| 47 |
+
status: Literal["ready", "not_ready"] = Field(..., description="Readiness status")
|
| 48 |
+
models_loaded: bool = Field(..., description="Whether models are loaded")
|
| 49 |
+
fusion_repo: Optional[str] = Field(None, description="Fusion repository ID")
|
| 50 |
+
submodels: List[str] = Field(
|
| 51 |
+
default_factory=list,
|
| 52 |
+
description="List of loaded submodel repository IDs"
|
| 53 |
+
)
|