Upload 37 files
Browse files- README_HF.md +73 -0
- app.py +222 -0
- requirements-hf.txt +20 -0
- src/__init__.py +1 -0
- src/__pycache__/__init__.cpython-313.pyc +0 -0
- src/api/__pycache__/main.cpython-313.pyc +0 -0
- src/api/main.py +461 -0
- src/api/schemas.py +280 -0
- src/core/__init__.py +5 -0
- src/core/__pycache__/__init__.cpython-313.pyc +0 -0
- src/core/__pycache__/config.cpython-313.pyc +0 -0
- src/core/__pycache__/logging.cpython-313.pyc +0 -0
- src/core/cache.py +238 -0
- src/core/config.py +233 -0
- src/core/exceptions.py +130 -0
- src/core/logging.py +169 -0
- src/core/middleware.py +210 -0
- src/core/security.py +285 -0
- src/db/__init__.py +29 -0
- src/db/models.py +185 -0
- src/detection/__init__.py +7 -0
- src/detection/__pycache__/__init__.cpython-313.pyc +0 -0
- src/detection/__pycache__/ai_text_detector.cpython-313.pyc +0 -0
- src/detection/__pycache__/anomaly_detector.cpython-313.pyc +0 -0
- src/detection/__pycache__/deepfake_detector.cpython-313.pyc +0 -0
- src/detection/ai_text_detector.py +402 -0
- src/detection/anomaly_detector.py +440 -0
- src/detection/deepfake_detector.py +431 -0
- src/models/__init__.py +1 -0
- src/models/__pycache__/__init__.cpython-313.pyc +0 -0
- src/training/train_deepfake.py +349 -0
- src/utils/__init__.py +1 -0
- src/utils/__pycache__/__init__.cpython-313.pyc +0 -0
- src/utils/__pycache__/face_detection.cpython-313.pyc +0 -0
- src/utils/__pycache__/preprocessing.cpython-313.pyc +0 -0
- src/utils/face_detection.py +36 -0
- src/utils/preprocessing.py +30 -0
README_HF.md
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Multimodal Misinformation Detection
|
| 3 |
+
emoji: 🔍
|
| 4 |
+
colorFrom: red
|
| 5 |
+
colorTo: blue
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 4.0.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
# 🔍 Multimodal Misinformation Detection System
|
| 14 |
+
|
| 15 |
+
**Detect AI-generated text, deepfake images, and coordinated disinformation campaigns using deep learning.**
|
| 16 |
+
|
| 17 |
+
## 🚀 Features
|
| 18 |
+
|
| 19 |
+
- **Text Analysis**: Identify AI-generated content from GPT, ChatGPT, and other LLMs
|
| 20 |
+
- **Image Analysis**: Detect deepfake and manipulated images
|
| 21 |
+
- **Real-time Processing**: Get results in under 2 seconds
|
| 22 |
+
- **High Accuracy**: 93-95% detection accuracy on benchmark datasets
|
| 23 |
+
|
| 24 |
+
## 🎯 Use Cases
|
| 25 |
+
|
| 26 |
+
- Social media content moderation
|
| 27 |
+
- News verification and fact-checking
|
| 28 |
+
- Academic integrity monitoring
|
| 29 |
+
- Digital forensics investigation
|
| 30 |
+
|
| 31 |
+
## 🛠️ Technology
|
| 32 |
+
|
| 33 |
+
- **Models**: EfficientNet-B4, RoBERTa-base, GPT-2
|
| 34 |
+
- **Frameworks**: PyTorch, Transformers, Gradio
|
| 35 |
+
- **Detection**: Face analysis, artifact detection, perplexity scoring
|
| 36 |
+
|
| 37 |
+
## 📊 Performance
|
| 38 |
+
|
| 39 |
+
| Task | Accuracy | Speed |
|
| 40 |
+
|------|----------|-------|
|
| 41 |
+
| Text Detection | 95% | <1s |
|
| 42 |
+
| Image Detection | 93% | <2s |
|
| 43 |
+
| Video Analysis | 91% | ~5s |
|
| 44 |
+
|
| 45 |
+
## 💡 How It Works
|
| 46 |
+
|
| 47 |
+
### Text Analysis
|
| 48 |
+
1. Analyzes writing patterns and vocabulary
|
| 49 |
+
2. Calculates perplexity using GPT-2
|
| 50 |
+
3. Classifies as human or AI-generated
|
| 51 |
+
4. Provides confidence score and explanation
|
| 52 |
+
|
| 53 |
+
### Image Analysis
|
| 54 |
+
1. Detects faces in the image
|
| 55 |
+
2. Analyzes facial features for manipulation
|
| 56 |
+
3. Identifies compression artifacts
|
| 57 |
+
4. Classifies as authentic or deepfake
|
| 58 |
+
|
| 59 |
+
## 🔗 Links
|
| 60 |
+
|
| 61 |
+
- [GitHub Repository](https://github.com/YOUR_USERNAME/multimodal-misinformation-detection)
|
| 62 |
+
- [API Documentation](https://github.com/YOUR_USERNAME/multimodal-misinformation-detection#api)
|
| 63 |
+
- [Technical Paper](https://github.com/YOUR_USERNAME/multimodal-misinformation-detection/blob/main/ARCHITECTURE.md)
|
| 64 |
+
|
| 65 |
+
## 👤 Author
|
| 66 |
+
|
| 67 |
+
Built by **Shreyas Gosavi** for Google DeepMind Research Engineer application.
|
| 68 |
+
|
| 69 |
+
Addressing the challenge of information quality and online misinformation through multimodal AI detection.
|
| 70 |
+
|
| 71 |
+
## 📝 License
|
| 72 |
+
|
| 73 |
+
MIT License - See LICENSE file for details
|
app.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gradio Interface for Multimodal Misinformation Detection
|
| 3 |
+
Hugging Face Spaces Deployment
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import gradio as gr
|
| 7 |
+
import numpy as np
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import sys
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
# Add src to path
|
| 13 |
+
sys.path.append(str(Path(__file__).parent / "src"))
|
| 14 |
+
|
| 15 |
+
from detection.deepfake_detector import DeepfakeDetector
|
| 16 |
+
from detection.ai_text_detector import AITextDetector
|
| 17 |
+
|
| 18 |
+
# Initialize detectors
|
| 19 |
+
print("Loading models...")
|
| 20 |
+
deepfake_detector = DeepfakeDetector()
|
| 21 |
+
ai_text_detector = AITextDetector()
|
| 22 |
+
print("Models loaded!")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def analyze_text(text):
|
| 26 |
+
"""Analyze text for AI generation."""
|
| 27 |
+
if not text or len(text.strip()) < 10:
|
| 28 |
+
return "⚠️ Please enter at least 10 characters of text."
|
| 29 |
+
|
| 30 |
+
result = ai_text_detector.detect(text)
|
| 31 |
+
|
| 32 |
+
verdict = result['verdict']
|
| 33 |
+
confidence = result['confidence']
|
| 34 |
+
|
| 35 |
+
# Format output
|
| 36 |
+
if verdict == "AI_GENERATED":
|
| 37 |
+
emoji = "🤖"
|
| 38 |
+
color = "red"
|
| 39 |
+
status = f"**AI-GENERATED** (Confidence: {confidence:.1%})"
|
| 40 |
+
elif verdict == "HUMAN_WRITTEN":
|
| 41 |
+
emoji = "✅"
|
| 42 |
+
color = "green"
|
| 43 |
+
status = f"**HUMAN-WRITTEN** (Confidence: {confidence:.1%})"
|
| 44 |
+
else:
|
| 45 |
+
emoji = "❓"
|
| 46 |
+
color = "orange"
|
| 47 |
+
status = f"**UNCERTAIN** (Confidence: {confidence:.1%})"
|
| 48 |
+
|
| 49 |
+
output = f"""
|
| 50 |
+
### {emoji} Detection Result
|
| 51 |
+
|
| 52 |
+
**Status:** {status}
|
| 53 |
+
|
| 54 |
+
**Explanation:** {result['explanation']}
|
| 55 |
+
|
| 56 |
+
**Perplexity Score:** {result.get('perplexity', 'N/A')}
|
| 57 |
+
|
| 58 |
+
---
|
| 59 |
+
*Lower perplexity often indicates AI-generated content*
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
return output
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def analyze_image(image):
|
| 66 |
+
"""Analyze image for deepfakes."""
|
| 67 |
+
if image is None:
|
| 68 |
+
return "⚠️ Please upload an image."
|
| 69 |
+
|
| 70 |
+
# Convert to numpy array if needed
|
| 71 |
+
if isinstance(image, Image.Image):
|
| 72 |
+
image = np.array(image)
|
| 73 |
+
|
| 74 |
+
result = deepfake_detector.detect(image)
|
| 75 |
+
|
| 76 |
+
verdict = result['verdict']
|
| 77 |
+
confidence = result.get('confidence', 0)
|
| 78 |
+
|
| 79 |
+
# Format output
|
| 80 |
+
if verdict == "FAKE":
|
| 81 |
+
emoji = "⚠️"
|
| 82 |
+
color = "red"
|
| 83 |
+
status = f"**DEEPFAKE DETECTED** (Confidence: {confidence:.1%})"
|
| 84 |
+
elif verdict == "REAL":
|
| 85 |
+
emoji = "✅"
|
| 86 |
+
color = "green"
|
| 87 |
+
status = f"**AUTHENTIC** (Confidence: {confidence:.1%})"
|
| 88 |
+
elif verdict == "NO_FACE_DETECTED":
|
| 89 |
+
emoji = "👤"
|
| 90 |
+
color = "orange"
|
| 91 |
+
status = "**NO FACE DETECTED**"
|
| 92 |
+
else:
|
| 93 |
+
emoji = "❓"
|
| 94 |
+
color = "orange"
|
| 95 |
+
status = f"**UNCERTAIN** (Confidence: {confidence:.1%})"
|
| 96 |
+
|
| 97 |
+
faces = result.get('faces_analyzed', 0)
|
| 98 |
+
artifacts = result.get('artifacts_detected', [])
|
| 99 |
+
|
| 100 |
+
output = f"""
|
| 101 |
+
### {emoji} Detection Result
|
| 102 |
+
|
| 103 |
+
**Status:** {status}
|
| 104 |
+
|
| 105 |
+
**Faces Analyzed:** {faces}
|
| 106 |
+
|
| 107 |
+
**Explanation:** {result['explanation']}
|
| 108 |
+
|
| 109 |
+
**Artifacts Detected:** {', '.join(artifacts) if artifacts else 'None'}
|
| 110 |
+
|
| 111 |
+
---
|
| 112 |
+
*Analysis based on facial features, artifacts, and neural network patterns*
|
| 113 |
+
"""
|
| 114 |
+
|
| 115 |
+
return output
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
# Create Gradio interface
|
| 119 |
+
with gr.Blocks(theme=gr.themes.Soft(), title="Misinformation Detector") as demo:
|
| 120 |
+
gr.Markdown("""
|
| 121 |
+
# 🔍 Multimodal Misinformation Detection System
|
| 122 |
+
|
| 123 |
+
**Powered by Deep Learning | Built for Google DeepMind Application**
|
| 124 |
+
|
| 125 |
+
This system detects:
|
| 126 |
+
- 🤖 AI-generated text (GPT, ChatGPT, etc.)
|
| 127 |
+
- 🎭 Deepfake images (face manipulation)
|
| 128 |
+
- 📊 Coordinated disinformation campaigns
|
| 129 |
+
|
| 130 |
+
---
|
| 131 |
+
""")
|
| 132 |
+
|
| 133 |
+
with gr.Tabs():
|
| 134 |
+
# Text Analysis Tab
|
| 135 |
+
with gr.Tab("📝 Text Analysis"):
|
| 136 |
+
gr.Markdown("### Detect AI-Generated Text")
|
| 137 |
+
gr.Markdown("*Analyzes writing patterns to identify content from GPT, ChatGPT, and other LLMs*")
|
| 138 |
+
|
| 139 |
+
with gr.Row():
|
| 140 |
+
with gr.Column():
|
| 141 |
+
text_input = gr.Textbox(
|
| 142 |
+
label="Enter Text to Analyze",
|
| 143 |
+
placeholder="Paste any text here (minimum 10 characters)...",
|
| 144 |
+
lines=8
|
| 145 |
+
)
|
| 146 |
+
text_button = gr.Button("🔍 Analyze Text", variant="primary")
|
| 147 |
+
|
| 148 |
+
with gr.Column():
|
| 149 |
+
text_output = gr.Markdown(label="Analysis Result")
|
| 150 |
+
|
| 151 |
+
gr.Examples(
|
| 152 |
+
examples=[
|
| 153 |
+
["The quick brown fox jumps over the lazy dog. This is a simple test sentence written by a human."],
|
| 154 |
+
["Artificial intelligence represents a paradigm shift in computational methodologies, leveraging neural architectures to facilitate autonomous decision-making processes across diverse domains."],
|
| 155 |
+
["I went to the store yesterday and bought some groceries. The weather was nice, so I walked instead of driving."],
|
| 156 |
+
],
|
| 157 |
+
inputs=text_input,
|
| 158 |
+
label="Example Texts"
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
# Image Analysis Tab
|
| 162 |
+
with gr.Tab("🖼️ Image Analysis"):
|
| 163 |
+
gr.Markdown("### Detect Deepfake Images")
|
| 164 |
+
gr.Markdown("*Analyzes facial features and manipulation artifacts to identify synthetic media*")
|
| 165 |
+
|
| 166 |
+
with gr.Row():
|
| 167 |
+
with gr.Column():
|
| 168 |
+
image_input = gr.Image(
|
| 169 |
+
label="Upload Image",
|
| 170 |
+
type="numpy"
|
| 171 |
+
)
|
| 172 |
+
image_button = gr.Button("🔍 Analyze Image", variant="primary")
|
| 173 |
+
|
| 174 |
+
with gr.Column():
|
| 175 |
+
image_output = gr.Markdown(label="Analysis Result")
|
| 176 |
+
|
| 177 |
+
gr.Markdown("""
|
| 178 |
+
**Tips:**
|
| 179 |
+
- Upload images with clear, visible faces
|
| 180 |
+
- Works best with forward-facing portraits
|
| 181 |
+
- Supports JPG, PNG formats
|
| 182 |
+
""")
|
| 183 |
+
|
| 184 |
+
# About section
|
| 185 |
+
with gr.Accordion("ℹ️ About This System", open=False):
|
| 186 |
+
gr.Markdown("""
|
| 187 |
+
### Technology Stack
|
| 188 |
+
|
| 189 |
+
**Text Detection:**
|
| 190 |
+
- RoBERTa-base fine-tuned on human/AI text
|
| 191 |
+
- GPT-2 perplexity analysis
|
| 192 |
+
- Perplexity scoring for confidence
|
| 193 |
+
|
| 194 |
+
**Image Detection:**
|
| 195 |
+
- EfficientNet-B4 for deepfake classification
|
| 196 |
+
- Face detection with MTCNN/RetinaFace
|
| 197 |
+
- Artifact detection (blending, compression)
|
| 198 |
+
|
| 199 |
+
**Performance:**
|
| 200 |
+
- Text: ~95% accuracy on benchmark datasets
|
| 201 |
+
- Images: ~93% accuracy on FaceForensics++
|
| 202 |
+
- Processing: <2 seconds per request
|
| 203 |
+
|
| 204 |
+
### Use Cases
|
| 205 |
+
- Social media content moderation
|
| 206 |
+
- News verification
|
| 207 |
+
- Academic integrity
|
| 208 |
+
- Digital forensics
|
| 209 |
+
|
| 210 |
+
### Author
|
| 211 |
+
Built by Shreyas Gosavi for Google DeepMind Research Engineer application
|
| 212 |
+
|
| 213 |
+
[GitHub Repository](https://github.com/YOUR_USERNAME/multimodal-misinformation-detection)
|
| 214 |
+
""")
|
| 215 |
+
|
| 216 |
+
# Connect buttons to functions
|
| 217 |
+
text_button.click(fn=analyze_text, inputs=text_input, outputs=text_output)
|
| 218 |
+
image_button.click(fn=analyze_image, inputs=image_input, outputs=image_output)
|
| 219 |
+
|
| 220 |
+
# Launch
|
| 221 |
+
if __name__ == "__main__":
|
| 222 |
+
demo.launch()
|
requirements-hf.txt
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Hugging Face Spaces Requirements
|
| 2 |
+
# Minimal dependencies for deployment
|
| 3 |
+
|
| 4 |
+
# Core ML
|
| 5 |
+
torch>=2.0.0
|
| 6 |
+
torchvision>=0.15.0
|
| 7 |
+
transformers>=4.30.0
|
| 8 |
+
timm>=0.9.0
|
| 9 |
+
|
| 10 |
+
# Detection
|
| 11 |
+
opencv-python-headless>=4.8.0
|
| 12 |
+
Pillow>=10.0.0
|
| 13 |
+
numpy>=1.24.0
|
| 14 |
+
scikit-learn>=1.3.0
|
| 15 |
+
|
| 16 |
+
# UI
|
| 17 |
+
gradio>=4.0.0
|
| 18 |
+
|
| 19 |
+
# Utilities
|
| 20 |
+
tqdm>=4.65.0
|
src/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Init files for package structure."""
|
src/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (266 Bytes). View file
|
|
|
src/api/__pycache__/main.cpython-313.pyc
ADDED
|
Binary file (16.9 kB). View file
|
|
|
src/api/main.py
ADDED
|
@@ -0,0 +1,461 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FastAPI Main Application
|
| 3 |
+
|
| 4 |
+
Production-ready API for multimodal misinformation detection.
|
| 5 |
+
|
| 6 |
+
Features:
|
| 7 |
+
- Async endpoints
|
| 8 |
+
- Rate limiting
|
| 9 |
+
- Authentication
|
| 10 |
+
- Background task processing
|
| 11 |
+
- Comprehensive error handling
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from fastapi import FastAPI, File, UploadFile, HTTPException, Depends, BackgroundTasks, Request
|
| 15 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 16 |
+
from fastapi.responses import JSONResponse
|
| 17 |
+
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
| 18 |
+
from pydantic import BaseModel, Field
|
| 19 |
+
from typing import Optional, List, Dict
|
| 20 |
+
import uvicorn
|
| 21 |
+
from datetime import datetime
|
| 22 |
+
import logging
|
| 23 |
+
import asyncio
|
| 24 |
+
from pathlib import Path
|
| 25 |
+
import tempfile
|
| 26 |
+
import os
|
| 27 |
+
|
| 28 |
+
# Import detection modules
|
| 29 |
+
import sys
|
| 30 |
+
sys.path.append(str(Path(__file__).parent.parent))
|
| 31 |
+
|
| 32 |
+
from detection.deepfake_detector import DeepfakeDetector
|
| 33 |
+
from detection.ai_text_detector import AITextDetector
|
| 34 |
+
from detection.anomaly_detector import AnomalyDetector
|
| 35 |
+
|
| 36 |
+
# Configure logging
|
| 37 |
+
logging.basicConfig(
|
| 38 |
+
level=logging.INFO,
|
| 39 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
| 40 |
+
)
|
| 41 |
+
logger = logging.getLogger(__name__)
|
| 42 |
+
|
| 43 |
+
# Initialize FastAPI app
|
| 44 |
+
app = FastAPI(
|
| 45 |
+
title="Multimodal Misinformation Detection API",
|
| 46 |
+
description="Production API for detecting deepfakes, AI-generated content, and coordinated campaigns",
|
| 47 |
+
version="1.0.0",
|
| 48 |
+
docs_url="/docs",
|
| 49 |
+
redoc_url="/redoc"
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
# CORS middleware
|
| 53 |
+
app.add_middleware(
|
| 54 |
+
CORSMiddleware,
|
| 55 |
+
allow_origins=["*"], # Configure appropriately for production
|
| 56 |
+
allow_credentials=True,
|
| 57 |
+
allow_methods=["*"],
|
| 58 |
+
allow_headers=["*"],
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
# Security
|
| 62 |
+
security = HTTPBearer()
|
| 63 |
+
|
| 64 |
+
# Initialize detectors (lazy loading for performance)
|
| 65 |
+
_deepfake_detector = None
|
| 66 |
+
_ai_text_detector = None
|
| 67 |
+
_anomaly_detector = None
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def get_deepfake_detector():
|
| 71 |
+
"""Lazy load deepfake detector."""
|
| 72 |
+
global _deepfake_detector
|
| 73 |
+
if _deepfake_detector is None:
|
| 74 |
+
_deepfake_detector = DeepfakeDetector()
|
| 75 |
+
return _deepfake_detector
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def get_ai_text_detector():
|
| 79 |
+
"""Lazy load AI text detector."""
|
| 80 |
+
global _ai_text_detector
|
| 81 |
+
if _ai_text_detector is None:
|
| 82 |
+
_ai_text_detector = AITextDetector()
|
| 83 |
+
return _ai_text_detector
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def get_anomaly_detector():
|
| 87 |
+
"""Lazy load anomaly detector."""
|
| 88 |
+
global _anomaly_detector
|
| 89 |
+
if _anomaly_detector is None:
|
| 90 |
+
_anomaly_detector = AnomalyDetector()
|
| 91 |
+
return _anomaly_detector
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
# Request/Response Models
|
| 95 |
+
class TextAnalysisRequest(BaseModel):
|
| 96 |
+
text: str = Field(..., min_length=10, description="Text to analyze")
|
| 97 |
+
detailed: bool = Field(default=True, description="Return detailed analysis")
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class TextAnalysisResponse(BaseModel):
|
| 101 |
+
verdict: str
|
| 102 |
+
confidence: float
|
| 103 |
+
perplexity: Optional[float] = None
|
| 104 |
+
explanation: str
|
| 105 |
+
timestamp: datetime
|
| 106 |
+
processing_time_ms: float
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class ImageAnalysisResponse(BaseModel):
|
| 110 |
+
verdict: str
|
| 111 |
+
confidence: float
|
| 112 |
+
faces_analyzed: int
|
| 113 |
+
explanation: str
|
| 114 |
+
artifacts_detected: List[str]
|
| 115 |
+
timestamp: datetime
|
| 116 |
+
processing_time_ms: float
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class HealthResponse(BaseModel):
|
| 120 |
+
status: str
|
| 121 |
+
version: str
|
| 122 |
+
timestamp: datetime
|
| 123 |
+
models_loaded: Dict[str, bool]
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
# Middleware for request timing and security headers
|
| 127 |
+
@app.middleware("http")
|
| 128 |
+
async def add_process_time_header(request: Request, call_next):
|
| 129 |
+
"""Add processing time and security headers to response."""
|
| 130 |
+
start_time = datetime.utcnow()
|
| 131 |
+
response = await call_next(request)
|
| 132 |
+
process_time = (datetime.utcnow() - start_time).total_seconds() * 1000
|
| 133 |
+
response.headers["X-Process-Time-Ms"] = str(process_time)
|
| 134 |
+
|
| 135 |
+
# Add CSP header that allows Swagger UI to work
|
| 136 |
+
if request.url.path in ["/docs", "/redoc"] or request.url.path.startswith("/openapi"):
|
| 137 |
+
response.headers["Content-Security-Policy"] = (
|
| 138 |
+
"default-src 'self'; "
|
| 139 |
+
"script-src 'self' 'unsafe-inline' 'unsafe-eval' https://cdn.jsdelivr.net; "
|
| 140 |
+
"style-src 'self' 'unsafe-inline' https://cdn.jsdelivr.net; "
|
| 141 |
+
"img-src 'self' data: https:; "
|
| 142 |
+
"font-src 'self' data: https://cdn.jsdelivr.net;"
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
return response
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
# Authentication dependency (simplified)
|
| 149 |
+
async def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
| 150 |
+
"""
|
| 151 |
+
Verify API token.
|
| 152 |
+
In production, implement proper JWT verification.
|
| 153 |
+
"""
|
| 154 |
+
token = credentials.credentials
|
| 155 |
+
|
| 156 |
+
# Simplified check - implement proper verification
|
| 157 |
+
if token != os.getenv("API_TOKEN", "dev-token"):
|
| 158 |
+
raise HTTPException(
|
| 159 |
+
status_code=401,
|
| 160 |
+
detail="Invalid authentication credentials"
|
| 161 |
+
)
|
| 162 |
+
return token
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
# API Endpoints
|
| 166 |
+
|
| 167 |
+
@app.get("/", response_model=HealthResponse)
|
| 168 |
+
async def root():
|
| 169 |
+
"""Root endpoint with API health status."""
|
| 170 |
+
return {
|
| 171 |
+
"status": "operational",
|
| 172 |
+
"version": "1.0.0",
|
| 173 |
+
"timestamp": datetime.utcnow(),
|
| 174 |
+
"models_loaded": {
|
| 175 |
+
"deepfake_detector": _deepfake_detector is not None,
|
| 176 |
+
"ai_text_detector": _ai_text_detector is not None,
|
| 177 |
+
"anomaly_detector": _anomaly_detector is not None
|
| 178 |
+
}
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
@app.get("/health")
|
| 183 |
+
async def health_check():
|
| 184 |
+
"""Health check endpoint for monitoring."""
|
| 185 |
+
return {
|
| 186 |
+
"status": "healthy",
|
| 187 |
+
"timestamp": datetime.utcnow().isoformat()
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
@app.post("/api/v1/analyze/text", response_model=TextAnalysisResponse)
|
| 192 |
+
async def analyze_text(
|
| 193 |
+
request: TextAnalysisRequest,
|
| 194 |
+
background_tasks: BackgroundTasks,
|
| 195 |
+
# token: str = Depends(verify_token) # Uncomment for auth
|
| 196 |
+
):
|
| 197 |
+
"""
|
| 198 |
+
Analyze text for AI generation.
|
| 199 |
+
|
| 200 |
+
**Example Request:**
|
| 201 |
+
```json
|
| 202 |
+
{
|
| 203 |
+
"text": "Your text here...",
|
| 204 |
+
"detailed": true
|
| 205 |
+
}
|
| 206 |
+
```
|
| 207 |
+
"""
|
| 208 |
+
start_time = datetime.utcnow()
|
| 209 |
+
|
| 210 |
+
try:
|
| 211 |
+
detector = get_ai_text_detector()
|
| 212 |
+
result = detector.analyze_text(request.text, detailed=request.detailed)
|
| 213 |
+
|
| 214 |
+
processing_time = (datetime.utcnow() - start_time).total_seconds() * 1000
|
| 215 |
+
|
| 216 |
+
# Log analytics in background
|
| 217 |
+
background_tasks.add_task(
|
| 218 |
+
log_analysis,
|
| 219 |
+
"text",
|
| 220 |
+
result['verdict'],
|
| 221 |
+
processing_time
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
return {
|
| 225 |
+
**result,
|
| 226 |
+
"timestamp": datetime.utcnow(),
|
| 227 |
+
"processing_time_ms": processing_time
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
except Exception as e:
|
| 231 |
+
logger.error(f"Error analyzing text: {str(e)}")
|
| 232 |
+
raise HTTPException(status_code=500, detail=f"Analysis failed: {str(e)}")
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
@app.post("/api/v1/analyze/image", response_model=ImageAnalysisResponse)
|
| 236 |
+
async def analyze_image(
|
| 237 |
+
file: UploadFile = File(...),
|
| 238 |
+
return_attention: bool = False,
|
| 239 |
+
background_tasks: BackgroundTasks = BackgroundTasks(),
|
| 240 |
+
# token: str = Depends(verify_token)
|
| 241 |
+
):
|
| 242 |
+
"""
|
| 243 |
+
Analyze image for deepfake artifacts.
|
| 244 |
+
|
| 245 |
+
**Supported formats:** JPG, PNG, WebP
|
| 246 |
+
**Max size:** 10MB
|
| 247 |
+
"""
|
| 248 |
+
start_time = datetime.utcnow()
|
| 249 |
+
|
| 250 |
+
# Validate file
|
| 251 |
+
if file.content_type not in ["image/jpeg", "image/png", "image/webp"]:
|
| 252 |
+
raise HTTPException(
|
| 253 |
+
status_code=400,
|
| 254 |
+
detail="Invalid file type. Supported: JPEG, PNG, WebP"
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
# Save uploaded file temporarily
|
| 258 |
+
try:
|
| 259 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=Path(file.filename).suffix) as tmp:
|
| 260 |
+
content = await file.read()
|
| 261 |
+
tmp.write(content)
|
| 262 |
+
tmp_path = tmp.name
|
| 263 |
+
|
| 264 |
+
# Analyze
|
| 265 |
+
detector = get_deepfake_detector()
|
| 266 |
+
result = detector.analyze_image(tmp_path, return_attention=return_attention)
|
| 267 |
+
|
| 268 |
+
processing_time = (datetime.utcnow() - start_time).total_seconds() * 1000
|
| 269 |
+
|
| 270 |
+
# Cleanup
|
| 271 |
+
os.unlink(tmp_path)
|
| 272 |
+
|
| 273 |
+
# Log in background
|
| 274 |
+
background_tasks.add_task(
|
| 275 |
+
log_analysis,
|
| 276 |
+
"image",
|
| 277 |
+
result['verdict'],
|
| 278 |
+
processing_time
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
return {
|
| 282 |
+
**result,
|
| 283 |
+
"timestamp": datetime.utcnow(),
|
| 284 |
+
"processing_time_ms": processing_time
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
except Exception as e:
|
| 288 |
+
logger.error(f"Error analyzing image: {str(e)}")
|
| 289 |
+
# Cleanup on error
|
| 290 |
+
if 'tmp_path' in locals():
|
| 291 |
+
try:
|
| 292 |
+
os.unlink(tmp_path)
|
| 293 |
+
except:
|
| 294 |
+
pass
|
| 295 |
+
raise HTTPException(status_code=500, detail=f"Analysis failed: {str(e)}")
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
@app.post("/api/v1/analyze/video")
|
| 299 |
+
async def analyze_video(
|
| 300 |
+
file: UploadFile = File(...),
|
| 301 |
+
sample_rate: int = 5,
|
| 302 |
+
max_frames: int = 100,
|
| 303 |
+
background_tasks: BackgroundTasks = BackgroundTasks(),
|
| 304 |
+
# token: str = Depends(verify_token)
|
| 305 |
+
):
|
| 306 |
+
"""
|
| 307 |
+
Analyze video for deepfake artifacts.
|
| 308 |
+
|
| 309 |
+
**Supported formats:** MP4, AVI, MOV
|
| 310 |
+
**Max size:** 100MB
|
| 311 |
+
**Processing:** Async with job ID returned immediately
|
| 312 |
+
"""
|
| 313 |
+
start_time = datetime.utcnow()
|
| 314 |
+
|
| 315 |
+
# Validate file
|
| 316 |
+
if file.content_type not in ["video/mp4", "video/avi", "video/quicktime"]:
|
| 317 |
+
raise HTTPException(
|
| 318 |
+
status_code=400,
|
| 319 |
+
detail="Invalid file type. Supported: MP4, AVI, MOV"
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
try:
|
| 323 |
+
# Save file
|
| 324 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=Path(file.filename).suffix) as tmp:
|
| 325 |
+
content = await file.read()
|
| 326 |
+
tmp.write(content)
|
| 327 |
+
tmp_path = tmp.name
|
| 328 |
+
|
| 329 |
+
# For large videos, process in background
|
| 330 |
+
# For demo, process synchronously
|
| 331 |
+
detector = get_deepfake_detector()
|
| 332 |
+
result = detector.analyze_video(
|
| 333 |
+
tmp_path,
|
| 334 |
+
sample_rate=sample_rate,
|
| 335 |
+
max_frames=max_frames
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
processing_time = (datetime.utcnow() - start_time).total_seconds() * 1000
|
| 339 |
+
|
| 340 |
+
# Cleanup
|
| 341 |
+
os.unlink(tmp_path)
|
| 342 |
+
|
| 343 |
+
# Log in background
|
| 344 |
+
background_tasks.add_task(
|
| 345 |
+
log_analysis,
|
| 346 |
+
"video",
|
| 347 |
+
result['verdict'],
|
| 348 |
+
processing_time
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
return {
|
| 352 |
+
**result,
|
| 353 |
+
"timestamp": datetime.utcnow(),
|
| 354 |
+
"processing_time_ms": processing_time
|
| 355 |
+
}
|
| 356 |
+
|
| 357 |
+
except Exception as e:
|
| 358 |
+
logger.error(f"Error analyzing video: {str(e)}")
|
| 359 |
+
if 'tmp_path' in locals():
|
| 360 |
+
try:
|
| 361 |
+
os.unlink(tmp_path)
|
| 362 |
+
except:
|
| 363 |
+
pass
|
| 364 |
+
raise HTTPException(status_code=500, detail=f"Analysis failed: {str(e)}")
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
@app.post("/api/v1/batch/text")
|
| 368 |
+
async def batch_analyze_text(
|
| 369 |
+
texts: List[str],
|
| 370 |
+
background_tasks: BackgroundTasks,
|
| 371 |
+
# token: str = Depends(verify_token)
|
| 372 |
+
):
|
| 373 |
+
"""
|
| 374 |
+
Batch analyze multiple texts.
|
| 375 |
+
|
| 376 |
+
**Limit:** 100 texts per request
|
| 377 |
+
"""
|
| 378 |
+
if len(texts) > 100:
|
| 379 |
+
raise HTTPException(
|
| 380 |
+
status_code=400,
|
| 381 |
+
detail="Maximum 100 texts per batch"
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
start_time = datetime.utcnow()
|
| 385 |
+
|
| 386 |
+
try:
|
| 387 |
+
detector = get_ai_text_detector()
|
| 388 |
+
results = detector.batch_analyze(texts)
|
| 389 |
+
|
| 390 |
+
processing_time = (datetime.utcnow() - start_time).total_seconds() * 1000
|
| 391 |
+
|
| 392 |
+
return {
|
| 393 |
+
"results": results,
|
| 394 |
+
"total_analyzed": len(texts),
|
| 395 |
+
"timestamp": datetime.utcnow(),
|
| 396 |
+
"processing_time_ms": processing_time
|
| 397 |
+
}
|
| 398 |
+
|
| 399 |
+
except Exception as e:
|
| 400 |
+
logger.error(f"Error in batch analysis: {str(e)}")
|
| 401 |
+
raise HTTPException(status_code=500, detail=f"Batch analysis failed: {str(e)}")
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
# Background task for logging
|
| 405 |
+
async def log_analysis(modality: str, verdict: str, processing_time: float):
|
| 406 |
+
"""Log analysis for monitoring and analytics."""
|
| 407 |
+
logger.info(
|
| 408 |
+
f"Analysis completed - Modality: {modality}, "
|
| 409 |
+
f"Verdict: {verdict}, Time: {processing_time:.2f}ms"
|
| 410 |
+
)
|
| 411 |
+
# In production: send to monitoring system (Prometheus, CloudWatch, etc.)
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
# Error handlers
|
| 415 |
+
@app.exception_handler(HTTPException)
|
| 416 |
+
async def http_exception_handler(request: Request, exc: HTTPException):
|
| 417 |
+
"""Custom HTTP exception handler."""
|
| 418 |
+
return JSONResponse(
|
| 419 |
+
status_code=exc.status_code,
|
| 420 |
+
content={
|
| 421 |
+
"error": exc.detail,
|
| 422 |
+
"timestamp": datetime.utcnow().isoformat()
|
| 423 |
+
}
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
@app.exception_handler(Exception)
|
| 428 |
+
async def general_exception_handler(request: Request, exc: Exception):
|
| 429 |
+
"""General exception handler."""
|
| 430 |
+
logger.error(f"Unhandled exception: {str(exc)}")
|
| 431 |
+
return JSONResponse(
|
| 432 |
+
status_code=500,
|
| 433 |
+
content={
|
| 434 |
+
"error": "Internal server error",
|
| 435 |
+
"timestamp": datetime.utcnow().isoformat()
|
| 436 |
+
}
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
# Startup/Shutdown events
|
| 441 |
+
@app.on_event("startup")
|
| 442 |
+
async def startup_event():
|
| 443 |
+
"""Initialize on startup."""
|
| 444 |
+
logger.info("🚀 Starting Multimodal Misinformation Detection API")
|
| 445 |
+
logger.info("📊 API Documentation: http://localhost:8000/docs")
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
@app.on_event("shutdown")
|
| 449 |
+
async def shutdown_event():
|
| 450 |
+
"""Cleanup on shutdown."""
|
| 451 |
+
logger.info("🛑 Shutting down API")
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
if __name__ == "__main__":
|
| 455 |
+
uvicorn.run(
|
| 456 |
+
"main:app",
|
| 457 |
+
host="0.0.0.0",
|
| 458 |
+
port=8000,
|
| 459 |
+
reload=True,
|
| 460 |
+
log_level="info"
|
| 461 |
+
)
|
src/api/schemas.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
API Request/Response Schemas for Production
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
from typing import Optional, List, Dict, Any
|
| 7 |
+
from pydantic import BaseModel, Field, EmailStr, field_validator
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
# Authentication Schemas
|
| 11 |
+
class UserLogin(BaseModel):
|
| 12 |
+
"""User login request"""
|
| 13 |
+
email: EmailStr
|
| 14 |
+
password: str = Field(..., min_length=8)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class UserCreate(BaseModel):
|
| 18 |
+
"""User registration request"""
|
| 19 |
+
email: EmailStr
|
| 20 |
+
password: str = Field(..., min_length=8)
|
| 21 |
+
full_name: Optional[str] = None
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class UserResponse(BaseModel):
|
| 25 |
+
"""User response"""
|
| 26 |
+
id: int
|
| 27 |
+
email: EmailStr
|
| 28 |
+
full_name: Optional[str] = None
|
| 29 |
+
is_active: bool
|
| 30 |
+
is_superuser: bool
|
| 31 |
+
created_at: datetime
|
| 32 |
+
|
| 33 |
+
class Config:
|
| 34 |
+
from_attributes = True
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class Token(BaseModel):
|
| 38 |
+
"""JWT token response"""
|
| 39 |
+
access_token: str
|
| 40 |
+
refresh_token: str
|
| 41 |
+
token_type: str = "bearer"
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class TokenRefresh(BaseModel):
|
| 45 |
+
"""Token refresh request"""
|
| 46 |
+
refresh_token: str
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class APIKeyCreate(BaseModel):
|
| 50 |
+
"""API key creation request"""
|
| 51 |
+
name: str = Field(..., min_length=1, max_length=255)
|
| 52 |
+
expires_days: Optional[int] = Field(default=None, gt=0, le=365)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class APIKeyResponse(BaseModel):
|
| 56 |
+
"""API key response"""
|
| 57 |
+
id: int
|
| 58 |
+
key: str
|
| 59 |
+
name: str
|
| 60 |
+
is_active: bool
|
| 61 |
+
rate_limit_per_minute: int
|
| 62 |
+
rate_limit_per_hour: int
|
| 63 |
+
created_at: datetime
|
| 64 |
+
expires_at: Optional[datetime] = None
|
| 65 |
+
last_used_at: Optional[datetime] = None
|
| 66 |
+
|
| 67 |
+
class Config:
|
| 68 |
+
from_attributes = True
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# Analysis Request Schemas
|
| 72 |
+
class TextAnalysisRequest(BaseModel):
|
| 73 |
+
"""Text analysis request"""
|
| 74 |
+
text: str = Field(..., min_length=10, max_length=100000)
|
| 75 |
+
model_version: Optional[str] = Field(default=None, description="Optional model version")
|
| 76 |
+
|
| 77 |
+
@field_validator("text")
|
| 78 |
+
@classmethod
|
| 79 |
+
def validate_text(cls, v):
|
| 80 |
+
if not v.strip():
|
| 81 |
+
raise ValueError("Text cannot be empty")
|
| 82 |
+
return v.strip()
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class ImageAnalysisRequest(BaseModel):
|
| 86 |
+
"""Image analysis metadata"""
|
| 87 |
+
filename: Optional[str] = None
|
| 88 |
+
model_version: Optional[str] = None
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class VideoAnalysisRequest(BaseModel):
|
| 92 |
+
"""Video analysis metadata"""
|
| 93 |
+
filename: Optional[str] = None
|
| 94 |
+
analyze_frames: bool = Field(default=True, description="Analyze individual frames")
|
| 95 |
+
frame_sample_rate: int = Field(default=30, ge=1, le=60, description="Frames to analyze per second")
|
| 96 |
+
model_version: Optional[str] = None
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class BatchTextAnalysisRequest(BaseModel):
|
| 100 |
+
"""Batch text analysis request"""
|
| 101 |
+
texts: List[str] = Field(..., min_length=1, max_length=100)
|
| 102 |
+
model_version: Optional[str] = None
|
| 103 |
+
|
| 104 |
+
@field_validator("texts")
|
| 105 |
+
@classmethod
|
| 106 |
+
def validate_texts(cls, v):
|
| 107 |
+
if not v:
|
| 108 |
+
raise ValueError("At least one text is required")
|
| 109 |
+
|
| 110 |
+
for text in v:
|
| 111 |
+
if not text or not text.strip():
|
| 112 |
+
raise ValueError("All texts must be non-empty")
|
| 113 |
+
if len(text) > 100000:
|
| 114 |
+
raise ValueError("Text exceeds maximum length of 100,000 characters")
|
| 115 |
+
|
| 116 |
+
return [text.strip() for text in v]
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
# Analysis Response Schemas
|
| 120 |
+
class DetectionResult(BaseModel):
|
| 121 |
+
"""Base detection result"""
|
| 122 |
+
prediction: str = Field(..., description="Prediction label")
|
| 123 |
+
confidence: float = Field(..., ge=0, le=1, description="Confidence score")
|
| 124 |
+
details: Dict[str, Any] = Field(default_factory=dict, description="Additional details")
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class TextAnalysisResponse(BaseModel):
|
| 128 |
+
"""Text analysis response"""
|
| 129 |
+
request_id: str
|
| 130 |
+
prediction: str
|
| 131 |
+
confidence: float
|
| 132 |
+
perplexity: Optional[float] = None
|
| 133 |
+
statistical_features: Optional[Dict[str, float]] = None
|
| 134 |
+
explanation: str
|
| 135 |
+
processing_time_ms: float
|
| 136 |
+
cached: bool = False
|
| 137 |
+
model_version: str
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class ImageAnalysisResponse(BaseModel):
|
| 141 |
+
"""Image analysis response"""
|
| 142 |
+
request_id: str
|
| 143 |
+
prediction: str
|
| 144 |
+
confidence: float
|
| 145 |
+
face_detected: bool
|
| 146 |
+
manipulation_score: float
|
| 147 |
+
artifacts_detected: List[str] = Field(default_factory=list)
|
| 148 |
+
explanation: str
|
| 149 |
+
processing_time_ms: float
|
| 150 |
+
cached: bool = False
|
| 151 |
+
model_version: str
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class VideoAnalysisResponse(BaseModel):
|
| 155 |
+
"""Video analysis response"""
|
| 156 |
+
request_id: str
|
| 157 |
+
prediction: str
|
| 158 |
+
confidence: float
|
| 159 |
+
frames_analyzed: int
|
| 160 |
+
temporal_consistency: float
|
| 161 |
+
frame_predictions: List[Dict[str, Any]] = Field(default_factory=list)
|
| 162 |
+
explanation: str
|
| 163 |
+
processing_time_ms: float
|
| 164 |
+
model_version: str
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class BatchTextAnalysisResponse(BaseModel):
|
| 168 |
+
"""Batch text analysis response"""
|
| 169 |
+
request_id: str
|
| 170 |
+
results: List[TextAnalysisResponse]
|
| 171 |
+
total_processed: int
|
| 172 |
+
processing_time_ms: float
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
class AnomalyDetectionResponse(BaseModel):
|
| 176 |
+
"""Anomaly detection response"""
|
| 177 |
+
request_id: str
|
| 178 |
+
detected: bool
|
| 179 |
+
anomaly_score: float
|
| 180 |
+
anomaly_type: Optional[str] = None
|
| 181 |
+
explanation: str
|
| 182 |
+
details: Dict[str, Any] = Field(default_factory=dict)
|
| 183 |
+
processing_time_ms: float
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
# Health & Status Schemas
|
| 187 |
+
class HealthResponse(BaseModel):
|
| 188 |
+
"""Health check response"""
|
| 189 |
+
status: str = "healthy"
|
| 190 |
+
timestamp: datetime
|
| 191 |
+
version: str
|
| 192 |
+
environment: str
|
| 193 |
+
services: Dict[str, str] = Field(default_factory=dict)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
class MetricsResponse(BaseModel):
|
| 197 |
+
"""System metrics response"""
|
| 198 |
+
requests_total: int
|
| 199 |
+
requests_per_minute: float
|
| 200 |
+
average_response_time_ms: float
|
| 201 |
+
cache_hit_rate: float
|
| 202 |
+
active_users: int
|
| 203 |
+
models_loaded: List[str]
|
| 204 |
+
uptime_seconds: float
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
# Error Response Schemas
|
| 208 |
+
class ErrorResponse(BaseModel):
|
| 209 |
+
"""Standard error response"""
|
| 210 |
+
error: str = Field(..., description="Error type")
|
| 211 |
+
message: str = Field(..., description="Error message")
|
| 212 |
+
details: Optional[Dict[str, Any]] = Field(default=None, description="Additional error details")
|
| 213 |
+
request_id: Optional[str] = Field(default=None, description="Request ID for tracking")
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
class ValidationErrorResponse(BaseModel):
|
| 217 |
+
"""Validation error response"""
|
| 218 |
+
error: str = "ValidationError"
|
| 219 |
+
message: str
|
| 220 |
+
details: Dict[str, List[str]] = Field(..., description="Field-specific validation errors")
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
# Admin Schemas
|
| 224 |
+
class UserListResponse(BaseModel):
|
| 225 |
+
"""User list response"""
|
| 226 |
+
users: List[UserResponse]
|
| 227 |
+
total: int
|
| 228 |
+
page: int
|
| 229 |
+
page_size: int
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
class SystemStatsResponse(BaseModel):
|
| 233 |
+
"""System statistics response"""
|
| 234 |
+
total_users: int
|
| 235 |
+
active_users: int
|
| 236 |
+
total_requests: int
|
| 237 |
+
total_predictions: int
|
| 238 |
+
average_confidence: float
|
| 239 |
+
most_used_models: List[Dict[str, Any]]
|
| 240 |
+
cache_stats: Dict[str, Any]
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
class LogEntry(BaseModel):
|
| 244 |
+
"""Log entry"""
|
| 245 |
+
timestamp: datetime
|
| 246 |
+
level: str
|
| 247 |
+
message: str
|
| 248 |
+
context: Optional[Dict[str, Any]] = None
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
class LogsResponse(BaseModel):
|
| 252 |
+
"""Logs response"""
|
| 253 |
+
logs: List[LogEntry]
|
| 254 |
+
total: int
|
| 255 |
+
page: int
|
| 256 |
+
page_size: int
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
# Pagination
|
| 260 |
+
class PaginationParams(BaseModel):
|
| 261 |
+
"""Pagination parameters"""
|
| 262 |
+
page: int = Field(default=1, ge=1)
|
| 263 |
+
page_size: int = Field(default=20, ge=1, le=100)
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
if __name__ == "__main__":
|
| 267 |
+
# Test schemas
|
| 268 |
+
request = TextAnalysisRequest(text="This is a test text for analysis")
|
| 269 |
+
print(f"Request: {request}")
|
| 270 |
+
|
| 271 |
+
response = TextAnalysisResponse(
|
| 272 |
+
request_id="test-123",
|
| 273 |
+
prediction="HUMAN",
|
| 274 |
+
confidence=0.95,
|
| 275 |
+
perplexity=45.2,
|
| 276 |
+
explanation="Text exhibits natural language patterns",
|
| 277 |
+
processing_time_ms=125.5,
|
| 278 |
+
model_version="1.0"
|
| 279 |
+
)
|
| 280 |
+
print(f"Response: {response.model_dump_json(indent=2)}")
|
src/core/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Core application components"""
|
| 2 |
+
|
| 3 |
+
from .config import settings, validate_production_config
|
| 4 |
+
|
| 5 |
+
__all__ = ["settings", "validate_production_config"]
|
src/core/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (385 Bytes). View file
|
|
|
src/core/__pycache__/config.cpython-313.pyc
ADDED
|
Binary file (10.9 kB). View file
|
|
|
src/core/__pycache__/logging.cpython-313.pyc
ADDED
|
Binary file (5.81 kB). View file
|
|
|
src/core/cache.py
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Redis Cache Implementation for Production
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import hashlib
|
| 7 |
+
from typing import Any, Optional, Union
|
| 8 |
+
from datetime import timedelta
|
| 9 |
+
import redis.asyncio as aioredis
|
| 10 |
+
|
| 11 |
+
from src.core.config import settings
|
| 12 |
+
from src.core.logging import logger
|
| 13 |
+
from src.core.exceptions import CacheError
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class RedisCache:
|
| 17 |
+
"""Redis cache manager with async support"""
|
| 18 |
+
|
| 19 |
+
def __init__(self):
|
| 20 |
+
self.redis: Optional[aioredis.Redis] = None
|
| 21 |
+
self.enabled = settings.CACHE_PREDICTIONS
|
| 22 |
+
|
| 23 |
+
async def connect(self):
|
| 24 |
+
"""Connect to Redis"""
|
| 25 |
+
if not self.enabled:
|
| 26 |
+
logger.info("Redis cache is disabled")
|
| 27 |
+
return
|
| 28 |
+
|
| 29 |
+
try:
|
| 30 |
+
self.redis = await aioredis.from_url(
|
| 31 |
+
settings.REDIS_URL,
|
| 32 |
+
encoding="utf-8",
|
| 33 |
+
decode_responses=True,
|
| 34 |
+
max_connections=50
|
| 35 |
+
)
|
| 36 |
+
# Test connection
|
| 37 |
+
await self.redis.ping()
|
| 38 |
+
logger.info(f"Connected to Redis at {settings.REDIS_HOST}:{settings.REDIS_PORT}")
|
| 39 |
+
except Exception as e:
|
| 40 |
+
logger.error(f"Failed to connect to Redis: {e}")
|
| 41 |
+
self.enabled = False
|
| 42 |
+
raise CacheError(f"Redis connection failed: {e}")
|
| 43 |
+
|
| 44 |
+
async def disconnect(self):
|
| 45 |
+
"""Disconnect from Redis"""
|
| 46 |
+
if self.redis:
|
| 47 |
+
await self.redis.close()
|
| 48 |
+
logger.info("Disconnected from Redis")
|
| 49 |
+
|
| 50 |
+
def _generate_cache_key(self, prefix: str, data: Union[str, dict]) -> str:
|
| 51 |
+
"""Generate cache key from data"""
|
| 52 |
+
if isinstance(data, dict):
|
| 53 |
+
data_str = json.dumps(data, sort_keys=True)
|
| 54 |
+
else:
|
| 55 |
+
data_str = str(data)
|
| 56 |
+
|
| 57 |
+
hash_value = hashlib.sha256(data_str.encode()).hexdigest()[:16]
|
| 58 |
+
return f"{prefix}:{hash_value}"
|
| 59 |
+
|
| 60 |
+
async def get(self, key: str) -> Optional[Any]:
|
| 61 |
+
"""Get value from cache"""
|
| 62 |
+
if not self.enabled or not self.redis:
|
| 63 |
+
return None
|
| 64 |
+
|
| 65 |
+
try:
|
| 66 |
+
value = await self.redis.get(key)
|
| 67 |
+
if value:
|
| 68 |
+
logger.debug(f"Cache hit: {key}")
|
| 69 |
+
return json.loads(value)
|
| 70 |
+
logger.debug(f"Cache miss: {key}")
|
| 71 |
+
return None
|
| 72 |
+
except Exception as e:
|
| 73 |
+
logger.warning(f"Cache get error for {key}: {e}")
|
| 74 |
+
return None
|
| 75 |
+
|
| 76 |
+
async def set(
|
| 77 |
+
self,
|
| 78 |
+
key: str,
|
| 79 |
+
value: Any,
|
| 80 |
+
ttl: Optional[int] = None
|
| 81 |
+
) -> bool:
|
| 82 |
+
"""Set value in cache with TTL"""
|
| 83 |
+
if not self.enabled or not self.redis:
|
| 84 |
+
return False
|
| 85 |
+
|
| 86 |
+
try:
|
| 87 |
+
ttl = ttl or settings.CACHE_TTL
|
| 88 |
+
value_json = json.dumps(value)
|
| 89 |
+
await self.redis.setex(key, ttl, value_json)
|
| 90 |
+
logger.debug(f"Cache set: {key} (TTL: {ttl}s)")
|
| 91 |
+
return True
|
| 92 |
+
except Exception as e:
|
| 93 |
+
logger.warning(f"Cache set error for {key}: {e}")
|
| 94 |
+
return False
|
| 95 |
+
|
| 96 |
+
async def delete(self, key: str) -> bool:
|
| 97 |
+
"""Delete key from cache"""
|
| 98 |
+
if not self.enabled or not self.redis:
|
| 99 |
+
return False
|
| 100 |
+
|
| 101 |
+
try:
|
| 102 |
+
await self.redis.delete(key)
|
| 103 |
+
logger.debug(f"Cache delete: {key}")
|
| 104 |
+
return True
|
| 105 |
+
except Exception as e:
|
| 106 |
+
logger.warning(f"Cache delete error for {key}: {e}")
|
| 107 |
+
return False
|
| 108 |
+
|
| 109 |
+
async def get_prediction(
|
| 110 |
+
self,
|
| 111 |
+
model_type: str,
|
| 112 |
+
input_data: Union[str, dict]
|
| 113 |
+
) -> Optional[dict]:
|
| 114 |
+
"""Get cached prediction"""
|
| 115 |
+
key = self._generate_cache_key(f"pred:{model_type}", input_data)
|
| 116 |
+
return await self.get(key)
|
| 117 |
+
|
| 118 |
+
async def set_prediction(
|
| 119 |
+
self,
|
| 120 |
+
model_type: str,
|
| 121 |
+
input_data: Union[str, dict],
|
| 122 |
+
result: dict,
|
| 123 |
+
ttl: Optional[int] = None
|
| 124 |
+
) -> bool:
|
| 125 |
+
"""Cache prediction result"""
|
| 126 |
+
key = self._generate_cache_key(f"pred:{model_type}", input_data)
|
| 127 |
+
return await self.set(key, result, ttl)
|
| 128 |
+
|
| 129 |
+
async def increment_rate_limit(
|
| 130 |
+
self,
|
| 131 |
+
identifier: str,
|
| 132 |
+
window_seconds: int
|
| 133 |
+
) -> int:
|
| 134 |
+
"""Increment rate limit counter"""
|
| 135 |
+
if not self.enabled or not self.redis:
|
| 136 |
+
return 0
|
| 137 |
+
|
| 138 |
+
try:
|
| 139 |
+
key = f"ratelimit:{identifier}"
|
| 140 |
+
pipe = self.redis.pipeline()
|
| 141 |
+
pipe.incr(key)
|
| 142 |
+
pipe.expire(key, window_seconds)
|
| 143 |
+
result = await pipe.execute()
|
| 144 |
+
count = result[0]
|
| 145 |
+
logger.debug(f"Rate limit count for {identifier}: {count}")
|
| 146 |
+
return count
|
| 147 |
+
except Exception as e:
|
| 148 |
+
logger.warning(f"Rate limit increment error: {e}")
|
| 149 |
+
return 0
|
| 150 |
+
|
| 151 |
+
async def get_rate_limit_count(self, identifier: str) -> int:
|
| 152 |
+
"""Get current rate limit count"""
|
| 153 |
+
if not self.enabled or not self.redis:
|
| 154 |
+
return 0
|
| 155 |
+
|
| 156 |
+
try:
|
| 157 |
+
key = f"ratelimit:{identifier}"
|
| 158 |
+
count = await self.redis.get(key)
|
| 159 |
+
return int(count) if count else 0
|
| 160 |
+
except Exception as e:
|
| 161 |
+
logger.warning(f"Rate limit get error: {e}")
|
| 162 |
+
return 0
|
| 163 |
+
|
| 164 |
+
async def clear_all(self) -> bool:
|
| 165 |
+
"""Clear all cache (use with caution!)"""
|
| 166 |
+
if not self.enabled or not self.redis:
|
| 167 |
+
return False
|
| 168 |
+
|
| 169 |
+
try:
|
| 170 |
+
await self.redis.flushdb()
|
| 171 |
+
logger.warning("All cache cleared!")
|
| 172 |
+
return True
|
| 173 |
+
except Exception as e:
|
| 174 |
+
logger.error(f"Cache clear error: {e}")
|
| 175 |
+
return False
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
# Global cache instance
|
| 179 |
+
cache = RedisCache()
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
# Decorator for caching function results
|
| 183 |
+
def cached(prefix: str, ttl: Optional[int] = None):
|
| 184 |
+
"""Decorator to cache function results"""
|
| 185 |
+
def decorator(func):
|
| 186 |
+
async def wrapper(*args, **kwargs):
|
| 187 |
+
# Generate cache key from function arguments
|
| 188 |
+
cache_data = {"args": str(args), "kwargs": str(kwargs)}
|
| 189 |
+
cache_key = cache._generate_cache_key(prefix, cache_data)
|
| 190 |
+
|
| 191 |
+
# Try to get from cache
|
| 192 |
+
cached_result = await cache.get(cache_key)
|
| 193 |
+
if cached_result is not None:
|
| 194 |
+
return cached_result
|
| 195 |
+
|
| 196 |
+
# Execute function
|
| 197 |
+
result = await func(*args, **kwargs)
|
| 198 |
+
|
| 199 |
+
# Cache result
|
| 200 |
+
await cache.set(cache_key, result, ttl)
|
| 201 |
+
|
| 202 |
+
return result
|
| 203 |
+
return wrapper
|
| 204 |
+
return decorator
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
if __name__ == "__main__":
|
| 208 |
+
import asyncio
|
| 209 |
+
|
| 210 |
+
async def test_cache():
|
| 211 |
+
# Connect
|
| 212 |
+
await cache.connect()
|
| 213 |
+
|
| 214 |
+
# Test basic operations
|
| 215 |
+
await cache.set("test_key", {"value": 123}, ttl=60)
|
| 216 |
+
result = await cache.get("test_key")
|
| 217 |
+
print(f"Retrieved: {result}")
|
| 218 |
+
|
| 219 |
+
# Test prediction caching
|
| 220 |
+
await cache.set_prediction(
|
| 221 |
+
"deepfake",
|
| 222 |
+
{"image": "test.jpg"},
|
| 223 |
+
{"prediction": "FAKE", "confidence": 0.95},
|
| 224 |
+
ttl=300
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
cached_pred = await cache.get_prediction("deepfake", {"image": "test.jpg"})
|
| 228 |
+
print(f"Cached prediction: {cached_pred}")
|
| 229 |
+
|
| 230 |
+
# Test rate limiting
|
| 231 |
+
for i in range(5):
|
| 232 |
+
count = await cache.increment_rate_limit("user:123", 60)
|
| 233 |
+
print(f"Request {i+1}: Rate limit count = {count}")
|
| 234 |
+
|
| 235 |
+
# Disconnect
|
| 236 |
+
await cache.disconnect()
|
| 237 |
+
|
| 238 |
+
asyncio.run(test_cache())
|
src/core/config.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Production Configuration Management
|
| 3 |
+
Handles environment-based settings, secrets, and feature flags
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import List, Optional
|
| 9 |
+
from pydantic import Field, PostgresDsn, RedisDsn, field_validator
|
| 10 |
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Settings(BaseSettings):
|
| 14 |
+
"""Application configuration with environment variable support"""
|
| 15 |
+
|
| 16 |
+
# Application
|
| 17 |
+
APP_NAME: str = "Multimodal Misinformation Detection API"
|
| 18 |
+
APP_VERSION: str = "1.0.0"
|
| 19 |
+
API_V1_PREFIX: str = "/api/v1"
|
| 20 |
+
DEBUG: bool = Field(default=False, validation_alias="DEBUG")
|
| 21 |
+
ENVIRONMENT: str = Field(default="production", validation_alias="ENVIRONMENT")
|
| 22 |
+
|
| 23 |
+
# Server
|
| 24 |
+
HOST: str = Field(default="0.0.0.0", validation_alias="HOST")
|
| 25 |
+
PORT: int = Field(default=8000, validation_alias="PORT")
|
| 26 |
+
WORKERS: int = Field(default=4, validation_alias="WORKERS")
|
| 27 |
+
RELOAD: bool = Field(default=False, validation_alias="RELOAD")
|
| 28 |
+
|
| 29 |
+
# Security
|
| 30 |
+
SECRET_KEY: str = Field(
|
| 31 |
+
default="CHANGE-ME-IN-PRODUCTION-USE-OPENSSL-RAND-HEX-32",
|
| 32 |
+
validation_alias="SECRET_KEY"
|
| 33 |
+
)
|
| 34 |
+
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
|
| 35 |
+
REFRESH_TOKEN_EXPIRE_DAYS: int = 7
|
| 36 |
+
ALGORITHM: str = "HS256"
|
| 37 |
+
|
| 38 |
+
# CORS
|
| 39 |
+
BACKEND_CORS_ORIGINS: List[str] = Field(
|
| 40 |
+
default=["http://localhost:3000", "http://localhost:8000"],
|
| 41 |
+
validation_alias="BACKEND_CORS_ORIGINS"
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
@field_validator("BACKEND_CORS_ORIGINS", mode="before")
|
| 45 |
+
@classmethod
|
| 46 |
+
def parse_cors_origins(cls, v):
|
| 47 |
+
if isinstance(v, str):
|
| 48 |
+
return [origin.strip() for origin in v.split(",")]
|
| 49 |
+
return v
|
| 50 |
+
|
| 51 |
+
# Database
|
| 52 |
+
POSTGRES_SERVER: str = Field(default="localhost", validation_alias="POSTGRES_SERVER")
|
| 53 |
+
POSTGRES_USER: str = Field(default="postgres", validation_alias="POSTGRES_USER")
|
| 54 |
+
POSTGRES_PASSWORD: str = Field(default="postgres", validation_alias="POSTGRES_PASSWORD")
|
| 55 |
+
POSTGRES_DB: str = Field(default="misinformation_detection", validation_alias="POSTGRES_DB")
|
| 56 |
+
POSTGRES_PORT: int = Field(default=5432, validation_alias="POSTGRES_PORT")
|
| 57 |
+
DATABASE_URL: Optional[str] = None
|
| 58 |
+
|
| 59 |
+
@field_validator("DATABASE_URL", mode="before")
|
| 60 |
+
@classmethod
|
| 61 |
+
def assemble_db_connection(cls, v, info):
|
| 62 |
+
if isinstance(v, str) and v:
|
| 63 |
+
return v
|
| 64 |
+
data = info.data
|
| 65 |
+
return f"postgresql://{data.get('POSTGRES_USER')}:{data.get('POSTGRES_PASSWORD')}@{data.get('POSTGRES_SERVER')}:{data.get('POSTGRES_PORT')}/{data.get('POSTGRES_DB')}"
|
| 66 |
+
|
| 67 |
+
# Redis
|
| 68 |
+
REDIS_HOST: str = Field(default="localhost", validation_alias="REDIS_HOST")
|
| 69 |
+
REDIS_PORT: int = Field(default=6379, validation_alias="REDIS_PORT")
|
| 70 |
+
REDIS_PASSWORD: Optional[str] = Field(default=None, validation_alias="REDIS_PASSWORD")
|
| 71 |
+
REDIS_DB: int = Field(default=0, validation_alias="REDIS_DB")
|
| 72 |
+
REDIS_URL: Optional[str] = None
|
| 73 |
+
|
| 74 |
+
@field_validator("REDIS_URL", mode="before")
|
| 75 |
+
@classmethod
|
| 76 |
+
def assemble_redis_connection(cls, v, info):
|
| 77 |
+
if isinstance(v, str) and v:
|
| 78 |
+
return v
|
| 79 |
+
data = info.data
|
| 80 |
+
password_part = f":{data.get('REDIS_PASSWORD')}@" if data.get('REDIS_PASSWORD') else ""
|
| 81 |
+
return f"redis://{password_part}{data.get('REDIS_HOST')}:{data.get('REDIS_PORT')}/{data.get('REDIS_DB')}"
|
| 82 |
+
|
| 83 |
+
# Cache
|
| 84 |
+
CACHE_TTL: int = Field(default=3600, validation_alias="CACHE_TTL") # 1 hour
|
| 85 |
+
CACHE_PREDICTIONS: bool = Field(default=True, validation_alias="CACHE_PREDICTIONS")
|
| 86 |
+
|
| 87 |
+
# Rate Limiting
|
| 88 |
+
RATE_LIMIT_ENABLED: bool = Field(default=True, validation_alias="RATE_LIMIT_ENABLED")
|
| 89 |
+
RATE_LIMIT_PER_MINUTE: int = Field(default=60, validation_alias="RATE_LIMIT_PER_MINUTE")
|
| 90 |
+
RATE_LIMIT_PER_HOUR: int = Field(default=1000, validation_alias="RATE_LIMIT_PER_HOUR")
|
| 91 |
+
|
| 92 |
+
# File Upload
|
| 93 |
+
MAX_UPLOAD_SIZE: int = Field(default=10 * 1024 * 1024, validation_alias="MAX_UPLOAD_SIZE") # 10MB
|
| 94 |
+
ALLOWED_IMAGE_TYPES: List[str] = Field(
|
| 95 |
+
default=["image/jpeg", "image/png", "image/webp"],
|
| 96 |
+
validation_alias="ALLOWED_IMAGE_TYPES"
|
| 97 |
+
)
|
| 98 |
+
ALLOWED_VIDEO_TYPES: List[str] = Field(
|
| 99 |
+
default=["video/mp4", "video/mpeg", "video/quicktime"],
|
| 100 |
+
validation_alias="ALLOWED_VIDEO_TYPES"
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
# ML Models
|
| 104 |
+
MODEL_CACHE_DIR: Path = Field(
|
| 105 |
+
default=Path(__file__).parent.parent.parent / "models",
|
| 106 |
+
validation_alias="MODEL_CACHE_DIR"
|
| 107 |
+
)
|
| 108 |
+
DEVICE: str = Field(default="cpu", validation_alias="DEVICE") # cpu or cuda
|
| 109 |
+
BATCH_SIZE: int = Field(default=32, validation_alias="BATCH_SIZE")
|
| 110 |
+
|
| 111 |
+
# Model paths
|
| 112 |
+
DEEPFAKE_MODEL: str = Field(
|
| 113 |
+
default="timm/efficientnet_b4.ra2_in1k",
|
| 114 |
+
validation_alias="DEEPFAKE_MODEL"
|
| 115 |
+
)
|
| 116 |
+
TEXT_CLASSIFIER_MODEL: str = Field(
|
| 117 |
+
default="roberta-base",
|
| 118 |
+
validation_alias="TEXT_CLASSIFIER_MODEL"
|
| 119 |
+
)
|
| 120 |
+
PERPLEXITY_MODEL: str = Field(
|
| 121 |
+
default="gpt2",
|
| 122 |
+
validation_alias="PERPLEXITY_MODEL"
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
# Logging
|
| 126 |
+
LOG_LEVEL: str = Field(default="INFO", validation_alias="LOG_LEVEL")
|
| 127 |
+
LOG_FORMAT: str = Field(default="json", validation_alias="LOG_FORMAT") # json or text
|
| 128 |
+
LOG_FILE: Optional[Path] = Field(default=None, validation_alias="LOG_FILE")
|
| 129 |
+
|
| 130 |
+
# Monitoring
|
| 131 |
+
ENABLE_METRICS: bool = Field(default=True, validation_alias="ENABLE_METRICS")
|
| 132 |
+
ENABLE_TRACING: bool = Field(default=False, validation_alias="ENABLE_TRACING")
|
| 133 |
+
METRICS_PORT: int = Field(default=9090, validation_alias="METRICS_PORT")
|
| 134 |
+
|
| 135 |
+
# Feature Flags
|
| 136 |
+
ENABLE_VIDEO_ANALYSIS: bool = Field(default=True, validation_alias="ENABLE_VIDEO_ANALYSIS")
|
| 137 |
+
ENABLE_AUDIO_ANALYSIS: bool = Field(default=True, validation_alias="ENABLE_AUDIO_ANALYSIS")
|
| 138 |
+
ENABLE_BATCH_PROCESSING: bool = Field(default=True, validation_alias="ENABLE_BATCH_PROCESSING")
|
| 139 |
+
ENABLE_ASYNC_TASKS: bool = Field(default=True, validation_alias="ENABLE_ASYNC_TASKS")
|
| 140 |
+
|
| 141 |
+
# Celery (for async tasks)
|
| 142 |
+
CELERY_BROKER_URL: Optional[str] = None
|
| 143 |
+
CELERY_RESULT_BACKEND: Optional[str] = None
|
| 144 |
+
|
| 145 |
+
@field_validator("CELERY_BROKER_URL", mode="before")
|
| 146 |
+
@classmethod
|
| 147 |
+
def set_celery_broker(cls, v, info):
|
| 148 |
+
if isinstance(v, str) and v:
|
| 149 |
+
return v
|
| 150 |
+
return info.data.get("REDIS_URL")
|
| 151 |
+
|
| 152 |
+
@field_validator("CELERY_RESULT_BACKEND", mode="before")
|
| 153 |
+
@classmethod
|
| 154 |
+
def set_celery_backend(cls, v, info):
|
| 155 |
+
if isinstance(v, str) and v:
|
| 156 |
+
return v
|
| 157 |
+
return info.data.get("REDIS_URL")
|
| 158 |
+
|
| 159 |
+
# Email (for notifications)
|
| 160 |
+
SMTP_HOST: Optional[str] = Field(default=None, validation_alias="SMTP_HOST")
|
| 161 |
+
SMTP_PORT: int = Field(default=587, validation_alias="SMTP_PORT")
|
| 162 |
+
SMTP_USER: Optional[str] = Field(default=None, validation_alias="SMTP_USER")
|
| 163 |
+
SMTP_PASSWORD: Optional[str] = Field(default=None, validation_alias="SMTP_PASSWORD")
|
| 164 |
+
EMAILS_FROM_EMAIL: Optional[str] = Field(default=None, validation_alias="EMAILS_FROM_EMAIL")
|
| 165 |
+
|
| 166 |
+
# Admin
|
| 167 |
+
FIRST_SUPERUSER_EMAIL: str = Field(
|
| 168 |
+
default="admin@example.com",
|
| 169 |
+
validation_alias="FIRST_SUPERUSER_EMAIL"
|
| 170 |
+
)
|
| 171 |
+
FIRST_SUPERUSER_PASSWORD: str = Field(
|
| 172 |
+
default="changeme",
|
| 173 |
+
validation_alias="FIRST_SUPERUSER_PASSWORD"
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
model_config = SettingsConfigDict(
|
| 177 |
+
env_file=".env",
|
| 178 |
+
env_file_encoding="utf-8",
|
| 179 |
+
case_sensitive=True,
|
| 180 |
+
extra="allow"
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
@property
|
| 184 |
+
def is_production(self) -> bool:
|
| 185 |
+
"""Check if running in production environment"""
|
| 186 |
+
return self.ENVIRONMENT.lower() == "production"
|
| 187 |
+
|
| 188 |
+
@property
|
| 189 |
+
def is_development(self) -> bool:
|
| 190 |
+
"""Check if running in development environment"""
|
| 191 |
+
return self.ENVIRONMENT.lower() == "development"
|
| 192 |
+
|
| 193 |
+
@property
|
| 194 |
+
def is_testing(self) -> bool:
|
| 195 |
+
"""Check if running in testing environment"""
|
| 196 |
+
return self.ENVIRONMENT.lower() == "testing"
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
# Global settings instance
|
| 200 |
+
settings = Settings()
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
# Validate critical production settings
|
| 204 |
+
def validate_production_config():
|
| 205 |
+
"""Validate that production settings are properly configured"""
|
| 206 |
+
if settings.is_production:
|
| 207 |
+
errors = []
|
| 208 |
+
|
| 209 |
+
if settings.SECRET_KEY == "CHANGE-ME-IN-PRODUCTION-USE-OPENSSL-RAND-HEX-32":
|
| 210 |
+
errors.append("SECRET_KEY must be changed in production")
|
| 211 |
+
|
| 212 |
+
if settings.FIRST_SUPERUSER_PASSWORD == "changeme":
|
| 213 |
+
errors.append("FIRST_SUPERUSER_PASSWORD must be changed in production")
|
| 214 |
+
|
| 215 |
+
if settings.DEBUG:
|
| 216 |
+
errors.append("DEBUG must be False in production")
|
| 217 |
+
|
| 218 |
+
if not settings.POSTGRES_PASSWORD or settings.POSTGRES_PASSWORD == "postgres":
|
| 219 |
+
errors.append("Strong POSTGRES_PASSWORD required in production")
|
| 220 |
+
|
| 221 |
+
if errors:
|
| 222 |
+
raise ValueError(
|
| 223 |
+
f"Production configuration errors:\n" + "\n".join(f" - {err}" for err in errors)
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
if __name__ == "__main__":
|
| 228 |
+
# Test configuration loading
|
| 229 |
+
print(f"Environment: {settings.ENVIRONMENT}")
|
| 230 |
+
print(f"Database URL: {settings.DATABASE_URL}")
|
| 231 |
+
print(f"Redis URL: {settings.REDIS_URL}")
|
| 232 |
+
print(f"Debug Mode: {settings.DEBUG}")
|
| 233 |
+
print(f"Rate Limiting: {settings.RATE_LIMIT_ENABLED}")
|
src/core/exceptions.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Custom Exception Classes for Production Error Handling
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from typing import Any, Dict, Optional
|
| 6 |
+
from fastapi import status
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class AppException(Exception):
|
| 10 |
+
"""Base application exception"""
|
| 11 |
+
|
| 12 |
+
def __init__(
|
| 13 |
+
self,
|
| 14 |
+
message: str,
|
| 15 |
+
status_code: int = status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 16 |
+
details: Optional[Dict[str, Any]] = None
|
| 17 |
+
):
|
| 18 |
+
self.message = message
|
| 19 |
+
self.status_code = status_code
|
| 20 |
+
self.details = details or {}
|
| 21 |
+
super().__init__(self.message)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class ValidationError(AppException):
|
| 25 |
+
"""Validation error exception"""
|
| 26 |
+
|
| 27 |
+
def __init__(self, message: str, details: Optional[Dict[str, Any]] = None):
|
| 28 |
+
super().__init__(
|
| 29 |
+
message=message,
|
| 30 |
+
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
| 31 |
+
details=details
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class AuthenticationError(AppException):
|
| 36 |
+
"""Authentication error exception"""
|
| 37 |
+
|
| 38 |
+
def __init__(self, message: str = "Authentication failed"):
|
| 39 |
+
super().__init__(
|
| 40 |
+
message=message,
|
| 41 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 42 |
+
details={"www_authenticate": "Bearer"}
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class AuthorizationError(AppException):
|
| 47 |
+
"""Authorization error exception"""
|
| 48 |
+
|
| 49 |
+
def __init__(self, message: str = "Insufficient permissions"):
|
| 50 |
+
super().__init__(
|
| 51 |
+
message=message,
|
| 52 |
+
status_code=status.HTTP_403_FORBIDDEN
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class ResourceNotFoundError(AppException):
|
| 57 |
+
"""Resource not found exception"""
|
| 58 |
+
|
| 59 |
+
def __init__(self, resource: str, identifier: Any):
|
| 60 |
+
super().__init__(
|
| 61 |
+
message=f"{resource} not found",
|
| 62 |
+
status_code=status.HTTP_404_NOT_FOUND,
|
| 63 |
+
details={"resource": resource, "identifier": str(identifier)}
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class RateLimitExceededError(AppException):
|
| 68 |
+
"""Rate limit exceeded exception"""
|
| 69 |
+
|
| 70 |
+
def __init__(self, limit: int, window: str):
|
| 71 |
+
super().__init__(
|
| 72 |
+
message=f"Rate limit exceeded: {limit} requests per {window}",
|
| 73 |
+
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
| 74 |
+
details={"limit": limit, "window": window}
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class ModelLoadError(AppException):
|
| 79 |
+
"""ML model loading error"""
|
| 80 |
+
|
| 81 |
+
def __init__(self, model_name: str, reason: str):
|
| 82 |
+
super().__init__(
|
| 83 |
+
message=f"Failed to load model: {model_name}",
|
| 84 |
+
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
| 85 |
+
details={"model": model_name, "reason": reason}
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class PredictionError(AppException):
|
| 90 |
+
"""ML prediction error"""
|
| 91 |
+
|
| 92 |
+
def __init__(self, message: str, details: Optional[Dict[str, Any]] = None):
|
| 93 |
+
super().__init__(
|
| 94 |
+
message=message,
|
| 95 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 96 |
+
details=details
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class FileUploadError(AppException):
|
| 101 |
+
"""File upload error"""
|
| 102 |
+
|
| 103 |
+
def __init__(self, message: str, details: Optional[Dict[str, Any]] = None):
|
| 104 |
+
super().__init__(
|
| 105 |
+
message=message,
|
| 106 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 107 |
+
details=details
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class DatabaseError(AppException):
|
| 112 |
+
"""Database operation error"""
|
| 113 |
+
|
| 114 |
+
def __init__(self, message: str, details: Optional[Dict[str, Any]] = None):
|
| 115 |
+
super().__init__(
|
| 116 |
+
message=message,
|
| 117 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 118 |
+
details=details
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class CacheError(AppException):
|
| 123 |
+
"""Cache operation error"""
|
| 124 |
+
|
| 125 |
+
def __init__(self, message: str, details: Optional[Dict[str, Any]] = None):
|
| 126 |
+
super().__init__(
|
| 127 |
+
message=message,
|
| 128 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 129 |
+
details=details
|
| 130 |
+
)
|
src/core/logging.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Production-Grade Structured Logging
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
import sys
|
| 7 |
+
import json
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
from typing import Any, Dict
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
from pythonjsonlogger import jsonlogger
|
| 13 |
+
|
| 14 |
+
from .config import settings
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class CustomJsonFormatter(jsonlogger.JsonFormatter):
|
| 18 |
+
"""Custom JSON formatter with additional fields"""
|
| 19 |
+
|
| 20 |
+
def add_fields(self, log_record: Dict[str, Any], record: logging.LogRecord, message_dict: dict):
|
| 21 |
+
super().add_fields(log_record, record, message_dict)
|
| 22 |
+
|
| 23 |
+
# Add timestamp
|
| 24 |
+
log_record['timestamp'] = datetime.utcnow().isoformat()
|
| 25 |
+
|
| 26 |
+
# Add log level
|
| 27 |
+
log_record['level'] = record.levelname
|
| 28 |
+
|
| 29 |
+
# Add application context
|
| 30 |
+
log_record['app'] = settings.APP_NAME
|
| 31 |
+
log_record['version'] = settings.APP_VERSION
|
| 32 |
+
log_record['environment'] = settings.ENVIRONMENT
|
| 33 |
+
|
| 34 |
+
# Add request ID if available (will be set by middleware)
|
| 35 |
+
if hasattr(record, 'request_id'):
|
| 36 |
+
log_record['request_id'] = record.request_id
|
| 37 |
+
|
| 38 |
+
# Add user ID if available
|
| 39 |
+
if hasattr(record, 'user_id'):
|
| 40 |
+
log_record['user_id'] = record.user_id
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def setup_logging():
|
| 44 |
+
"""Configure application logging"""
|
| 45 |
+
|
| 46 |
+
# Create logger
|
| 47 |
+
logger = logging.getLogger()
|
| 48 |
+
logger.setLevel(getattr(logging, settings.LOG_LEVEL.upper()))
|
| 49 |
+
|
| 50 |
+
# Remove existing handlers
|
| 51 |
+
logger.handlers = []
|
| 52 |
+
|
| 53 |
+
# Console handler
|
| 54 |
+
console_handler = logging.StreamHandler(sys.stdout)
|
| 55 |
+
|
| 56 |
+
if settings.LOG_FORMAT == "json":
|
| 57 |
+
console_formatter = CustomJsonFormatter(
|
| 58 |
+
'%(timestamp)s %(level)s %(name)s %(message)s'
|
| 59 |
+
)
|
| 60 |
+
else:
|
| 61 |
+
console_formatter = logging.Formatter(
|
| 62 |
+
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
console_handler.setFormatter(console_formatter)
|
| 66 |
+
logger.addHandler(console_handler)
|
| 67 |
+
|
| 68 |
+
# File handler (if configured)
|
| 69 |
+
if settings.LOG_FILE:
|
| 70 |
+
log_file = Path(settings.LOG_FILE)
|
| 71 |
+
log_file.parent.mkdir(parents=True, exist_ok=True)
|
| 72 |
+
|
| 73 |
+
file_handler = logging.FileHandler(log_file)
|
| 74 |
+
|
| 75 |
+
if settings.LOG_FORMAT == "json":
|
| 76 |
+
file_formatter = CustomJsonFormatter(
|
| 77 |
+
'%(timestamp)s %(level)s %(name)s %(message)s'
|
| 78 |
+
)
|
| 79 |
+
else:
|
| 80 |
+
file_formatter = logging.Formatter(
|
| 81 |
+
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
file_handler.setFormatter(file_formatter)
|
| 85 |
+
logger.addHandler(file_handler)
|
| 86 |
+
|
| 87 |
+
return logger
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# Create module-level logger
|
| 91 |
+
logger = setup_logging()
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def log_api_request(
|
| 95 |
+
method: str,
|
| 96 |
+
path: str,
|
| 97 |
+
status_code: int,
|
| 98 |
+
duration_ms: float,
|
| 99 |
+
user_id: str = None,
|
| 100 |
+
request_id: str = None
|
| 101 |
+
):
|
| 102 |
+
"""Log API request with structured data"""
|
| 103 |
+
logger.info(
|
| 104 |
+
"API Request",
|
| 105 |
+
extra={
|
| 106 |
+
"method": method,
|
| 107 |
+
"path": path,
|
| 108 |
+
"status_code": status_code,
|
| 109 |
+
"duration_ms": duration_ms,
|
| 110 |
+
"user_id": user_id,
|
| 111 |
+
"request_id": request_id,
|
| 112 |
+
"event_type": "api_request"
|
| 113 |
+
}
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def log_prediction(
|
| 118 |
+
model_type: str,
|
| 119 |
+
input_size: int,
|
| 120 |
+
confidence: float,
|
| 121 |
+
duration_ms: float,
|
| 122 |
+
cached: bool = False,
|
| 123 |
+
user_id: str = None
|
| 124 |
+
):
|
| 125 |
+
"""Log ML prediction with metrics"""
|
| 126 |
+
logger.info(
|
| 127 |
+
"ML Prediction",
|
| 128 |
+
extra={
|
| 129 |
+
"model_type": model_type,
|
| 130 |
+
"input_size": input_size,
|
| 131 |
+
"confidence": confidence,
|
| 132 |
+
"duration_ms": duration_ms,
|
| 133 |
+
"cached": cached,
|
| 134 |
+
"user_id": user_id,
|
| 135 |
+
"event_type": "prediction"
|
| 136 |
+
}
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def log_error(
|
| 141 |
+
error: Exception,
|
| 142 |
+
context: Dict[str, Any] = None,
|
| 143 |
+
user_id: str = None,
|
| 144 |
+
request_id: str = None
|
| 145 |
+
):
|
| 146 |
+
"""Log error with full context"""
|
| 147 |
+
logger.error(
|
| 148 |
+
f"Error: {str(error)}",
|
| 149 |
+
extra={
|
| 150 |
+
"error_type": type(error).__name__,
|
| 151 |
+
"error_message": str(error),
|
| 152 |
+
"context": context or {},
|
| 153 |
+
"user_id": user_id,
|
| 154 |
+
"request_id": request_id,
|
| 155 |
+
"event_type": "error"
|
| 156 |
+
},
|
| 157 |
+
exc_info=True
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
if __name__ == "__main__":
|
| 162 |
+
# Test logging
|
| 163 |
+
logger.info("Application starting")
|
| 164 |
+
logger.debug("Debug message")
|
| 165 |
+
logger.warning("Warning message")
|
| 166 |
+
logger.error("Error message")
|
| 167 |
+
|
| 168 |
+
log_api_request("GET", "/api/v1/health", 200, 5.2, request_id="test-123")
|
| 169 |
+
log_prediction("deepfake", 1024, 0.95, 125.5, cached=False, user_id="user-1")
|
src/core/middleware.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Middleware for Production Security
|
| 3 |
+
Rate limiting, request logging, security headers, CORS
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import time
|
| 7 |
+
import uuid
|
| 8 |
+
from typing import Callable
|
| 9 |
+
from fastapi import Request, Response, status
|
| 10 |
+
from fastapi.responses import JSONResponse
|
| 11 |
+
from starlette.middleware.base import BaseHTTPMiddleware
|
| 12 |
+
from starlette.middleware.cors import CORSMiddleware
|
| 13 |
+
|
| 14 |
+
from src.core.config import settings
|
| 15 |
+
from src.core.logging import logger, log_api_request, log_error
|
| 16 |
+
from src.core.exceptions import RateLimitExceededError
|
| 17 |
+
from src.core.cache import cache
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class RequestIDMiddleware(BaseHTTPMiddleware):
|
| 21 |
+
"""Add unique request ID to each request"""
|
| 22 |
+
|
| 23 |
+
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
| 24 |
+
request_id = str(uuid.uuid4())
|
| 25 |
+
request.state.request_id = request_id
|
| 26 |
+
|
| 27 |
+
response = await call_next(request)
|
| 28 |
+
response.headers["X-Request-ID"] = request_id
|
| 29 |
+
|
| 30 |
+
return response
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class RequestLoggingMiddleware(BaseHTTPMiddleware):
|
| 34 |
+
"""Log all API requests with performance metrics"""
|
| 35 |
+
|
| 36 |
+
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
| 37 |
+
start_time = time.time()
|
| 38 |
+
|
| 39 |
+
# Get request ID
|
| 40 |
+
request_id = getattr(request.state, "request_id", None)
|
| 41 |
+
|
| 42 |
+
# Process request
|
| 43 |
+
response = await call_next(request)
|
| 44 |
+
|
| 45 |
+
# Calculate duration
|
| 46 |
+
duration_ms = (time.time() - start_time) * 1000
|
| 47 |
+
|
| 48 |
+
# Log request
|
| 49 |
+
log_api_request(
|
| 50 |
+
method=request.method,
|
| 51 |
+
path=str(request.url.path),
|
| 52 |
+
status_code=response.status_code,
|
| 53 |
+
duration_ms=duration_ms,
|
| 54 |
+
user_id=getattr(request.state, "user_id", None),
|
| 55 |
+
request_id=request_id
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
# Add performance header
|
| 59 |
+
response.headers["X-Response-Time"] = f"{duration_ms:.2f}ms"
|
| 60 |
+
|
| 61 |
+
return response
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class RateLimitMiddleware(BaseHTTPMiddleware):
|
| 65 |
+
"""Rate limiting based on IP address or API key"""
|
| 66 |
+
|
| 67 |
+
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
| 68 |
+
if not settings.RATE_LIMIT_ENABLED:
|
| 69 |
+
return await call_next(request)
|
| 70 |
+
|
| 71 |
+
# Skip rate limiting for health check
|
| 72 |
+
if request.url.path == "/health":
|
| 73 |
+
return await call_next(request)
|
| 74 |
+
|
| 75 |
+
# Get identifier (IP address or user ID)
|
| 76 |
+
client_ip = request.client.host if request.client else "unknown"
|
| 77 |
+
user_id = getattr(request.state, "user_id", None)
|
| 78 |
+
identifier = f"user:{user_id}" if user_id else f"ip:{client_ip}"
|
| 79 |
+
|
| 80 |
+
# Check rate limit (per minute)
|
| 81 |
+
count = await cache.increment_rate_limit(identifier, 60)
|
| 82 |
+
|
| 83 |
+
if count > settings.RATE_LIMIT_PER_MINUTE:
|
| 84 |
+
logger.warning(f"Rate limit exceeded for {identifier}: {count} requests")
|
| 85 |
+
raise RateLimitExceededError(
|
| 86 |
+
limit=settings.RATE_LIMIT_PER_MINUTE,
|
| 87 |
+
window="minute"
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
# Add rate limit headers
|
| 91 |
+
response = await call_next(request)
|
| 92 |
+
response.headers["X-RateLimit-Limit"] = str(settings.RATE_LIMIT_PER_MINUTE)
|
| 93 |
+
response.headers["X-RateLimit-Remaining"] = str(max(0, settings.RATE_LIMIT_PER_MINUTE - count))
|
| 94 |
+
|
| 95 |
+
return response
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
| 99 |
+
"""Add security headers to responses"""
|
| 100 |
+
|
| 101 |
+
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
| 102 |
+
response = await call_next(request)
|
| 103 |
+
|
| 104 |
+
# Security headers
|
| 105 |
+
response.headers["X-Content-Type-Options"] = "nosniff"
|
| 106 |
+
response.headers["X-Frame-Options"] = "DENY"
|
| 107 |
+
response.headers["X-XSS-Protection"] = "1; mode=block"
|
| 108 |
+
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
|
| 109 |
+
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
| 110 |
+
|
| 111 |
+
# Content Security Policy
|
| 112 |
+
if settings.is_production:
|
| 113 |
+
response.headers["Content-Security-Policy"] = (
|
| 114 |
+
"default-src 'self'; "
|
| 115 |
+
"script-src 'self' 'unsafe-inline'; "
|
| 116 |
+
"style-src 'self' 'unsafe-inline'; "
|
| 117 |
+
"img-src 'self' data: https:; "
|
| 118 |
+
"font-src 'self' data:; "
|
| 119 |
+
"connect-src 'self'"
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
return response
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class ErrorHandlerMiddleware(BaseHTTPMiddleware):
|
| 126 |
+
"""Global error handler"""
|
| 127 |
+
|
| 128 |
+
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
| 129 |
+
try:
|
| 130 |
+
response = await call_next(request)
|
| 131 |
+
return response
|
| 132 |
+
except Exception as e:
|
| 133 |
+
# Log error
|
| 134 |
+
log_error(
|
| 135 |
+
error=e,
|
| 136 |
+
context={
|
| 137 |
+
"method": request.method,
|
| 138 |
+
"path": str(request.url.path),
|
| 139 |
+
"client": request.client.host if request.client else None
|
| 140 |
+
},
|
| 141 |
+
request_id=getattr(request.state, "request_id", None)
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
# Return error response
|
| 145 |
+
from src.core.exceptions import AppException
|
| 146 |
+
|
| 147 |
+
if isinstance(e, AppException):
|
| 148 |
+
return JSONResponse(
|
| 149 |
+
status_code=e.status_code,
|
| 150 |
+
content={
|
| 151 |
+
"error": type(e).__name__,
|
| 152 |
+
"message": e.message,
|
| 153 |
+
"details": e.details,
|
| 154 |
+
"request_id": getattr(request.state, "request_id", None)
|
| 155 |
+
}
|
| 156 |
+
)
|
| 157 |
+
else:
|
| 158 |
+
# Generic error response
|
| 159 |
+
return JSONResponse(
|
| 160 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 161 |
+
content={
|
| 162 |
+
"error": "InternalServerError",
|
| 163 |
+
"message": "An unexpected error occurred",
|
| 164 |
+
"request_id": getattr(request.state, "request_id", None)
|
| 165 |
+
}
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def setup_cors(app):
|
| 170 |
+
"""Configure CORS middleware"""
|
| 171 |
+
app.add_middleware(
|
| 172 |
+
CORSMiddleware,
|
| 173 |
+
allow_origins=settings.BACKEND_CORS_ORIGINS,
|
| 174 |
+
allow_credentials=True,
|
| 175 |
+
allow_methods=["*"],
|
| 176 |
+
allow_headers=["*"],
|
| 177 |
+
expose_headers=["X-Request-ID", "X-Response-Time", "X-RateLimit-Limit", "X-RateLimit-Remaining"]
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def setup_middleware(app):
|
| 182 |
+
"""Setup all middleware in correct order"""
|
| 183 |
+
|
| 184 |
+
# Order matters! Apply in reverse order of execution
|
| 185 |
+
|
| 186 |
+
# Error handling (outermost)
|
| 187 |
+
app.add_middleware(ErrorHandlerMiddleware)
|
| 188 |
+
|
| 189 |
+
# Security headers
|
| 190 |
+
app.add_middleware(SecurityHeadersMiddleware)
|
| 191 |
+
|
| 192 |
+
# Rate limiting
|
| 193 |
+
app.add_middleware(RateLimitMiddleware)
|
| 194 |
+
|
| 195 |
+
# Request logging
|
| 196 |
+
app.add_middleware(RequestLoggingMiddleware)
|
| 197 |
+
|
| 198 |
+
# Request ID (innermost)
|
| 199 |
+
app.add_middleware(RequestIDMiddleware)
|
| 200 |
+
|
| 201 |
+
# CORS
|
| 202 |
+
setup_cors(app)
|
| 203 |
+
|
| 204 |
+
logger.info("Middleware configured successfully")
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
if __name__ == "__main__":
|
| 208 |
+
print("Middleware module loaded")
|
| 209 |
+
print(f"Rate limiting: {'Enabled' if settings.RATE_LIMIT_ENABLED else 'Disabled'}")
|
| 210 |
+
print(f"CORS origins: {settings.BACKEND_CORS_ORIGINS}")
|
src/core/security.py
ADDED
|
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Authentication and Authorization
|
| 3 |
+
JWT tokens, API keys, password hashing
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import secrets
|
| 7 |
+
from datetime import datetime, timedelta
|
| 8 |
+
from typing import Optional, Union
|
| 9 |
+
from jose import JWTError, jwt
|
| 10 |
+
from passlib.context import CryptContext
|
| 11 |
+
from fastapi import Depends, HTTPException, status, Security
|
| 12 |
+
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials, APIKeyHeader
|
| 13 |
+
from sqlalchemy.orm import Session
|
| 14 |
+
|
| 15 |
+
from src.core.config import settings
|
| 16 |
+
from src.core.exceptions import AuthenticationError, AuthorizationError
|
| 17 |
+
from src.db.models import User, APIKey, get_db
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# Password hashing
|
| 21 |
+
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
| 22 |
+
|
| 23 |
+
# Security schemes
|
| 24 |
+
bearer_scheme = HTTPBearer()
|
| 25 |
+
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
| 29 |
+
"""Verify password against hash"""
|
| 30 |
+
return pwd_context.verify(plain_password, hashed_password)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def get_password_hash(password: str) -> str:
|
| 34 |
+
"""Generate password hash"""
|
| 35 |
+
return pwd_context.hash(password)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def create_access_token(
|
| 39 |
+
data: dict,
|
| 40 |
+
expires_delta: Optional[timedelta] = None
|
| 41 |
+
) -> str:
|
| 42 |
+
"""Create JWT access token"""
|
| 43 |
+
to_encode = data.copy()
|
| 44 |
+
|
| 45 |
+
if expires_delta:
|
| 46 |
+
expire = datetime.utcnow() + expires_delta
|
| 47 |
+
else:
|
| 48 |
+
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
| 49 |
+
|
| 50 |
+
to_encode.update({"exp": expire})
|
| 51 |
+
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
| 52 |
+
return encoded_jwt
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def create_refresh_token(
|
| 56 |
+
data: dict,
|
| 57 |
+
expires_delta: Optional[timedelta] = None
|
| 58 |
+
) -> str:
|
| 59 |
+
"""Create JWT refresh token"""
|
| 60 |
+
to_encode = data.copy()
|
| 61 |
+
|
| 62 |
+
if expires_delta:
|
| 63 |
+
expire = datetime.utcnow() + expires_delta
|
| 64 |
+
else:
|
| 65 |
+
expire = datetime.utcnow() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
| 66 |
+
|
| 67 |
+
to_encode.update({"exp": expire, "type": "refresh"})
|
| 68 |
+
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
| 69 |
+
return encoded_jwt
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def decode_token(token: str) -> dict:
|
| 73 |
+
"""Decode and validate JWT token"""
|
| 74 |
+
try:
|
| 75 |
+
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
| 76 |
+
return payload
|
| 77 |
+
except JWTError:
|
| 78 |
+
raise AuthenticationError("Invalid or expired token")
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def generate_api_key() -> str:
|
| 82 |
+
"""Generate secure API key"""
|
| 83 |
+
return secrets.token_urlsafe(32)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# Dependency: Get current user from JWT token
|
| 87 |
+
async def get_current_user(
|
| 88 |
+
credentials: HTTPAuthorizationCredentials = Security(bearer_scheme),
|
| 89 |
+
db: Session = Depends(get_db)
|
| 90 |
+
) -> User:
|
| 91 |
+
"""Get current authenticated user from JWT token"""
|
| 92 |
+
|
| 93 |
+
try:
|
| 94 |
+
token = credentials.credentials
|
| 95 |
+
payload = decode_token(token)
|
| 96 |
+
|
| 97 |
+
user_id: int = payload.get("sub")
|
| 98 |
+
if user_id is None:
|
| 99 |
+
raise AuthenticationError("Invalid token payload")
|
| 100 |
+
|
| 101 |
+
except JWTError:
|
| 102 |
+
raise AuthenticationError("Could not validate credentials")
|
| 103 |
+
|
| 104 |
+
user = db.query(User).filter(User.id == user_id).first()
|
| 105 |
+
if user is None:
|
| 106 |
+
raise AuthenticationError("User not found")
|
| 107 |
+
|
| 108 |
+
if not user.is_active:
|
| 109 |
+
raise AuthenticationError("User account is inactive")
|
| 110 |
+
|
| 111 |
+
return user
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
# Dependency: Get current user from API key
|
| 115 |
+
async def get_current_user_from_api_key(
|
| 116 |
+
api_key: Optional[str] = Security(api_key_header),
|
| 117 |
+
db: Session = Depends(get_db)
|
| 118 |
+
) -> Optional[User]:
|
| 119 |
+
"""Get current user from API key"""
|
| 120 |
+
|
| 121 |
+
if not api_key:
|
| 122 |
+
return None
|
| 123 |
+
|
| 124 |
+
# Find API key in database
|
| 125 |
+
api_key_obj = db.query(APIKey).filter(
|
| 126 |
+
APIKey.key == api_key,
|
| 127 |
+
APIKey.is_active == True
|
| 128 |
+
).first()
|
| 129 |
+
|
| 130 |
+
if not api_key_obj:
|
| 131 |
+
raise AuthenticationError("Invalid API key")
|
| 132 |
+
|
| 133 |
+
# Check expiration
|
| 134 |
+
if api_key_obj.expires_at and api_key_obj.expires_at < datetime.utcnow():
|
| 135 |
+
raise AuthenticationError("API key has expired")
|
| 136 |
+
|
| 137 |
+
# Update last used timestamp
|
| 138 |
+
api_key_obj.last_used_at = datetime.utcnow()
|
| 139 |
+
db.commit()
|
| 140 |
+
|
| 141 |
+
# Get user
|
| 142 |
+
user = db.query(User).filter(User.id == api_key_obj.user_id).first()
|
| 143 |
+
|
| 144 |
+
if not user or not user.is_active:
|
| 145 |
+
raise AuthenticationError("User not found or inactive")
|
| 146 |
+
|
| 147 |
+
return user
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
# Dependency: Get current user (try JWT first, then API key)
|
| 151 |
+
async def get_current_user_flexible(
|
| 152 |
+
bearer: Optional[HTTPAuthorizationCredentials] = Security(bearer_scheme, auto_error=False),
|
| 153 |
+
api_key: Optional[str] = Security(api_key_header),
|
| 154 |
+
db: Session = Depends(get_db)
|
| 155 |
+
) -> User:
|
| 156 |
+
"""Get current user from JWT or API key"""
|
| 157 |
+
|
| 158 |
+
# Try JWT token first
|
| 159 |
+
if bearer:
|
| 160 |
+
try:
|
| 161 |
+
token = bearer.credentials
|
| 162 |
+
payload = decode_token(token)
|
| 163 |
+
user_id: int = payload.get("sub")
|
| 164 |
+
|
| 165 |
+
user = db.query(User).filter(User.id == user_id).first()
|
| 166 |
+
if user and user.is_active:
|
| 167 |
+
return user
|
| 168 |
+
except:
|
| 169 |
+
pass
|
| 170 |
+
|
| 171 |
+
# Try API key
|
| 172 |
+
if api_key:
|
| 173 |
+
user = await get_current_user_from_api_key(api_key, db)
|
| 174 |
+
if user:
|
| 175 |
+
return user
|
| 176 |
+
|
| 177 |
+
raise AuthenticationError("Authentication required")
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
# Dependency: Require superuser
|
| 181 |
+
async def get_current_superuser(
|
| 182 |
+
current_user: User = Depends(get_current_user_flexible)
|
| 183 |
+
) -> User:
|
| 184 |
+
"""Require superuser privileges"""
|
| 185 |
+
|
| 186 |
+
if not current_user.is_superuser:
|
| 187 |
+
raise AuthorizationError("Superuser privileges required")
|
| 188 |
+
|
| 189 |
+
return current_user
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
# Helper: Authenticate user
|
| 193 |
+
def authenticate_user(
|
| 194 |
+
db: Session,
|
| 195 |
+
email: str,
|
| 196 |
+
password: str
|
| 197 |
+
) -> Optional[User]:
|
| 198 |
+
"""Authenticate user with email and password"""
|
| 199 |
+
|
| 200 |
+
user = db.query(User).filter(User.email == email).first()
|
| 201 |
+
if not user:
|
| 202 |
+
return None
|
| 203 |
+
|
| 204 |
+
if not verify_password(password, user.hashed_password):
|
| 205 |
+
return None
|
| 206 |
+
|
| 207 |
+
return user
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
# Helper: Create user
|
| 211 |
+
def create_user(
|
| 212 |
+
db: Session,
|
| 213 |
+
email: str,
|
| 214 |
+
password: str,
|
| 215 |
+
full_name: Optional[str] = None,
|
| 216 |
+
is_superuser: bool = False
|
| 217 |
+
) -> User:
|
| 218 |
+
"""Create new user"""
|
| 219 |
+
|
| 220 |
+
# Check if user exists
|
| 221 |
+
existing_user = db.query(User).filter(User.email == email).first()
|
| 222 |
+
if existing_user:
|
| 223 |
+
raise ValueError("User with this email already exists")
|
| 224 |
+
|
| 225 |
+
# Create user
|
| 226 |
+
user = User(
|
| 227 |
+
email=email,
|
| 228 |
+
hashed_password=get_password_hash(password),
|
| 229 |
+
full_name=full_name,
|
| 230 |
+
is_superuser=is_superuser,
|
| 231 |
+
is_active=True
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
db.add(user)
|
| 235 |
+
db.commit()
|
| 236 |
+
db.refresh(user)
|
| 237 |
+
|
| 238 |
+
return user
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
# Helper: Create API key
|
| 242 |
+
def create_api_key_for_user(
|
| 243 |
+
db: Session,
|
| 244 |
+
user_id: int,
|
| 245 |
+
name: Optional[str] = None,
|
| 246 |
+
expires_days: Optional[int] = None
|
| 247 |
+
) -> APIKey:
|
| 248 |
+
"""Create API key for user"""
|
| 249 |
+
|
| 250 |
+
key = generate_api_key()
|
| 251 |
+
|
| 252 |
+
api_key = APIKey(
|
| 253 |
+
key=key,
|
| 254 |
+
name=name or "API Key",
|
| 255 |
+
user_id=user_id,
|
| 256 |
+
is_active=True,
|
| 257 |
+
rate_limit_per_minute=settings.RATE_LIMIT_PER_MINUTE,
|
| 258 |
+
rate_limit_per_hour=settings.RATE_LIMIT_PER_HOUR,
|
| 259 |
+
expires_at=datetime.utcnow() + timedelta(days=expires_days) if expires_days else None
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
db.add(api_key)
|
| 263 |
+
db.commit()
|
| 264 |
+
db.refresh(api_key)
|
| 265 |
+
|
| 266 |
+
return api_key
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
if __name__ == "__main__":
|
| 270 |
+
# Test password hashing
|
| 271 |
+
password = "test_password_123"
|
| 272 |
+
hashed = get_password_hash(password)
|
| 273 |
+
print(f"Hashed: {hashed}")
|
| 274 |
+
print(f"Verified: {verify_password(password, hashed)}")
|
| 275 |
+
|
| 276 |
+
# Test JWT token creation
|
| 277 |
+
token = create_access_token({"sub": 1, "email": "test@example.com"})
|
| 278 |
+
print(f"Token: {token}")
|
| 279 |
+
|
| 280 |
+
payload = decode_token(token)
|
| 281 |
+
print(f"Decoded: {payload}")
|
| 282 |
+
|
| 283 |
+
# Test API key generation
|
| 284 |
+
api_key = generate_api_key()
|
| 285 |
+
print(f"API Key: {api_key}")
|
src/db/__init__.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Database package"""
|
| 2 |
+
|
| 3 |
+
from .models import (
|
| 4 |
+
Base,
|
| 5 |
+
User,
|
| 6 |
+
APIKey,
|
| 7 |
+
RequestLog,
|
| 8 |
+
PredictionLog,
|
| 9 |
+
SystemMetric,
|
| 10 |
+
engine,
|
| 11 |
+
SessionLocal,
|
| 12 |
+
get_db,
|
| 13 |
+
create_tables,
|
| 14 |
+
drop_tables
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
__all__ = [
|
| 18 |
+
"Base",
|
| 19 |
+
"User",
|
| 20 |
+
"APIKey",
|
| 21 |
+
"RequestLog",
|
| 22 |
+
"PredictionLog",
|
| 23 |
+
"SystemMetric",
|
| 24 |
+
"engine",
|
| 25 |
+
"SessionLocal",
|
| 26 |
+
"get_db",
|
| 27 |
+
"create_tables",
|
| 28 |
+
"drop_tables"
|
| 29 |
+
]
|
src/db/models.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Database Models and Session Management
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
from typing import Optional
|
| 7 |
+
from sqlalchemy import (
|
| 8 |
+
Boolean, Column, DateTime, Float, Integer, String, Text, JSON, ForeignKey, Index
|
| 9 |
+
)
|
| 10 |
+
from sqlalchemy.ext.declarative import declarative_base
|
| 11 |
+
from sqlalchemy.orm import relationship, Session
|
| 12 |
+
from sqlalchemy import create_engine
|
| 13 |
+
from sqlalchemy.orm import sessionmaker
|
| 14 |
+
from sqlalchemy.pool import QueuePool
|
| 15 |
+
|
| 16 |
+
from src.core.config import settings
|
| 17 |
+
|
| 18 |
+
# Create declarative base
|
| 19 |
+
Base = declarative_base()
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# Database Models
|
| 23 |
+
class User(Base):
|
| 24 |
+
"""User model for authentication"""
|
| 25 |
+
__tablename__ = "users"
|
| 26 |
+
|
| 27 |
+
id = Column(Integer, primary_key=True, index=True)
|
| 28 |
+
email = Column(String(255), unique=True, index=True, nullable=False)
|
| 29 |
+
hashed_password = Column(String(255), nullable=False)
|
| 30 |
+
full_name = Column(String(255))
|
| 31 |
+
is_active = Column(Boolean, default=True)
|
| 32 |
+
is_superuser = Column(Boolean, default=False)
|
| 33 |
+
created_at = Column(DateTime, default=datetime.utcnow)
|
| 34 |
+
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
| 35 |
+
|
| 36 |
+
# Relationships
|
| 37 |
+
api_keys = relationship("APIKey", back_populates="user", cascade="all, delete-orphan")
|
| 38 |
+
requests = relationship("RequestLog", back_populates="user", cascade="all, delete-orphan")
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class APIKey(Base):
|
| 42 |
+
"""API Key model for API authentication"""
|
| 43 |
+
__tablename__ = "api_keys"
|
| 44 |
+
|
| 45 |
+
id = Column(Integer, primary_key=True, index=True)
|
| 46 |
+
key = Column(String(64), unique=True, index=True, nullable=False)
|
| 47 |
+
name = Column(String(255))
|
| 48 |
+
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
|
| 49 |
+
is_active = Column(Boolean, default=True)
|
| 50 |
+
rate_limit_per_minute = Column(Integer, default=60)
|
| 51 |
+
rate_limit_per_hour = Column(Integer, default=1000)
|
| 52 |
+
created_at = Column(DateTime, default=datetime.utcnow)
|
| 53 |
+
last_used_at = Column(DateTime)
|
| 54 |
+
expires_at = Column(DateTime)
|
| 55 |
+
|
| 56 |
+
# Relationships
|
| 57 |
+
user = relationship("User", back_populates="api_keys")
|
| 58 |
+
|
| 59 |
+
# Indexes
|
| 60 |
+
__table_args__ = (
|
| 61 |
+
Index('idx_apikey_user_active', 'user_id', 'is_active'),
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class RequestLog(Base):
|
| 66 |
+
"""Request logging for analytics and debugging"""
|
| 67 |
+
__tablename__ = "request_logs"
|
| 68 |
+
|
| 69 |
+
id = Column(Integer, primary_key=True, index=True)
|
| 70 |
+
request_id = Column(String(64), unique=True, index=True)
|
| 71 |
+
user_id = Column(Integer, ForeignKey("users.id"), nullable=True)
|
| 72 |
+
api_key_id = Column(Integer, ForeignKey("api_keys.id"), nullable=True)
|
| 73 |
+
|
| 74 |
+
# Request details
|
| 75 |
+
method = Column(String(10))
|
| 76 |
+
path = Column(String(500))
|
| 77 |
+
query_params = Column(JSON)
|
| 78 |
+
status_code = Column(Integer)
|
| 79 |
+
|
| 80 |
+
# Performance
|
| 81 |
+
duration_ms = Column(Float)
|
| 82 |
+
|
| 83 |
+
# Client info
|
| 84 |
+
ip_address = Column(String(45))
|
| 85 |
+
user_agent = Column(Text)
|
| 86 |
+
|
| 87 |
+
# Timestamps
|
| 88 |
+
created_at = Column(DateTime, default=datetime.utcnow, index=True)
|
| 89 |
+
|
| 90 |
+
# Relationships
|
| 91 |
+
user = relationship("User", back_populates="requests")
|
| 92 |
+
|
| 93 |
+
# Indexes
|
| 94 |
+
__table_args__ = (
|
| 95 |
+
Index('idx_request_user_created', 'user_id', 'created_at'),
|
| 96 |
+
Index('idx_request_created', 'created_at'),
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class PredictionLog(Base):
|
| 101 |
+
"""ML prediction logging for analytics"""
|
| 102 |
+
__tablename__ = "prediction_logs"
|
| 103 |
+
|
| 104 |
+
id = Column(Integer, primary_key=True, index=True)
|
| 105 |
+
request_id = Column(String(64), index=True)
|
| 106 |
+
user_id = Column(Integer, ForeignKey("users.id"), nullable=True)
|
| 107 |
+
|
| 108 |
+
# Prediction details
|
| 109 |
+
model_type = Column(String(50), index=True) # deepfake, ai_text, anomaly
|
| 110 |
+
input_type = Column(String(20)) # text, image, video, audio
|
| 111 |
+
input_size = Column(Integer) # bytes or character count
|
| 112 |
+
|
| 113 |
+
# Results
|
| 114 |
+
prediction = Column(String(50))
|
| 115 |
+
confidence = Column(Float)
|
| 116 |
+
details = Column(JSON)
|
| 117 |
+
|
| 118 |
+
# Performance
|
| 119 |
+
duration_ms = Column(Float)
|
| 120 |
+
cached = Column(Boolean, default=False)
|
| 121 |
+
|
| 122 |
+
# Timestamps
|
| 123 |
+
created_at = Column(DateTime, default=datetime.utcnow, index=True)
|
| 124 |
+
|
| 125 |
+
# Indexes
|
| 126 |
+
__table_args__ = (
|
| 127 |
+
Index('idx_prediction_model_created', 'model_type', 'created_at'),
|
| 128 |
+
Index('idx_prediction_user_created', 'user_id', 'created_at'),
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class SystemMetric(Base):
|
| 133 |
+
"""System performance metrics"""
|
| 134 |
+
__tablename__ = "system_metrics"
|
| 135 |
+
|
| 136 |
+
id = Column(Integer, primary_key=True, index=True)
|
| 137 |
+
metric_name = Column(String(100), index=True)
|
| 138 |
+
metric_value = Column(Float)
|
| 139 |
+
labels = Column(JSON)
|
| 140 |
+
created_at = Column(DateTime, default=datetime.utcnow, index=True)
|
| 141 |
+
|
| 142 |
+
# Indexes
|
| 143 |
+
__table_args__ = (
|
| 144 |
+
Index('idx_metric_name_created', 'metric_name', 'created_at'),
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
# Database Engine and Session
|
| 149 |
+
engine = create_engine(
|
| 150 |
+
settings.DATABASE_URL,
|
| 151 |
+
poolclass=QueuePool,
|
| 152 |
+
pool_size=10,
|
| 153 |
+
max_overflow=20,
|
| 154 |
+
pool_pre_ping=True,
|
| 155 |
+
echo=settings.DEBUG
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
# Dependency for FastAPI
|
| 162 |
+
def get_db():
|
| 163 |
+
"""Get database session"""
|
| 164 |
+
db = SessionLocal()
|
| 165 |
+
try:
|
| 166 |
+
yield db
|
| 167 |
+
finally:
|
| 168 |
+
db.close()
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
# Database initialization
|
| 172 |
+
def create_tables():
|
| 173 |
+
"""Create all tables"""
|
| 174 |
+
Base.metadata.create_all(bind=engine)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def drop_tables():
|
| 178 |
+
"""Drop all tables (use with caution!)"""
|
| 179 |
+
Base.metadata.drop_all(bind=engine)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
if __name__ == "__main__":
|
| 183 |
+
print("Creating database tables...")
|
| 184 |
+
create_tables()
|
| 185 |
+
print("Tables created successfully!")
|
src/detection/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Init file for detection module."""
|
| 2 |
+
|
| 3 |
+
from .deepfake_detector import DeepfakeDetector
|
| 4 |
+
from .ai_text_detector import AITextDetector
|
| 5 |
+
from .anomaly_detector import AnomalyDetector
|
| 6 |
+
|
| 7 |
+
__all__ = ['DeepfakeDetector', 'AITextDetector', 'AnomalyDetector']
|
src/detection/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (490 Bytes). View file
|
|
|
src/detection/__pycache__/ai_text_detector.cpython-313.pyc
ADDED
|
Binary file (15.4 kB). View file
|
|
|
src/detection/__pycache__/anomaly_detector.cpython-313.pyc
ADDED
|
Binary file (17.6 kB). View file
|
|
|
src/detection/__pycache__/deepfake_detector.cpython-313.pyc
ADDED
|
Binary file (17.1 kB). View file
|
|
|
src/detection/ai_text_detector.py
ADDED
|
@@ -0,0 +1,402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
AI Text Detection Module
|
| 3 |
+
|
| 4 |
+
Detects AI-generated text from models like GPT-4, ChatGPT, Gemini, Claude.
|
| 5 |
+
|
| 6 |
+
Uses multiple detection strategies:
|
| 7 |
+
1. Perplexity analysis
|
| 8 |
+
2. Token probability distribution
|
| 9 |
+
3. Stylometric features
|
| 10 |
+
4. Statistical patterns
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
from transformers import (
|
| 16 |
+
AutoTokenizer,
|
| 17 |
+
AutoModelForSequenceClassification,
|
| 18 |
+
GPT2LMHeadModel,
|
| 19 |
+
GPT2Tokenizer
|
| 20 |
+
)
|
| 21 |
+
from typing import Dict, List, Tuple
|
| 22 |
+
import numpy as np
|
| 23 |
+
import re
|
| 24 |
+
from collections import Counter
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class AITextDetector:
|
| 28 |
+
"""
|
| 29 |
+
Detects AI-generated text using multiple approaches.
|
| 30 |
+
|
| 31 |
+
Combines:
|
| 32 |
+
- Fine-tuned BERT classifier
|
| 33 |
+
- Perplexity-based detection
|
| 34 |
+
- Statistical feature analysis
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
model_path: str = "models/ai_text_detector.pth",
|
| 40 |
+
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
| 41 |
+
threshold: float = 0.7
|
| 42 |
+
):
|
| 43 |
+
"""
|
| 44 |
+
Initialize AI text detector.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
model_path: Path to fine-tuned model
|
| 48 |
+
device: Device for inference
|
| 49 |
+
threshold: Detection threshold
|
| 50 |
+
"""
|
| 51 |
+
self.device = device
|
| 52 |
+
self.threshold = threshold
|
| 53 |
+
|
| 54 |
+
# Load classifier model
|
| 55 |
+
self.tokenizer = AutoTokenizer.from_pretrained("roberta-base")
|
| 56 |
+
self.classifier = AutoModelForSequenceClassification.from_pretrained(
|
| 57 |
+
"roberta-base",
|
| 58 |
+
num_labels=2
|
| 59 |
+
).to(device)
|
| 60 |
+
|
| 61 |
+
# Load GPT-2 for perplexity calculation
|
| 62 |
+
self.gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
| 63 |
+
self.gpt2_model = GPT2LMHeadModel.from_pretrained("gpt2").to(device)
|
| 64 |
+
self.gpt2_model.eval()
|
| 65 |
+
|
| 66 |
+
self.classifier.eval()
|
| 67 |
+
|
| 68 |
+
print("✓ AI Text Detector initialized")
|
| 69 |
+
|
| 70 |
+
def analyze_text(
|
| 71 |
+
self,
|
| 72 |
+
text: str,
|
| 73 |
+
detailed: bool = True
|
| 74 |
+
) -> Dict:
|
| 75 |
+
"""
|
| 76 |
+
Analyze text for AI generation indicators.
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
text: Input text to analyze
|
| 80 |
+
detailed: Return detailed analysis
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
Detection results dictionary
|
| 84 |
+
"""
|
| 85 |
+
if len(text.strip()) < 10:
|
| 86 |
+
return {
|
| 87 |
+
'verdict': 'TOO_SHORT',
|
| 88 |
+
'confidence': 0.0,
|
| 89 |
+
'explanation': 'Text too short for reliable analysis (min 10 chars)'
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
# Method 1: Classifier-based detection
|
| 93 |
+
classifier_score = self._classifier_detection(text)
|
| 94 |
+
|
| 95 |
+
# Method 2: Perplexity-based detection
|
| 96 |
+
perplexity = self._calculate_perplexity(text)
|
| 97 |
+
perplexity_score = self._perplexity_to_score(perplexity)
|
| 98 |
+
|
| 99 |
+
# Method 3: Statistical feature analysis
|
| 100 |
+
statistical_score = self._statistical_analysis(text)
|
| 101 |
+
|
| 102 |
+
# Ensemble the scores
|
| 103 |
+
final_score = (
|
| 104 |
+
0.5 * classifier_score +
|
| 105 |
+
0.3 * perplexity_score +
|
| 106 |
+
0.2 * statistical_score
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
is_ai_generated = final_score > self.threshold
|
| 110 |
+
|
| 111 |
+
result = {
|
| 112 |
+
'verdict': 'AI_GENERATED' if is_ai_generated else 'HUMAN_WRITTEN',
|
| 113 |
+
'confidence': float(final_score),
|
| 114 |
+
'threshold': self.threshold,
|
| 115 |
+
'perplexity': float(perplexity),
|
| 116 |
+
'explanation': self._generate_explanation(final_score, perplexity)
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
if detailed:
|
| 120 |
+
result['detailed_scores'] = {
|
| 121 |
+
'classifier': float(classifier_score),
|
| 122 |
+
'perplexity': float(perplexity_score),
|
| 123 |
+
'statistical': float(statistical_score)
|
| 124 |
+
}
|
| 125 |
+
result['features'] = self._extract_features(text)
|
| 126 |
+
result['indicators'] = self._identify_indicators(text, final_score)
|
| 127 |
+
|
| 128 |
+
return result
|
| 129 |
+
|
| 130 |
+
def _classifier_detection(self, text: str) -> float:
|
| 131 |
+
"""Use fine-tuned classifier for detection."""
|
| 132 |
+
# Tokenize
|
| 133 |
+
inputs = self.tokenizer(
|
| 134 |
+
text,
|
| 135 |
+
return_tensors="pt",
|
| 136 |
+
truncation=True,
|
| 137 |
+
max_length=512,
|
| 138 |
+
padding=True
|
| 139 |
+
).to(self.device)
|
| 140 |
+
|
| 141 |
+
# Get prediction
|
| 142 |
+
with torch.no_grad():
|
| 143 |
+
outputs = self.classifier(**inputs)
|
| 144 |
+
logits = outputs.logits
|
| 145 |
+
probs = torch.softmax(logits, dim=-1)
|
| 146 |
+
ai_prob = probs[0][1].item() # Probability of AI-generated
|
| 147 |
+
|
| 148 |
+
return ai_prob
|
| 149 |
+
|
| 150 |
+
def _calculate_perplexity(self, text: str) -> float:
|
| 151 |
+
"""
|
| 152 |
+
Calculate perplexity using GPT-2.
|
| 153 |
+
|
| 154 |
+
AI-generated text typically has lower perplexity.
|
| 155 |
+
"""
|
| 156 |
+
# Tokenize
|
| 157 |
+
encodings = self.gpt2_tokenizer(
|
| 158 |
+
text,
|
| 159 |
+
return_tensors="pt",
|
| 160 |
+
truncation=True,
|
| 161 |
+
max_length=1024
|
| 162 |
+
).to(self.device)
|
| 163 |
+
|
| 164 |
+
max_length = encodings.input_ids.size(1)
|
| 165 |
+
|
| 166 |
+
# Calculate loss
|
| 167 |
+
with torch.no_grad():
|
| 168 |
+
outputs = self.gpt2_model(**encodings, labels=encodings.input_ids)
|
| 169 |
+
loss = outputs.loss
|
| 170 |
+
|
| 171 |
+
# Perplexity = exp(loss)
|
| 172 |
+
perplexity = torch.exp(loss).item()
|
| 173 |
+
|
| 174 |
+
return perplexity
|
| 175 |
+
|
| 176 |
+
def _perplexity_to_score(self, perplexity: float) -> float:
|
| 177 |
+
"""
|
| 178 |
+
Convert perplexity to detection score.
|
| 179 |
+
|
| 180 |
+
Lower perplexity → higher AI probability
|
| 181 |
+
"""
|
| 182 |
+
# Typical ranges:
|
| 183 |
+
# Human text: 50-300
|
| 184 |
+
# AI text: 10-80
|
| 185 |
+
|
| 186 |
+
if perplexity < 20:
|
| 187 |
+
return 0.95 # Very likely AI
|
| 188 |
+
elif perplexity < 50:
|
| 189 |
+
return 0.75
|
| 190 |
+
elif perplexity < 100:
|
| 191 |
+
return 0.50
|
| 192 |
+
elif perplexity < 200:
|
| 193 |
+
return 0.25
|
| 194 |
+
else:
|
| 195 |
+
return 0.10 # Likely human
|
| 196 |
+
|
| 197 |
+
def _statistical_analysis(self, text: str) -> float:
|
| 198 |
+
"""
|
| 199 |
+
Analyze statistical features of text.
|
| 200 |
+
|
| 201 |
+
AI-generated text often has:
|
| 202 |
+
- More uniform sentence lengths
|
| 203 |
+
- Consistent vocabulary diversity
|
| 204 |
+
- Predictable structure
|
| 205 |
+
"""
|
| 206 |
+
features = self._extract_features(text)
|
| 207 |
+
|
| 208 |
+
score = 0.0
|
| 209 |
+
indicators = 0
|
| 210 |
+
|
| 211 |
+
# Check sentence length uniformity
|
| 212 |
+
if features['sentence_length_variance'] < 50:
|
| 213 |
+
score += 0.2
|
| 214 |
+
indicators += 1
|
| 215 |
+
|
| 216 |
+
# Check vocabulary diversity
|
| 217 |
+
if 0.4 < features['vocabulary_diversity'] < 0.6:
|
| 218 |
+
score += 0.2
|
| 219 |
+
indicators += 1
|
| 220 |
+
|
| 221 |
+
# Check average sentence length (AI often uses medium-length sentences)
|
| 222 |
+
if 15 < features['avg_sentence_length'] < 25:
|
| 223 |
+
score += 0.15
|
| 224 |
+
indicators += 1
|
| 225 |
+
|
| 226 |
+
# Check for repetitive patterns
|
| 227 |
+
if features['repetition_ratio'] < 0.05:
|
| 228 |
+
score += 0.15
|
| 229 |
+
indicators += 1
|
| 230 |
+
|
| 231 |
+
# Check for balanced punctuation
|
| 232 |
+
if 0.08 < features['punctuation_ratio'] < 0.15:
|
| 233 |
+
score += 0.15
|
| 234 |
+
indicators += 1
|
| 235 |
+
|
| 236 |
+
# Check for consistent paragraph structure
|
| 237 |
+
if features['avg_paragraph_length'] > 3:
|
| 238 |
+
score += 0.15
|
| 239 |
+
indicators += 1
|
| 240 |
+
|
| 241 |
+
return score
|
| 242 |
+
|
| 243 |
+
def _extract_features(self, text: str) -> Dict:
|
| 244 |
+
"""Extract statistical features from text."""
|
| 245 |
+
# Sentence segmentation
|
| 246 |
+
sentences = re.split(r'[.!?]+', text)
|
| 247 |
+
sentences = [s.strip() for s in sentences if s.strip()]
|
| 248 |
+
|
| 249 |
+
# Word tokenization
|
| 250 |
+
words = re.findall(r'\b\w+\b', text.lower())
|
| 251 |
+
|
| 252 |
+
# Calculate features
|
| 253 |
+
sentence_lengths = [len(s.split()) for s in sentences]
|
| 254 |
+
|
| 255 |
+
# Paragraph detection
|
| 256 |
+
paragraphs = text.split('\n\n')
|
| 257 |
+
paragraphs = [p.strip() for p in paragraphs if p.strip()]
|
| 258 |
+
|
| 259 |
+
features = {
|
| 260 |
+
'total_words': len(words),
|
| 261 |
+
'total_sentences': len(sentences),
|
| 262 |
+
'total_paragraphs': len(paragraphs),
|
| 263 |
+
'avg_sentence_length': np.mean(sentence_lengths) if sentence_lengths else 0,
|
| 264 |
+
'sentence_length_variance': np.var(sentence_lengths) if sentence_lengths else 0,
|
| 265 |
+
'vocabulary_diversity': len(set(words)) / len(words) if words else 0,
|
| 266 |
+
'avg_word_length': np.mean([len(w) for w in words]) if words else 0,
|
| 267 |
+
'punctuation_ratio': len(re.findall(r'[,.!?;:]', text)) / len(words) if words else 0,
|
| 268 |
+
'repetition_ratio': self._calculate_repetition(words),
|
| 269 |
+
'avg_paragraph_length': np.mean([len(p.split()) for p in paragraphs]) if paragraphs else 0
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
return features
|
| 273 |
+
|
| 274 |
+
def _calculate_repetition(self, words: List[str]) -> float:
|
| 275 |
+
"""Calculate word repetition ratio."""
|
| 276 |
+
if len(words) < 10:
|
| 277 |
+
return 0.0
|
| 278 |
+
|
| 279 |
+
# Look for repeated 3-grams
|
| 280 |
+
trigrams = [tuple(words[i:i+3]) for i in range(len(words)-2)]
|
| 281 |
+
trigram_counts = Counter(trigrams)
|
| 282 |
+
|
| 283 |
+
# Calculate ratio of repeated trigrams
|
| 284 |
+
repeated = sum(1 for count in trigram_counts.values() if count > 1)
|
| 285 |
+
total = len(trigrams)
|
| 286 |
+
|
| 287 |
+
return repeated / total if total > 0 else 0.0
|
| 288 |
+
|
| 289 |
+
def _identify_indicators(self, text: str, score: float) -> List[str]:
|
| 290 |
+
"""Identify specific AI generation indicators."""
|
| 291 |
+
indicators = []
|
| 292 |
+
|
| 293 |
+
features = self._extract_features(text)
|
| 294 |
+
perplexity = self._calculate_perplexity(text)
|
| 295 |
+
|
| 296 |
+
# Low perplexity
|
| 297 |
+
if perplexity < 30:
|
| 298 |
+
indicators.append(f"Very low perplexity ({perplexity:.1f}) suggests high predictability")
|
| 299 |
+
|
| 300 |
+
# Uniform sentence structure
|
| 301 |
+
if features['sentence_length_variance'] < 50:
|
| 302 |
+
indicators.append("Unusually uniform sentence lengths")
|
| 303 |
+
|
| 304 |
+
# Vocabulary consistency
|
| 305 |
+
if 0.4 < features['vocabulary_diversity'] < 0.6:
|
| 306 |
+
indicators.append("Vocabulary diversity typical of AI generation")
|
| 307 |
+
|
| 308 |
+
# Repetitive patterns
|
| 309 |
+
if features['repetition_ratio'] < 0.03:
|
| 310 |
+
indicators.append("Minimal repetition (uncommon in human writing)")
|
| 311 |
+
|
| 312 |
+
# Generic phrases common in AI
|
| 313 |
+
generic_phrases = [
|
| 314 |
+
"it's important to note",
|
| 315 |
+
"it's worth noting",
|
| 316 |
+
"in conclusion",
|
| 317 |
+
"to summarize",
|
| 318 |
+
"additionally",
|
| 319 |
+
"furthermore",
|
| 320 |
+
"moreover",
|
| 321 |
+
"in other words"
|
| 322 |
+
]
|
| 323 |
+
|
| 324 |
+
text_lower = text.lower()
|
| 325 |
+
found_phrases = [p for p in generic_phrases if p in text_lower]
|
| 326 |
+
if len(found_phrases) >= 2:
|
| 327 |
+
indicators.append(f"Multiple generic transition phrases: {', '.join(found_phrases[:3])}")
|
| 328 |
+
|
| 329 |
+
# Lack of personal pronouns
|
| 330 |
+
personal_pronouns = len(re.findall(r'\b(I|me|my|mine|we|us|our)\b', text, re.IGNORECASE))
|
| 331 |
+
if personal_pronouns == 0 and len(text.split()) > 50:
|
| 332 |
+
indicators.append("Absence of personal pronouns")
|
| 333 |
+
|
| 334 |
+
return indicators
|
| 335 |
+
|
| 336 |
+
def _generate_explanation(self, score: float, perplexity: float) -> str:
|
| 337 |
+
"""Generate human-readable explanation."""
|
| 338 |
+
if score > 0.9:
|
| 339 |
+
return (
|
| 340 |
+
f"Strong indicators of AI generation. "
|
| 341 |
+
f"Very low perplexity ({perplexity:.1f}) and multiple statistical markers."
|
| 342 |
+
)
|
| 343 |
+
elif score > 0.7:
|
| 344 |
+
return (
|
| 345 |
+
f"Likely AI-generated. "
|
| 346 |
+
f"Low perplexity ({perplexity:.1f}) and consistent with AI patterns."
|
| 347 |
+
)
|
| 348 |
+
elif score > 0.5:
|
| 349 |
+
return (
|
| 350 |
+
f"Possible AI generation. "
|
| 351 |
+
f"Some indicators present, but not conclusive."
|
| 352 |
+
)
|
| 353 |
+
elif score > 0.3:
|
| 354 |
+
return (
|
| 355 |
+
f"Likely human-written. "
|
| 356 |
+
f"Natural variation in style and structure."
|
| 357 |
+
)
|
| 358 |
+
else:
|
| 359 |
+
return (
|
| 360 |
+
f"Strong indicators of human writing. "
|
| 361 |
+
f"High perplexity ({perplexity:.1f}) and natural language patterns."
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
def batch_analyze(self, texts: List[str]) -> List[Dict]:
|
| 365 |
+
"""Analyze multiple texts efficiently."""
|
| 366 |
+
results = []
|
| 367 |
+
|
| 368 |
+
for text in texts:
|
| 369 |
+
result = self.analyze_text(text, detailed=False)
|
| 370 |
+
results.append(result)
|
| 371 |
+
|
| 372 |
+
return results
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
# Example usage
|
| 376 |
+
if __name__ == "__main__":
|
| 377 |
+
detector = AITextDetector()
|
| 378 |
+
|
| 379 |
+
# Test with sample text
|
| 380 |
+
ai_text = """
|
| 381 |
+
Artificial intelligence has revolutionized numerous industries in recent years.
|
| 382 |
+
It's important to note that machine learning algorithms have become increasingly
|
| 383 |
+
sophisticated. Furthermore, these technologies continue to advance at a rapid pace.
|
| 384 |
+
In conclusion, AI will likely play an even larger role in the future.
|
| 385 |
+
"""
|
| 386 |
+
|
| 387 |
+
human_text = """
|
| 388 |
+
I can't believe how much AI has changed things! Last week I was playing around
|
| 389 |
+
with ChatGPT and honestly... it's wild. My boss thinks we should use it for
|
| 390 |
+
everything but idk, seems risky? Anyway, what do you think?
|
| 391 |
+
"""
|
| 392 |
+
|
| 393 |
+
print("AI Text Analysis:")
|
| 394 |
+
result = detector.analyze_text(ai_text)
|
| 395 |
+
print(f"Verdict: {result['verdict']}")
|
| 396 |
+
print(f"Confidence: {result['confidence']:.2%}")
|
| 397 |
+
print(f"Indicators: {result['indicators']}\n")
|
| 398 |
+
|
| 399 |
+
print("Human Text Analysis:")
|
| 400 |
+
result = detector.analyze_text(human_text)
|
| 401 |
+
print(f"Verdict: {result['verdict']}")
|
| 402 |
+
print(f"Confidence: {result['confidence']:.2%}")
|
src/detection/anomaly_detector.py
ADDED
|
@@ -0,0 +1,440 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Anomaly Detection Module
|
| 3 |
+
|
| 4 |
+
Detects coordinated inauthentic behavior, bot networks, and suspicious patterns.
|
| 5 |
+
|
| 6 |
+
Key Features:
|
| 7 |
+
1. Bot account identification
|
| 8 |
+
2. Coordinated campaign detection
|
| 9 |
+
3. Viral spread analysis
|
| 10 |
+
4. Temporal pattern anomalies
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
import pandas as pd
|
| 15 |
+
from typing import Dict, List, Tuple, Optional
|
| 16 |
+
from sklearn.ensemble import IsolationForest
|
| 17 |
+
from sklearn.preprocessing import StandardScaler
|
| 18 |
+
import networkx as nx
|
| 19 |
+
from datetime import datetime, timedelta
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class AnomalyDetector:
|
| 25 |
+
"""
|
| 26 |
+
Multi-method anomaly detector for social media content.
|
| 27 |
+
|
| 28 |
+
Detects:
|
| 29 |
+
- Bot accounts (behavioral patterns)
|
| 30 |
+
- Coordinated campaigns (network analysis)
|
| 31 |
+
- Suspicious viral patterns
|
| 32 |
+
- Time-series anomalies
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
contamination: float = 0.1,
|
| 38 |
+
device: str = "cuda" if torch.cuda.is_available() else "cpu"
|
| 39 |
+
):
|
| 40 |
+
"""
|
| 41 |
+
Initialize anomaly detector.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
contamination: Expected proportion of anomalies (0-0.5)
|
| 45 |
+
device: Device for deep learning models
|
| 46 |
+
"""
|
| 47 |
+
self.contamination = contamination
|
| 48 |
+
self.device = device
|
| 49 |
+
|
| 50 |
+
# Isolation Forest for bot detection
|
| 51 |
+
self.bot_detector = IsolationForest(
|
| 52 |
+
contamination=contamination,
|
| 53 |
+
random_state=42,
|
| 54 |
+
n_estimators=100
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
# Scaler for feature normalization
|
| 58 |
+
self.scaler = StandardScaler()
|
| 59 |
+
|
| 60 |
+
print("✓ Anomaly Detector initialized")
|
| 61 |
+
|
| 62 |
+
def detect_bot_accounts(
|
| 63 |
+
self,
|
| 64 |
+
user_data: pd.DataFrame,
|
| 65 |
+
return_scores: bool = True
|
| 66 |
+
) -> Dict:
|
| 67 |
+
"""
|
| 68 |
+
Detect bot accounts based on behavioral features.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
user_data: DataFrame with user activity data
|
| 72 |
+
Required columns: user_id, post_count, follower_count,
|
| 73 |
+
following_count, account_age_days, avg_post_interval,
|
| 74 |
+
verified, profile_has_image, bio_length
|
| 75 |
+
return_scores: Return anomaly scores for all users
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
Detection results with bot predictions
|
| 79 |
+
"""
|
| 80 |
+
# Extract features
|
| 81 |
+
features = self._extract_bot_features(user_data)
|
| 82 |
+
|
| 83 |
+
# Normalize features
|
| 84 |
+
features_scaled = self.scaler.fit_transform(features)
|
| 85 |
+
|
| 86 |
+
# Detect anomalies
|
| 87 |
+
predictions = self.bot_detector.fit_predict(features_scaled)
|
| 88 |
+
anomaly_scores = self.bot_detector.score_samples(features_scaled)
|
| 89 |
+
|
| 90 |
+
# -1 = anomaly (bot), 1 = normal
|
| 91 |
+
bot_mask = predictions == -1
|
| 92 |
+
bot_users = user_data.loc[bot_mask, 'user_id'].tolist()
|
| 93 |
+
|
| 94 |
+
# Calculate confidence scores
|
| 95 |
+
# Convert anomaly scores to 0-1 probability
|
| 96 |
+
scores_normalized = 1 / (1 + np.exp(anomaly_scores))
|
| 97 |
+
|
| 98 |
+
result = {
|
| 99 |
+
'total_users': len(user_data),
|
| 100 |
+
'bots_detected': int(np.sum(bot_mask)),
|
| 101 |
+
'bot_percentage': float(np.mean(bot_mask) * 100),
|
| 102 |
+
'bot_user_ids': bot_users,
|
| 103 |
+
'summary': self._generate_bot_summary(user_data[bot_mask])
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
if return_scores:
|
| 107 |
+
result['user_scores'] = pd.DataFrame({
|
| 108 |
+
'user_id': user_data['user_id'],
|
| 109 |
+
'is_bot': bot_mask,
|
| 110 |
+
'bot_probability': scores_normalized,
|
| 111 |
+
'anomaly_score': anomaly_scores
|
| 112 |
+
}).to_dict('records')
|
| 113 |
+
|
| 114 |
+
return result
|
| 115 |
+
|
| 116 |
+
def _extract_bot_features(self, user_data: pd.DataFrame) -> np.ndarray:
|
| 117 |
+
"""Extract features for bot detection."""
|
| 118 |
+
features = []
|
| 119 |
+
|
| 120 |
+
# Feature 1: Post frequency
|
| 121 |
+
if 'account_age_days' in user_data and 'post_count' in user_data:
|
| 122 |
+
post_frequency = user_data['post_count'] / (user_data['account_age_days'] + 1)
|
| 123 |
+
features.append(post_frequency)
|
| 124 |
+
|
| 125 |
+
# Feature 2: Follower/following ratio
|
| 126 |
+
if 'follower_count' in user_data and 'following_count' in user_data:
|
| 127 |
+
ff_ratio = user_data['follower_count'] / (user_data['following_count'] + 1)
|
| 128 |
+
features.append(ff_ratio)
|
| 129 |
+
|
| 130 |
+
# Feature 3: Account completeness score
|
| 131 |
+
completeness = 0
|
| 132 |
+
if 'verified' in user_data:
|
| 133 |
+
completeness += user_data['verified'].astype(int)
|
| 134 |
+
if 'profile_has_image' in user_data:
|
| 135 |
+
completeness += user_data['profile_has_image'].astype(int)
|
| 136 |
+
if 'bio_length' in user_data:
|
| 137 |
+
completeness += (user_data['bio_length'] > 20).astype(int)
|
| 138 |
+
features.append(completeness)
|
| 139 |
+
|
| 140 |
+
# Feature 4: Posting pattern regularity
|
| 141 |
+
if 'avg_post_interval' in user_data:
|
| 142 |
+
features.append(user_data['avg_post_interval'])
|
| 143 |
+
|
| 144 |
+
# Feature 5: Account age
|
| 145 |
+
if 'account_age_days' in user_data:
|
| 146 |
+
features.append(user_data['account_age_days'])
|
| 147 |
+
|
| 148 |
+
# Stack features
|
| 149 |
+
feature_array = np.column_stack(features)
|
| 150 |
+
|
| 151 |
+
return feature_array
|
| 152 |
+
|
| 153 |
+
def _generate_bot_summary(self, bot_data: pd.DataFrame) -> Dict:
|
| 154 |
+
"""Generate summary statistics for detected bots."""
|
| 155 |
+
if len(bot_data) == 0:
|
| 156 |
+
return {'message': 'No bots detected'}
|
| 157 |
+
|
| 158 |
+
summary = {
|
| 159 |
+
'avg_post_frequency': float(bot_data['post_count'].mean() / (bot_data['account_age_days'].mean() + 1)) if 'post_count' in bot_data else None,
|
| 160 |
+
'avg_account_age_days': float(bot_data['account_age_days'].mean()) if 'account_age_days' in bot_data else None,
|
| 161 |
+
'percent_unverified': float((~bot_data['verified']).mean() * 100) if 'verified' in bot_data else None,
|
| 162 |
+
'percent_no_profile_image': float((~bot_data['profile_has_image']).mean() * 100) if 'profile_has_image' in bot_data else None
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
return summary
|
| 166 |
+
|
| 167 |
+
def detect_coordinated_campaign(
|
| 168 |
+
self,
|
| 169 |
+
activity_data: pd.DataFrame,
|
| 170 |
+
time_window: str = "1h",
|
| 171 |
+
min_accounts: int = 5
|
| 172 |
+
) -> Dict:
|
| 173 |
+
"""
|
| 174 |
+
Detect coordinated campaigns using network analysis.
|
| 175 |
+
|
| 176 |
+
Args:
|
| 177 |
+
activity_data: DataFrame with columns: user_id, content_id,
|
| 178 |
+
timestamp, content_hash, action_type
|
| 179 |
+
time_window: Time window for coordination (e.g., "1h", "30m")
|
| 180 |
+
min_accounts: Minimum accounts for a campaign
|
| 181 |
+
|
| 182 |
+
Returns:
|
| 183 |
+
Detected campaigns
|
| 184 |
+
"""
|
| 185 |
+
# Convert time window to timedelta
|
| 186 |
+
time_delta = self._parse_time_window(time_window)
|
| 187 |
+
|
| 188 |
+
# Group activities by content
|
| 189 |
+
content_groups = activity_data.groupby('content_hash')
|
| 190 |
+
|
| 191 |
+
campaigns = []
|
| 192 |
+
|
| 193 |
+
for content_hash, group in content_groups:
|
| 194 |
+
if len(group) < min_accounts:
|
| 195 |
+
continue
|
| 196 |
+
|
| 197 |
+
# Check temporal clustering
|
| 198 |
+
timestamps = pd.to_datetime(group['timestamp'])
|
| 199 |
+
time_range = (timestamps.max() - timestamps.min()).total_seconds()
|
| 200 |
+
|
| 201 |
+
# If all actions within time window
|
| 202 |
+
if time_range <= time_delta.total_seconds():
|
| 203 |
+
# Calculate coordination score
|
| 204 |
+
coordination_score = self._calculate_coordination_score(group)
|
| 205 |
+
|
| 206 |
+
if coordination_score > 0.7:
|
| 207 |
+
campaigns.append({
|
| 208 |
+
'content_hash': content_hash,
|
| 209 |
+
'participant_count': len(group),
|
| 210 |
+
'time_range_seconds': time_range,
|
| 211 |
+
'coordination_score': float(coordination_score),
|
| 212 |
+
'user_ids': group['user_id'].tolist(),
|
| 213 |
+
'start_time': timestamps.min().isoformat(),
|
| 214 |
+
'end_time': timestamps.max().isoformat()
|
| 215 |
+
})
|
| 216 |
+
|
| 217 |
+
# Network analysis
|
| 218 |
+
campaign_network = self._build_campaign_network(campaigns)
|
| 219 |
+
|
| 220 |
+
return {
|
| 221 |
+
'campaigns_detected': len(campaigns),
|
| 222 |
+
'campaigns': campaigns,
|
| 223 |
+
'network_metrics': campaign_network,
|
| 224 |
+
'explanation': self._explain_campaigns(campaigns)
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
def _parse_time_window(self, time_window: str) -> timedelta:
|
| 228 |
+
"""Parse time window string to timedelta."""
|
| 229 |
+
unit = time_window[-1]
|
| 230 |
+
value = int(time_window[:-1])
|
| 231 |
+
|
| 232 |
+
if unit == 's':
|
| 233 |
+
return timedelta(seconds=value)
|
| 234 |
+
elif unit == 'm':
|
| 235 |
+
return timedelta(minutes=value)
|
| 236 |
+
elif unit == 'h':
|
| 237 |
+
return timedelta(hours=value)
|
| 238 |
+
elif unit == 'd':
|
| 239 |
+
return timedelta(days=value)
|
| 240 |
+
else:
|
| 241 |
+
raise ValueError(f"Unknown time unit: {unit}")
|
| 242 |
+
|
| 243 |
+
def _calculate_coordination_score(self, activity_group: pd.DataFrame) -> float:
|
| 244 |
+
"""
|
| 245 |
+
Calculate coordination score based on:
|
| 246 |
+
- Temporal clustering
|
| 247 |
+
- Account similarity
|
| 248 |
+
- Action synchronization
|
| 249 |
+
"""
|
| 250 |
+
score = 0.0
|
| 251 |
+
|
| 252 |
+
# 1. Temporal clustering (max 0.4)
|
| 253 |
+
timestamps = pd.to_datetime(activity_group['timestamp'])
|
| 254 |
+
time_std = timestamps.astype(int).std() / 1e9 # Convert to seconds
|
| 255 |
+
|
| 256 |
+
if time_std < 60: # Within 1 minute
|
| 257 |
+
score += 0.4
|
| 258 |
+
elif time_std < 300: # Within 5 minutes
|
| 259 |
+
score += 0.3
|
| 260 |
+
elif time_std < 3600: # Within 1 hour
|
| 261 |
+
score += 0.2
|
| 262 |
+
|
| 263 |
+
# 2. Account age similarity (max 0.3)
|
| 264 |
+
if 'account_age_days' in activity_group:
|
| 265 |
+
age_std = activity_group['account_age_days'].std()
|
| 266 |
+
if age_std < 30: # Similar account ages
|
| 267 |
+
score += 0.3
|
| 268 |
+
elif age_std < 90:
|
| 269 |
+
score += 0.2
|
| 270 |
+
|
| 271 |
+
# 3. Action type uniformity (max 0.3)
|
| 272 |
+
if 'action_type' in activity_group:
|
| 273 |
+
action_entropy = self._calculate_entropy(
|
| 274 |
+
activity_group['action_type'].value_counts(normalize=True)
|
| 275 |
+
)
|
| 276 |
+
# Low entropy = uniform actions = coordinated
|
| 277 |
+
score += 0.3 * (1 - action_entropy)
|
| 278 |
+
|
| 279 |
+
return min(score, 1.0)
|
| 280 |
+
|
| 281 |
+
def _calculate_entropy(self, probabilities: pd.Series) -> float:
|
| 282 |
+
"""Calculate Shannon entropy."""
|
| 283 |
+
return -np.sum(probabilities * np.log2(probabilities + 1e-10))
|
| 284 |
+
|
| 285 |
+
def _build_campaign_network(self, campaigns: List[Dict]) -> Dict:
|
| 286 |
+
"""Build network graph of campaign participants."""
|
| 287 |
+
if not campaigns:
|
| 288 |
+
return {'nodes': 0, 'edges': 0, 'components': 0}
|
| 289 |
+
|
| 290 |
+
# Create graph
|
| 291 |
+
G = nx.Graph()
|
| 292 |
+
|
| 293 |
+
# Add nodes and edges
|
| 294 |
+
for campaign in campaigns:
|
| 295 |
+
users = campaign['user_ids']
|
| 296 |
+
|
| 297 |
+
# Add all users
|
| 298 |
+
G.add_nodes_from(users)
|
| 299 |
+
|
| 300 |
+
# Connect users who participated in same campaign
|
| 301 |
+
for i, user1 in enumerate(users):
|
| 302 |
+
for user2 in users[i+1:]:
|
| 303 |
+
if G.has_edge(user1, user2):
|
| 304 |
+
G[user1][user2]['weight'] += 1
|
| 305 |
+
else:
|
| 306 |
+
G.add_edge(user1, user2, weight=1)
|
| 307 |
+
|
| 308 |
+
# Calculate network metrics
|
| 309 |
+
connected_components = list(nx.connected_components(G))
|
| 310 |
+
|
| 311 |
+
metrics = {
|
| 312 |
+
'nodes': G.number_of_nodes(),
|
| 313 |
+
'edges': G.number_of_edges(),
|
| 314 |
+
'connected_components': len(connected_components),
|
| 315 |
+
'largest_component_size': max(len(c) for c in connected_components) if connected_components else 0,
|
| 316 |
+
'avg_clustering_coefficient': nx.average_clustering(G) if G.number_of_nodes() > 0 else 0
|
| 317 |
+
}
|
| 318 |
+
|
| 319 |
+
return metrics
|
| 320 |
+
|
| 321 |
+
def _explain_campaigns(self, campaigns: List[Dict]) -> str:
|
| 322 |
+
"""Generate explanation for detected campaigns."""
|
| 323 |
+
if not campaigns:
|
| 324 |
+
return "No coordinated campaigns detected."
|
| 325 |
+
|
| 326 |
+
total_participants = sum(c['participant_count'] for c in campaigns)
|
| 327 |
+
avg_coordination = np.mean([c['coordination_score'] for c in campaigns])
|
| 328 |
+
|
| 329 |
+
return (
|
| 330 |
+
f"Detected {len(campaigns)} coordinated campaign(s) involving "
|
| 331 |
+
f"{total_participants} accounts. Average coordination score: {avg_coordination:.2f}. "
|
| 332 |
+
f"This suggests organized, inauthentic behavior patterns."
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
def analyze_viral_spread(
|
| 336 |
+
self,
|
| 337 |
+
spread_data: pd.DataFrame
|
| 338 |
+
) -> Dict:
|
| 339 |
+
"""
|
| 340 |
+
Analyze viral spread patterns for anomalies.
|
| 341 |
+
|
| 342 |
+
Args:
|
| 343 |
+
spread_data: DataFrame with columns: timestamp, share_count,
|
| 344 |
+
view_count, engagement_rate
|
| 345 |
+
|
| 346 |
+
Returns:
|
| 347 |
+
Viral spread analysis
|
| 348 |
+
"""
|
| 349 |
+
# Sort by timestamp
|
| 350 |
+
spread_data = spread_data.sort_values('timestamp')
|
| 351 |
+
|
| 352 |
+
# Calculate growth rate
|
| 353 |
+
spread_data['growth_rate'] = spread_data['share_count'].pct_change()
|
| 354 |
+
|
| 355 |
+
# Detect suspicious patterns
|
| 356 |
+
anomalies = []
|
| 357 |
+
|
| 358 |
+
# 1. Sudden spike detection
|
| 359 |
+
mean_growth = spread_data['growth_rate'].mean()
|
| 360 |
+
std_growth = spread_data['growth_rate'].std()
|
| 361 |
+
|
| 362 |
+
spikes = spread_data[
|
| 363 |
+
spread_data['growth_rate'] > mean_growth + 3 * std_growth
|
| 364 |
+
]
|
| 365 |
+
|
| 366 |
+
if len(spikes) > 0:
|
| 367 |
+
anomalies.append({
|
| 368 |
+
'type': 'sudden_spike',
|
| 369 |
+
'description': f'Detected {len(spikes)} sudden spike(s) in sharing activity',
|
| 370 |
+
'timestamps': spikes['timestamp'].tolist()
|
| 371 |
+
})
|
| 372 |
+
|
| 373 |
+
# 2. Unnatural growth pattern
|
| 374 |
+
# Real viral content has exponential then logarithmic growth
|
| 375 |
+
# Inorganic content has linear or step-function growth
|
| 376 |
+
|
| 377 |
+
correlation_with_time = spread_data['share_count'].corr(
|
| 378 |
+
pd.Series(range(len(spread_data)))
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
if abs(correlation_with_time) > 0.95: # Too linear
|
| 382 |
+
anomalies.append({
|
| 383 |
+
'type': 'linear_growth',
|
| 384 |
+
'description': 'Unnaturally linear growth pattern (typical of bot-driven spread)',
|
| 385 |
+
'correlation': float(correlation_with_time)
|
| 386 |
+
})
|
| 387 |
+
|
| 388 |
+
# 3. Low engagement rate despite high shares
|
| 389 |
+
if 'engagement_rate' in spread_data:
|
| 390 |
+
avg_engagement = spread_data['engagement_rate'].mean()
|
| 391 |
+
if avg_engagement < 0.01: # Less than 1%
|
| 392 |
+
anomalies.append({
|
| 393 |
+
'type': 'low_engagement',
|
| 394 |
+
'description': 'High share count but abnormally low engagement',
|
| 395 |
+
'avg_engagement_rate': float(avg_engagement)
|
| 396 |
+
})
|
| 397 |
+
|
| 398 |
+
return {
|
| 399 |
+
'is_suspicious': len(anomalies) > 0,
|
| 400 |
+
'anomaly_count': len(anomalies),
|
| 401 |
+
'anomalies': anomalies,
|
| 402 |
+
'growth_statistics': {
|
| 403 |
+
'total_shares': int(spread_data['share_count'].iloc[-1]) if len(spread_data) > 0 else 0,
|
| 404 |
+
'avg_growth_rate': float(mean_growth),
|
| 405 |
+
'max_growth_rate': float(spread_data['growth_rate'].max()),
|
| 406 |
+
'time_to_peak': str(spread_data.loc[spread_data['share_count'].idxmax(), 'timestamp']) if len(spread_data) > 0 else None
|
| 407 |
+
},
|
| 408 |
+
'verdict': 'SUSPICIOUS' if len(anomalies) >= 2 else 'NORMAL',
|
| 409 |
+
'explanation': self._explain_viral_analysis(anomalies)
|
| 410 |
+
}
|
| 411 |
+
|
| 412 |
+
def _explain_viral_analysis(self, anomalies: List[Dict]) -> str:
|
| 413 |
+
"""Generate explanation for viral spread analysis."""
|
| 414 |
+
if not anomalies:
|
| 415 |
+
return "Viral spread pattern appears organic and natural."
|
| 416 |
+
|
| 417 |
+
explanations = [a['description'] for a in anomalies]
|
| 418 |
+
return "Suspicious patterns detected: " + "; ".join(explanations)
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
# Example usage
|
| 422 |
+
if __name__ == "__main__":
|
| 423 |
+
detector = AnomalyDetector()
|
| 424 |
+
|
| 425 |
+
# Example: Bot detection
|
| 426 |
+
user_data = pd.DataFrame({
|
| 427 |
+
'user_id': ['user1', 'user2', 'user3', 'user4', 'user5'],
|
| 428 |
+
'post_count': [1000, 50, 800, 30, 20],
|
| 429 |
+
'follower_count': [100, 500, 120, 300, 250],
|
| 430 |
+
'following_count': [5000, 200, 4800, 180, 220],
|
| 431 |
+
'account_age_days': [30, 365, 25, 400, 350],
|
| 432 |
+
'avg_post_interval': [0.1, 8, 0.15, 12, 10],
|
| 433 |
+
'verified': [False, True, False, True, True],
|
| 434 |
+
'profile_has_image': [False, True, False, True, True],
|
| 435 |
+
'bio_length': [5, 150, 8, 120, 100]
|
| 436 |
+
})
|
| 437 |
+
|
| 438 |
+
result = detector.detect_bot_accounts(user_data)
|
| 439 |
+
print(f"Bots detected: {result['bots_detected']}")
|
| 440 |
+
print(f"Bot user IDs: {result['bot_user_ids']}")
|
src/detection/deepfake_detector.py
ADDED
|
@@ -0,0 +1,431 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Deepfake Detection Module
|
| 3 |
+
|
| 4 |
+
This module implements state-of-the-art deepfake detection using:
|
| 5 |
+
1. EfficientNet-based architecture for face manipulation detection
|
| 6 |
+
2. Temporal consistency analysis for video deepfakes
|
| 7 |
+
3. Attention mechanisms for explainability
|
| 8 |
+
4. Multi-scale feature extraction
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
from torchvision import transforms
|
| 15 |
+
from typing import Dict, Tuple, Optional, List
|
| 16 |
+
import numpy as np
|
| 17 |
+
import cv2
|
| 18 |
+
from PIL import Image
|
| 19 |
+
import timm
|
| 20 |
+
|
| 21 |
+
# Simplified imports - use available modules
|
| 22 |
+
try:
|
| 23 |
+
from ..utils.face_detection import detect_faces
|
| 24 |
+
from ..utils.preprocessing import preprocess_image
|
| 25 |
+
except ImportError:
|
| 26 |
+
# If relative imports fail, try absolute
|
| 27 |
+
import sys
|
| 28 |
+
from pathlib import Path
|
| 29 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 30 |
+
from utils.face_detection import detect_faces
|
| 31 |
+
from utils.preprocessing import preprocess_image
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class DeepfakeDetector:
|
| 35 |
+
"""
|
| 36 |
+
Production-ready deepfake detector with ensemble approach.
|
| 37 |
+
|
| 38 |
+
Combines multiple detection strategies:
|
| 39 |
+
- Spatial artifact detection
|
| 40 |
+
- Temporal consistency (for videos)
|
| 41 |
+
- Frequency domain analysis
|
| 42 |
+
- Attention-based feature extraction
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
def __init__(
|
| 46 |
+
self,
|
| 47 |
+
model_path: str = "models/deepfake_efficientnet_b4.pth",
|
| 48 |
+
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
| 49 |
+
threshold: float = 0.5,
|
| 50 |
+
use_ensemble: bool = True
|
| 51 |
+
):
|
| 52 |
+
"""
|
| 53 |
+
Initialize the deepfake detector.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
model_path: Path to pre-trained model weights
|
| 57 |
+
device: Device to run inference on (cuda/cpu)
|
| 58 |
+
threshold: Detection threshold (0-1)
|
| 59 |
+
use_ensemble: Whether to use ensemble of models
|
| 60 |
+
"""
|
| 61 |
+
self.device = device
|
| 62 |
+
self.threshold = threshold
|
| 63 |
+
self.use_ensemble = use_ensemble
|
| 64 |
+
|
| 65 |
+
# Load models
|
| 66 |
+
self._load_models(model_path)
|
| 67 |
+
|
| 68 |
+
# Image preprocessing
|
| 69 |
+
self.transform = transforms.Compose([
|
| 70 |
+
transforms.Resize((380, 380)),
|
| 71 |
+
transforms.CenterCrop(299),
|
| 72 |
+
transforms.ToTensor(),
|
| 73 |
+
transforms.Normalize(
|
| 74 |
+
mean=[0.485, 0.456, 0.406],
|
| 75 |
+
std=[0.229, 0.224, 0.225]
|
| 76 |
+
)
|
| 77 |
+
])
|
| 78 |
+
|
| 79 |
+
def _load_models(self, model_path: str):
|
| 80 |
+
"""Load pre-trained models."""
|
| 81 |
+
# Primary model: EfficientNet-B4
|
| 82 |
+
self.primary_model = timm.create_model(
|
| 83 |
+
'efficientnet_b4',
|
| 84 |
+
pretrained=False,
|
| 85 |
+
num_classes=1
|
| 86 |
+
).to(self.device)
|
| 87 |
+
|
| 88 |
+
# Load weights if available
|
| 89 |
+
try:
|
| 90 |
+
checkpoint = torch.load(model_path, map_location=self.device)
|
| 91 |
+
self.primary_model.load_state_dict(checkpoint['model_state_dict'])
|
| 92 |
+
print(f"✓ Loaded model from {model_path}")
|
| 93 |
+
except FileNotFoundError:
|
| 94 |
+
print(f"⚠ Model not found at {model_path}. Using random initialization.")
|
| 95 |
+
print(" Run: python scripts/download_models.py to download pre-trained weights")
|
| 96 |
+
|
| 97 |
+
self.primary_model.eval()
|
| 98 |
+
|
| 99 |
+
# Secondary models for ensemble
|
| 100 |
+
if self.use_ensemble:
|
| 101 |
+
self.secondary_models = self._load_ensemble_models()
|
| 102 |
+
|
| 103 |
+
def _load_ensemble_models(self) -> List[nn.Module]:
|
| 104 |
+
"""Load additional models for ensemble."""
|
| 105 |
+
models = []
|
| 106 |
+
|
| 107 |
+
# XceptionNet - good for GAN artifacts
|
| 108 |
+
xception = timm.create_model(
|
| 109 |
+
'xception',
|
| 110 |
+
pretrained=False,
|
| 111 |
+
num_classes=1
|
| 112 |
+
).to(self.device)
|
| 113 |
+
xception.eval()
|
| 114 |
+
models.append(xception)
|
| 115 |
+
|
| 116 |
+
# ResNet50 - robust baseline
|
| 117 |
+
resnet = timm.create_model(
|
| 118 |
+
'resnet50',
|
| 119 |
+
pretrained=False,
|
| 120 |
+
num_classes=1
|
| 121 |
+
).to(self.device)
|
| 122 |
+
resnet.eval()
|
| 123 |
+
models.append(resnet)
|
| 124 |
+
|
| 125 |
+
return models
|
| 126 |
+
|
| 127 |
+
def analyze_image(
|
| 128 |
+
self,
|
| 129 |
+
image_path: str,
|
| 130 |
+
return_attention: bool = True
|
| 131 |
+
) -> Dict:
|
| 132 |
+
"""
|
| 133 |
+
Analyze a single image for deepfake artifacts.
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
image_path: Path to image file
|
| 137 |
+
return_attention: Whether to return attention maps
|
| 138 |
+
|
| 139 |
+
Returns:
|
| 140 |
+
Dictionary with detection results
|
| 141 |
+
"""
|
| 142 |
+
# Load and preprocess image
|
| 143 |
+
image = Image.open(image_path).convert('RGB')
|
| 144 |
+
original_size = image.size
|
| 145 |
+
|
| 146 |
+
# Detect faces
|
| 147 |
+
faces = detect_faces(image)
|
| 148 |
+
|
| 149 |
+
if len(faces) == 0:
|
| 150 |
+
return {
|
| 151 |
+
'verdict': 'NO_FACE_DETECTED',
|
| 152 |
+
'confidence': 0.0,
|
| 153 |
+
'explanation': 'No faces detected in the image',
|
| 154 |
+
'faces_analyzed': 0,
|
| 155 |
+
'artifacts_detected': []
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
# Analyze each face
|
| 159 |
+
face_results = []
|
| 160 |
+
for i, face_coords in enumerate(faces):
|
| 161 |
+
face_crop = self._crop_face(image, face_coords)
|
| 162 |
+
result = self._analyze_face(face_crop, return_attention)
|
| 163 |
+
face_results.append(result)
|
| 164 |
+
|
| 165 |
+
# Aggregate results
|
| 166 |
+
avg_confidence = np.mean([r['confidence'] for r in face_results])
|
| 167 |
+
is_fake = avg_confidence > self.threshold
|
| 168 |
+
|
| 169 |
+
return {
|
| 170 |
+
'verdict': 'FAKE' if is_fake else 'REAL',
|
| 171 |
+
'confidence': float(avg_confidence),
|
| 172 |
+
'threshold': self.threshold,
|
| 173 |
+
'faces_analyzed': len(faces),
|
| 174 |
+
'face_results': face_results,
|
| 175 |
+
'explanation': self._generate_explanation(avg_confidence, face_results),
|
| 176 |
+
'artifacts_detected': self._detect_artifacts(image)
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
def _analyze_face(
|
| 180 |
+
self,
|
| 181 |
+
face_image: Image.Image,
|
| 182 |
+
return_attention: bool
|
| 183 |
+
) -> Dict:
|
| 184 |
+
"""Analyze a single face crop."""
|
| 185 |
+
# Preprocess
|
| 186 |
+
input_tensor = self.transform(face_image).unsqueeze(0).to(self.device)
|
| 187 |
+
|
| 188 |
+
# Primary model inference
|
| 189 |
+
with torch.no_grad():
|
| 190 |
+
logits = self.primary_model(input_tensor)
|
| 191 |
+
confidence = torch.sigmoid(logits).item()
|
| 192 |
+
|
| 193 |
+
# Ensemble if enabled
|
| 194 |
+
if self.use_ensemble:
|
| 195 |
+
ensemble_confidences = [confidence]
|
| 196 |
+
for model in self.secondary_models:
|
| 197 |
+
with torch.no_grad():
|
| 198 |
+
logits = model(input_tensor)
|
| 199 |
+
conf = torch.sigmoid(logits).item()
|
| 200 |
+
ensemble_confidences.append(conf)
|
| 201 |
+
|
| 202 |
+
confidence = np.mean(ensemble_confidences)
|
| 203 |
+
|
| 204 |
+
result = {
|
| 205 |
+
'confidence': confidence,
|
| 206 |
+
'is_fake': confidence > self.threshold
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
# Add attention map if requested
|
| 210 |
+
if return_attention:
|
| 211 |
+
result['attention_map'] = self._generate_attention_map(input_tensor)
|
| 212 |
+
|
| 213 |
+
return result
|
| 214 |
+
|
| 215 |
+
def _crop_face(
|
| 216 |
+
self,
|
| 217 |
+
image: Image.Image,
|
| 218 |
+
face_coords: Tuple[int, int, int, int]
|
| 219 |
+
) -> Image.Image:
|
| 220 |
+
"""Crop face from image with padding."""
|
| 221 |
+
x, y, w, h = face_coords
|
| 222 |
+
|
| 223 |
+
# Add 30% padding
|
| 224 |
+
padding = int(0.3 * max(w, h))
|
| 225 |
+
x1 = max(0, x - padding)
|
| 226 |
+
y1 = max(0, y - padding)
|
| 227 |
+
x2 = min(image.width, x + w + padding)
|
| 228 |
+
y2 = min(image.height, y + h + padding)
|
| 229 |
+
|
| 230 |
+
return image.crop((x1, y1, x2, y2))
|
| 231 |
+
|
| 232 |
+
def _generate_attention_map(self, input_tensor: torch.Tensor) -> np.ndarray:
|
| 233 |
+
"""Generate Grad-CAM attention map."""
|
| 234 |
+
# Simplified attention map generation
|
| 235 |
+
# In production, implement full Grad-CAM
|
| 236 |
+
|
| 237 |
+
# Get feature maps from last conv layer
|
| 238 |
+
features = self.primary_model.features(input_tensor)
|
| 239 |
+
|
| 240 |
+
# Global average pooling
|
| 241 |
+
attention = F.adaptive_avg_pool2d(features, (1, 1))
|
| 242 |
+
attention = attention.squeeze().cpu().numpy()
|
| 243 |
+
|
| 244 |
+
return attention
|
| 245 |
+
|
| 246 |
+
def _detect_artifacts(self, image: Image.Image) -> List[str]:
|
| 247 |
+
"""Detect specific deepfake artifacts."""
|
| 248 |
+
artifacts = []
|
| 249 |
+
|
| 250 |
+
# Convert to numpy array
|
| 251 |
+
img_array = np.array(image)
|
| 252 |
+
|
| 253 |
+
# Check for common artifacts
|
| 254 |
+
|
| 255 |
+
# 1. Face boundary inconsistencies
|
| 256 |
+
if self._check_boundary_artifacts(img_array):
|
| 257 |
+
artifacts.append("Face boundary inconsistencies detected")
|
| 258 |
+
|
| 259 |
+
# 2. Color inconsistencies
|
| 260 |
+
if self._check_color_artifacts(img_array):
|
| 261 |
+
artifacts.append("Abnormal color distribution in face region")
|
| 262 |
+
|
| 263 |
+
# 3. Frequency domain artifacts
|
| 264 |
+
if self._check_frequency_artifacts(img_array):
|
| 265 |
+
artifacts.append("Suspicious frequency patterns detected")
|
| 266 |
+
|
| 267 |
+
# 4. Eye/teeth artifacts (common in face-swap)
|
| 268 |
+
if self._check_facial_feature_artifacts(img_array):
|
| 269 |
+
artifacts.append("Inconsistencies in facial features")
|
| 270 |
+
|
| 271 |
+
return artifacts
|
| 272 |
+
|
| 273 |
+
def _check_boundary_artifacts(self, image: np.ndarray) -> bool:
|
| 274 |
+
"""Check for boundary artifacts using edge detection."""
|
| 275 |
+
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
|
| 276 |
+
edges = cv2.Canny(gray, 50, 150)
|
| 277 |
+
|
| 278 |
+
# Calculate edge density
|
| 279 |
+
edge_density = np.sum(edges > 0) / edges.size
|
| 280 |
+
|
| 281 |
+
# Suspicious if too many sharp edges (indicates blending)
|
| 282 |
+
return edge_density > 0.15
|
| 283 |
+
|
| 284 |
+
def _check_color_artifacts(self, image: np.ndarray) -> bool:
|
| 285 |
+
"""Check for color inconsistencies."""
|
| 286 |
+
# Convert to LAB color space
|
| 287 |
+
lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
|
| 288 |
+
|
| 289 |
+
# Calculate color variance
|
| 290 |
+
color_var = np.var(lab, axis=(0, 1))
|
| 291 |
+
|
| 292 |
+
# Suspicious if variance is abnormal
|
| 293 |
+
return color_var[0] > 1000 # Threshold for L channel
|
| 294 |
+
|
| 295 |
+
def _check_frequency_artifacts(self, image: np.ndarray) -> bool:
|
| 296 |
+
"""Check frequency domain for artifacts."""
|
| 297 |
+
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
|
| 298 |
+
|
| 299 |
+
# Apply FFT
|
| 300 |
+
fft = np.fft.fft2(gray)
|
| 301 |
+
fft_shift = np.fft.fftshift(fft)
|
| 302 |
+
magnitude = np.abs(fft_shift)
|
| 303 |
+
|
| 304 |
+
# Check for abnormal frequency patterns
|
| 305 |
+
high_freq_energy = np.sum(magnitude[magnitude.shape[0]//4:3*magnitude.shape[0]//4,
|
| 306 |
+
magnitude.shape[1]//4:3*magnitude.shape[1]//4])
|
| 307 |
+
total_energy = np.sum(magnitude)
|
| 308 |
+
|
| 309 |
+
ratio = high_freq_energy / total_energy
|
| 310 |
+
|
| 311 |
+
# GAN-generated images often have specific frequency signatures
|
| 312 |
+
return ratio < 0.1 or ratio > 0.4
|
| 313 |
+
|
| 314 |
+
def _check_facial_feature_artifacts(self, image: np.ndarray) -> bool:
|
| 315 |
+
"""Check for artifacts in facial features."""
|
| 316 |
+
# Simplified check - in production, use facial landmark detection
|
| 317 |
+
# and analyze consistency of eyes, nose, mouth
|
| 318 |
+
|
| 319 |
+
# For now, return False (placeholder)
|
| 320 |
+
return False
|
| 321 |
+
|
| 322 |
+
def _generate_explanation(
|
| 323 |
+
self,
|
| 324 |
+
confidence: float,
|
| 325 |
+
face_results: List[Dict]
|
| 326 |
+
) -> str:
|
| 327 |
+
"""Generate human-readable explanation."""
|
| 328 |
+
if confidence > 0.9:
|
| 329 |
+
return "Strong indicators of manipulation detected. Multiple artifacts found."
|
| 330 |
+
elif confidence > 0.7:
|
| 331 |
+
return "Likely manipulated. Several suspicious patterns identified."
|
| 332 |
+
elif confidence > 0.5:
|
| 333 |
+
return "Possible manipulation. Some inconsistencies detected."
|
| 334 |
+
elif confidence > 0.3:
|
| 335 |
+
return "Minor inconsistencies found, but likely authentic."
|
| 336 |
+
else:
|
| 337 |
+
return "No significant manipulation detected. Image appears authentic."
|
| 338 |
+
|
| 339 |
+
def analyze_video(
|
| 340 |
+
self,
|
| 341 |
+
video_path: str,
|
| 342 |
+
sample_rate: int = 5,
|
| 343 |
+
max_frames: int = 100
|
| 344 |
+
) -> Dict:
|
| 345 |
+
"""
|
| 346 |
+
Analyze video for deepfake artifacts.
|
| 347 |
+
|
| 348 |
+
Args:
|
| 349 |
+
video_path: Path to video file
|
| 350 |
+
sample_rate: Analyze every Nth frame
|
| 351 |
+
max_frames: Maximum frames to analyze
|
| 352 |
+
|
| 353 |
+
Returns:
|
| 354 |
+
Dictionary with detection results
|
| 355 |
+
"""
|
| 356 |
+
cap = cv2.VideoCapture(video_path)
|
| 357 |
+
|
| 358 |
+
frame_results = []
|
| 359 |
+
frame_count = 0
|
| 360 |
+
analyzed_count = 0
|
| 361 |
+
|
| 362 |
+
while cap.isOpened() and analyzed_count < max_frames:
|
| 363 |
+
ret, frame = cap.read()
|
| 364 |
+
if not ret:
|
| 365 |
+
break
|
| 366 |
+
|
| 367 |
+
# Sample frames
|
| 368 |
+
if frame_count % sample_rate == 0:
|
| 369 |
+
# Convert BGR to RGB
|
| 370 |
+
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 371 |
+
pil_image = Image.fromarray(frame_rgb)
|
| 372 |
+
|
| 373 |
+
# Analyze frame
|
| 374 |
+
result = self.analyze_image(pil_image, return_attention=False)
|
| 375 |
+
frame_results.append({
|
| 376 |
+
'frame_number': frame_count,
|
| 377 |
+
'confidence': result['confidence'],
|
| 378 |
+
'verdict': result['verdict']
|
| 379 |
+
})
|
| 380 |
+
|
| 381 |
+
analyzed_count += 1
|
| 382 |
+
|
| 383 |
+
frame_count += 1
|
| 384 |
+
|
| 385 |
+
cap.release()
|
| 386 |
+
|
| 387 |
+
# Analyze temporal consistency
|
| 388 |
+
confidences = [r['confidence'] for r in frame_results]
|
| 389 |
+
avg_confidence = np.mean(confidences)
|
| 390 |
+
confidence_variance = np.var(confidences)
|
| 391 |
+
|
| 392 |
+
# High variance suggests inconsistent manipulation
|
| 393 |
+
temporal_inconsistency = confidence_variance > 0.05
|
| 394 |
+
|
| 395 |
+
return {
|
| 396 |
+
'verdict': 'FAKE' if avg_confidence > self.threshold else 'REAL',
|
| 397 |
+
'confidence': float(avg_confidence),
|
| 398 |
+
'confidence_variance': float(confidence_variance),
|
| 399 |
+
'temporal_inconsistency': temporal_inconsistency,
|
| 400 |
+
'frames_analyzed': analyzed_count,
|
| 401 |
+
'total_frames': frame_count,
|
| 402 |
+
'frame_results': frame_results,
|
| 403 |
+
'explanation': self._generate_video_explanation(
|
| 404 |
+
avg_confidence,
|
| 405 |
+
temporal_inconsistency
|
| 406 |
+
)
|
| 407 |
+
}
|
| 408 |
+
|
| 409 |
+
def _generate_video_explanation(
|
| 410 |
+
self,
|
| 411 |
+
confidence: float,
|
| 412 |
+
temporal_inconsistency: bool
|
| 413 |
+
) -> str:
|
| 414 |
+
"""Generate explanation for video analysis."""
|
| 415 |
+
base_explanation = self._generate_explanation(confidence, [])
|
| 416 |
+
|
| 417 |
+
if temporal_inconsistency:
|
| 418 |
+
base_explanation += " Temporal inconsistencies detected across frames."
|
| 419 |
+
|
| 420 |
+
return base_explanation
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
# Example usage
|
| 424 |
+
if __name__ == "__main__":
|
| 425 |
+
detector = DeepfakeDetector()
|
| 426 |
+
|
| 427 |
+
# Analyze image
|
| 428 |
+
result = detector.analyze_image("test_image.jpg")
|
| 429 |
+
print(f"Verdict: {result['verdict']}")
|
| 430 |
+
print(f"Confidence: {result['confidence']:.2%}")
|
| 431 |
+
print(f"Explanation: {result['explanation']}")
|
src/models/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Init file for models module."""
|
src/models/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (268 Bytes). View file
|
|
|
src/training/train_deepfake.py
ADDED
|
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Training Pipeline for Deepfake Detection Models
|
| 3 |
+
|
| 4 |
+
Implements:
|
| 5 |
+
- Distributed training (multi-GPU)
|
| 6 |
+
- Mixed precision training
|
| 7 |
+
- Experiment tracking with MLflow
|
| 8 |
+
- Checkpoint management
|
| 9 |
+
- Data augmentation
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import torch.optim as optim
|
| 15 |
+
from torch.utils.data import DataLoader, Dataset
|
| 16 |
+
from torch.cuda.amp import autocast, GradScaler
|
| 17 |
+
import pytorch_lightning as pl
|
| 18 |
+
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
| 19 |
+
from pytorch_lightning.loggers import MLFlowLogger
|
| 20 |
+
import timm
|
| 21 |
+
from typing import Dict, Tuple, Optional
|
| 22 |
+
import mlflow
|
| 23 |
+
import numpy as np
|
| 24 |
+
from pathlib import Path
|
| 25 |
+
import albumentations as A
|
| 26 |
+
from albumentations.pytorch import ToTensorV2
|
| 27 |
+
from PIL import Image
|
| 28 |
+
import cv2
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class DeepfakeDataset(Dataset):
|
| 32 |
+
"""Dataset for deepfake detection training."""
|
| 33 |
+
|
| 34 |
+
def __init__(
|
| 35 |
+
self,
|
| 36 |
+
image_paths: list,
|
| 37 |
+
labels: list,
|
| 38 |
+
transform=None,
|
| 39 |
+
mode: str = "train"
|
| 40 |
+
):
|
| 41 |
+
"""
|
| 42 |
+
Args:
|
| 43 |
+
image_paths: List of paths to images
|
| 44 |
+
labels: List of labels (0=real, 1=fake)
|
| 45 |
+
transform: Albumentations transforms
|
| 46 |
+
mode: 'train', 'val', or 'test'
|
| 47 |
+
"""
|
| 48 |
+
self.image_paths = image_paths
|
| 49 |
+
self.labels = labels
|
| 50 |
+
self.transform = transform
|
| 51 |
+
self.mode = mode
|
| 52 |
+
|
| 53 |
+
def __len__(self):
|
| 54 |
+
return len(self.image_paths)
|
| 55 |
+
|
| 56 |
+
def __getitem__(self, idx):
|
| 57 |
+
# Load image
|
| 58 |
+
image_path = self.image_paths[idx]
|
| 59 |
+
image = cv2.imread(str(image_path))
|
| 60 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 61 |
+
|
| 62 |
+
label = self.labels[idx]
|
| 63 |
+
|
| 64 |
+
# Apply transforms
|
| 65 |
+
if self.transform:
|
| 66 |
+
augmented = self.transform(image=image)
|
| 67 |
+
image = augmented['image']
|
| 68 |
+
|
| 69 |
+
return image, label
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class DeepfakeDetectionModel(pl.LightningModule):
|
| 73 |
+
"""PyTorch Lightning module for deepfake detection."""
|
| 74 |
+
|
| 75 |
+
def __init__(
|
| 76 |
+
self,
|
| 77 |
+
model_name: str = "efficientnet_b4",
|
| 78 |
+
learning_rate: float = 1e-4,
|
| 79 |
+
weight_decay: float = 1e-5,
|
| 80 |
+
num_classes: int = 1
|
| 81 |
+
):
|
| 82 |
+
super().__init__()
|
| 83 |
+
self.save_hyperparameters()
|
| 84 |
+
|
| 85 |
+
# Load pre-trained model
|
| 86 |
+
self.model = timm.create_model(
|
| 87 |
+
model_name,
|
| 88 |
+
pretrained=True,
|
| 89 |
+
num_classes=num_classes
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
# Loss function
|
| 93 |
+
self.criterion = nn.BCEWithLogitsLoss()
|
| 94 |
+
|
| 95 |
+
# Metrics
|
| 96 |
+
self.train_accuracy = []
|
| 97 |
+
self.val_accuracy = []
|
| 98 |
+
|
| 99 |
+
def forward(self, x):
|
| 100 |
+
return self.model(x)
|
| 101 |
+
|
| 102 |
+
def training_step(self, batch, batch_idx):
|
| 103 |
+
images, labels = batch
|
| 104 |
+
labels = labels.float().unsqueeze(1)
|
| 105 |
+
|
| 106 |
+
# Forward pass
|
| 107 |
+
logits = self(images)
|
| 108 |
+
loss = self.criterion(logits, labels)
|
| 109 |
+
|
| 110 |
+
# Calculate accuracy
|
| 111 |
+
probs = torch.sigmoid(logits)
|
| 112 |
+
preds = (probs > 0.5).float()
|
| 113 |
+
accuracy = (preds == labels).float().mean()
|
| 114 |
+
|
| 115 |
+
# Log metrics
|
| 116 |
+
self.log('train_loss', loss, prog_bar=True)
|
| 117 |
+
self.log('train_accuracy', accuracy, prog_bar=True)
|
| 118 |
+
|
| 119 |
+
return loss
|
| 120 |
+
|
| 121 |
+
def validation_step(self, batch, batch_idx):
|
| 122 |
+
images, labels = batch
|
| 123 |
+
labels = labels.float().unsqueeze(1)
|
| 124 |
+
|
| 125 |
+
# Forward pass
|
| 126 |
+
logits = self(images)
|
| 127 |
+
loss = self.criterion(logits, labels)
|
| 128 |
+
|
| 129 |
+
# Calculate accuracy
|
| 130 |
+
probs = torch.sigmoid(logits)
|
| 131 |
+
preds = (probs > 0.5).float()
|
| 132 |
+
accuracy = (preds == labels).float().mean()
|
| 133 |
+
|
| 134 |
+
# Log metrics
|
| 135 |
+
self.log('val_loss', loss, prog_bar=True)
|
| 136 |
+
self.log('val_accuracy', accuracy, prog_bar=True)
|
| 137 |
+
|
| 138 |
+
return {'val_loss': loss, 'val_accuracy': accuracy}
|
| 139 |
+
|
| 140 |
+
def configure_optimizers(self):
|
| 141 |
+
optimizer = optim.AdamW(
|
| 142 |
+
self.parameters(),
|
| 143 |
+
lr=self.hparams.learning_rate,
|
| 144 |
+
weight_decay=self.hparams.weight_decay
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
scheduler = optim.lr_scheduler.CosineAnnealingLR(
|
| 148 |
+
optimizer,
|
| 149 |
+
T_max=10,
|
| 150 |
+
eta_min=1e-6
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
return {
|
| 154 |
+
'optimizer': optimizer,
|
| 155 |
+
'lr_scheduler': {
|
| 156 |
+
'scheduler': scheduler,
|
| 157 |
+
'monitor': 'val_loss'
|
| 158 |
+
}
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def get_transforms(mode: str = "train") -> A.Compose:
|
| 163 |
+
"""Get augmentation transforms."""
|
| 164 |
+
|
| 165 |
+
if mode == "train":
|
| 166 |
+
return A.Compose([
|
| 167 |
+
A.Resize(380, 380),
|
| 168 |
+
A.CenterCrop(299, 299),
|
| 169 |
+
A.HorizontalFlip(p=0.5),
|
| 170 |
+
A.Rotate(limit=15, p=0.5),
|
| 171 |
+
A.ColorJitter(
|
| 172 |
+
brightness=0.2,
|
| 173 |
+
contrast=0.2,
|
| 174 |
+
saturation=0.2,
|
| 175 |
+
hue=0.1,
|
| 176 |
+
p=0.5
|
| 177 |
+
),
|
| 178 |
+
A.GaussNoise(p=0.3),
|
| 179 |
+
A.Normalize(
|
| 180 |
+
mean=[0.485, 0.456, 0.406],
|
| 181 |
+
std=[0.229, 0.224, 0.225]
|
| 182 |
+
),
|
| 183 |
+
ToTensorV2()
|
| 184 |
+
])
|
| 185 |
+
else:
|
| 186 |
+
return A.Compose([
|
| 187 |
+
A.Resize(380, 380),
|
| 188 |
+
A.CenterCrop(299, 299),
|
| 189 |
+
A.Normalize(
|
| 190 |
+
mean=[0.485, 0.456, 0.406],
|
| 191 |
+
std=[0.229, 0.224, 0.225]
|
| 192 |
+
),
|
| 193 |
+
ToTensorV2()
|
| 194 |
+
])
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
class DeepfakeTrainer:
|
| 198 |
+
"""Training pipeline manager."""
|
| 199 |
+
|
| 200 |
+
def __init__(
|
| 201 |
+
self,
|
| 202 |
+
config: Dict,
|
| 203 |
+
experiment_name: str = "deepfake-detection"
|
| 204 |
+
):
|
| 205 |
+
"""
|
| 206 |
+
Args:
|
| 207 |
+
config: Training configuration
|
| 208 |
+
experiment_name: MLflow experiment name
|
| 209 |
+
"""
|
| 210 |
+
self.config = config
|
| 211 |
+
self.experiment_name = experiment_name
|
| 212 |
+
|
| 213 |
+
# Setup MLflow
|
| 214 |
+
mlflow.set_experiment(experiment_name)
|
| 215 |
+
self.mlflow_logger = MLFlowLogger(
|
| 216 |
+
experiment_name=experiment_name,
|
| 217 |
+
tracking_uri=config.get('mlflow_uri', 'http://localhost:5000')
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
def train(
|
| 221 |
+
self,
|
| 222 |
+
train_data: Tuple[list, list],
|
| 223 |
+
val_data: Tuple[list, list]
|
| 224 |
+
):
|
| 225 |
+
"""
|
| 226 |
+
Train the model.
|
| 227 |
+
|
| 228 |
+
Args:
|
| 229 |
+
train_data: Tuple of (image_paths, labels)
|
| 230 |
+
val_data: Tuple of (image_paths, labels)
|
| 231 |
+
"""
|
| 232 |
+
# Start MLflow run
|
| 233 |
+
with mlflow.start_run():
|
| 234 |
+
# Log parameters
|
| 235 |
+
mlflow.log_params(self.config)
|
| 236 |
+
|
| 237 |
+
# Create datasets
|
| 238 |
+
train_dataset = DeepfakeDataset(
|
| 239 |
+
*train_data,
|
| 240 |
+
transform=get_transforms("train"),
|
| 241 |
+
mode="train"
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
val_dataset = DeepfakeDataset(
|
| 245 |
+
*val_data,
|
| 246 |
+
transform=get_transforms("val"),
|
| 247 |
+
mode="val"
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
# Create data loaders
|
| 251 |
+
train_loader = DataLoader(
|
| 252 |
+
train_dataset,
|
| 253 |
+
batch_size=self.config['batch_size'],
|
| 254 |
+
shuffle=True,
|
| 255 |
+
num_workers=self.config['num_workers'],
|
| 256 |
+
pin_memory=True
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
val_loader = DataLoader(
|
| 260 |
+
val_dataset,
|
| 261 |
+
batch_size=self.config['batch_size'],
|
| 262 |
+
shuffle=False,
|
| 263 |
+
num_workers=self.config['num_workers'],
|
| 264 |
+
pin_memory=True
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
# Create model
|
| 268 |
+
model = DeepfakeDetectionModel(
|
| 269 |
+
model_name=self.config['model_name'],
|
| 270 |
+
learning_rate=self.config['learning_rate'],
|
| 271 |
+
weight_decay=self.config['weight_decay']
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
# Callbacks
|
| 275 |
+
checkpoint_callback = ModelCheckpoint(
|
| 276 |
+
dirpath=self.config['checkpoint_dir'],
|
| 277 |
+
filename='deepfake-{epoch:02d}-{val_accuracy:.4f}',
|
| 278 |
+
monitor='val_accuracy',
|
| 279 |
+
mode='max',
|
| 280 |
+
save_top_k=3,
|
| 281 |
+
save_last=True
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
early_stop_callback = EarlyStopping(
|
| 285 |
+
monitor='val_loss',
|
| 286 |
+
patience=self.config['early_stop_patience'],
|
| 287 |
+
mode='min'
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
# Trainer
|
| 291 |
+
trainer = pl.Trainer(
|
| 292 |
+
max_epochs=self.config['epochs'],
|
| 293 |
+
accelerator='auto',
|
| 294 |
+
devices=self.config.get('gpus', 1),
|
| 295 |
+
precision=self.config.get('precision', 16),
|
| 296 |
+
logger=self.mlflow_logger,
|
| 297 |
+
callbacks=[checkpoint_callback, early_stop_callback],
|
| 298 |
+
log_every_n_steps=10,
|
| 299 |
+
gradient_clip_val=1.0
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
# Train
|
| 303 |
+
trainer.fit(model, train_loader, val_loader)
|
| 304 |
+
|
| 305 |
+
# Log best model
|
| 306 |
+
best_model_path = checkpoint_callback.best_model_path
|
| 307 |
+
mlflow.log_artifact(best_model_path)
|
| 308 |
+
|
| 309 |
+
print(f"✓ Training completed!")
|
| 310 |
+
print(f" Best model: {best_model_path}")
|
| 311 |
+
print(f" Best val accuracy: {checkpoint_callback.best_model_score:.4f}")
|
| 312 |
+
|
| 313 |
+
return model, trainer
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
# Example usage
|
| 317 |
+
if __name__ == "__main__":
|
| 318 |
+
# Training configuration
|
| 319 |
+
config = {
|
| 320 |
+
'model_name': 'efficientnet_b4',
|
| 321 |
+
'batch_size': 32,
|
| 322 |
+
'learning_rate': 1e-4,
|
| 323 |
+
'weight_decay': 1e-5,
|
| 324 |
+
'epochs': 50,
|
| 325 |
+
'num_workers': 4,
|
| 326 |
+
'gpus': 1,
|
| 327 |
+
'precision': 16,
|
| 328 |
+
'checkpoint_dir': 'models/checkpoints',
|
| 329 |
+
'early_stop_patience': 5,
|
| 330 |
+
'mlflow_uri': 'http://localhost:5000'
|
| 331 |
+
}
|
| 332 |
+
|
| 333 |
+
# Example data (replace with actual data loading)
|
| 334 |
+
train_paths = ['path/to/train/img1.jpg', 'path/to/train/img2.jpg']
|
| 335 |
+
train_labels = [0, 1] # 0=real, 1=fake
|
| 336 |
+
|
| 337 |
+
val_paths = ['path/to/val/img1.jpg', 'path/to/val/img2.jpg']
|
| 338 |
+
val_labels = [0, 1]
|
| 339 |
+
|
| 340 |
+
# Create trainer
|
| 341 |
+
trainer = DeepfakeTrainer(config)
|
| 342 |
+
|
| 343 |
+
# Train
|
| 344 |
+
# model, pl_trainer = trainer.train(
|
| 345 |
+
# train_data=(train_paths, train_labels),
|
| 346 |
+
# val_data=(val_paths, val_labels)
|
| 347 |
+
# )
|
| 348 |
+
|
| 349 |
+
print("Training script ready. Uncomment the training code to run.")
|
src/utils/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Init file for utils module."""
|
src/utils/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (266 Bytes). View file
|
|
|
src/utils/__pycache__/face_detection.cpython-313.pyc
ADDED
|
Binary file (1.45 kB). View file
|
|
|
src/utils/__pycache__/preprocessing.cpython-313.pyc
ADDED
|
Binary file (1.32 kB). View file
|
|
|
src/utils/face_detection.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Utility module for face detection."""
|
| 2 |
+
|
| 3 |
+
import cv2
|
| 4 |
+
import numpy as np
|
| 5 |
+
from typing import List, Tuple
|
| 6 |
+
from PIL import Image
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def detect_faces(image: Image.Image) -> List[Tuple[int, int, int, int]]:
|
| 10 |
+
"""
|
| 11 |
+
Detect faces in an image.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
image: PIL Image
|
| 15 |
+
|
| 16 |
+
Returns:
|
| 17 |
+
List of face bounding boxes (x, y, w, h)
|
| 18 |
+
"""
|
| 19 |
+
# Convert PIL to OpenCV format
|
| 20 |
+
img_array = np.array(image)
|
| 21 |
+
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
|
| 22 |
+
|
| 23 |
+
# Load Haar Cascade
|
| 24 |
+
face_cascade = cv2.CascadeClassifier(
|
| 25 |
+
cv2.data.haarcascades + 'haarcascade_frontalface_default.xml'
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
# Detect faces
|
| 29 |
+
faces = face_cascade.detectMultiScale(
|
| 30 |
+
gray,
|
| 31 |
+
scaleFactor=1.1,
|
| 32 |
+
minNeighbors=5,
|
| 33 |
+
minSize=(30, 30)
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
return faces.tolist() if len(faces) > 0 else []
|
src/utils/preprocessing.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Preprocessing utilities."""
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from typing import Tuple
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def preprocess_image(image: Image.Image, target_size: Tuple[int, int] = (299, 299)) -> np.ndarray:
|
| 9 |
+
"""
|
| 10 |
+
Preprocess image for model input.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
image: PIL Image
|
| 14 |
+
target_size: Target dimensions
|
| 15 |
+
|
| 16 |
+
Returns:
|
| 17 |
+
Preprocessed image array
|
| 18 |
+
"""
|
| 19 |
+
# Resize
|
| 20 |
+
image = image.resize(target_size, Image.LANCZOS)
|
| 21 |
+
|
| 22 |
+
# Convert to array
|
| 23 |
+
img_array = np.array(image) / 255.0
|
| 24 |
+
|
| 25 |
+
# Normalize
|
| 26 |
+
mean = np.array([0.485, 0.456, 0.406])
|
| 27 |
+
std = np.array([0.229, 0.224, 0.225])
|
| 28 |
+
img_array = (img_array - mean) / std
|
| 29 |
+
|
| 30 |
+
return img_array.astype(np.float32)
|