diff --git a/.env.example b/.env.example new file mode 100644 index 0000000000000000000000000000000000000000..3fd38b6af9dcce562e2f6ada45b78811ccee6326 --- /dev/null +++ b/.env.example @@ -0,0 +1,24 @@ +# DeepFake Detector Backend - Environment Variables +# Copy this file to .env and update with your values + +# Hugging Face Configuration +# Available fusion models: +# - DeepFakeDetector/fusion-logreg-final (Logistic Regression - default) +# - DeepFakeDetector/fusion-meta-final (Meta-classifier) +HF_FUSION_REPO_ID=DeepFakeDetector/fusion-logreg-final +HF_CACHE_DIR=.hf_cache +# HF_TOKEN=your_huggingface_token_here # Optional: for private repos + +# Google Gemini API (Optional - for LLM explanations) +# GOOGLE_API_KEY=your_google_api_key_here + +# Server Configuration +HOST=0.0.0.0 +PORT=8000 + +# CORS Configuration (comma-separated list of allowed origins) +CORS_ORIGINS=http://localhost:8082,https://www.deepfake-detector.app,https://deepfake-detector.app + +# Debugging +ENABLE_DEBUG=false +LOG_LEVEL=INFO diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..44ac5a73214f068107c42237e33853b0453f1b11 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,43 @@ +# DeepFake Detector API - Hugging Face Spaces Docker Image +# Optimized for HF Spaces deployment with GPU support + +FROM python:3.11-slim + +# Set working directory +WORKDIR /app + +# Set environment variables +ENV PYTHONDONTWRITEBYTECODE=1 \ + PYTHONUNBUFFERED=1 \ + PIP_NO_CACHE_DIR=1 \ + PIP_DISABLE_PIP_VERSION_CHECK=1 \ + PORT=7860 + +# Install system dependencies +RUN apt-get update && apt-get install -y --no-install-recommends \ + curl \ + git \ + && rm -rf /var/lib/apt/lists/* + +# Create non-root user (HF Spaces requirement) +RUN useradd -m -u 1000 user +USER user + +# Set PATH for user-installed packages +ENV PATH="/home/user/.local/bin:$PATH" + +# Copy requirements and install dependencies as user +COPY --chown=user:user requirements.txt . +RUN pip install --no-cache-dir --upgrade -r requirements.txt + +# Copy application code +COPY --chown=user:user . /app + +# Create cache directory for Hugging Face models +RUN mkdir -p /app/.hf_cache + +# Expose HF Spaces port +EXPOSE 7860 + +# Run the application (start.sh already defaults to port 7860) +CMD ["./start.sh"] diff --git a/Dockerfile.huggingface b/Dockerfile.huggingface new file mode 100644 index 0000000000000000000000000000000000000000..44ac5a73214f068107c42237e33853b0453f1b11 --- /dev/null +++ b/Dockerfile.huggingface @@ -0,0 +1,43 @@ +# DeepFake Detector API - Hugging Face Spaces Docker Image +# Optimized for HF Spaces deployment with GPU support + +FROM python:3.11-slim + +# Set working directory +WORKDIR /app + +# Set environment variables +ENV PYTHONDONTWRITEBYTECODE=1 \ + PYTHONUNBUFFERED=1 \ + PIP_NO_CACHE_DIR=1 \ + PIP_DISABLE_PIP_VERSION_CHECK=1 \ + PORT=7860 + +# Install system dependencies +RUN apt-get update && apt-get install -y --no-install-recommends \ + curl \ + git \ + && rm -rf /var/lib/apt/lists/* + +# Create non-root user (HF Spaces requirement) +RUN useradd -m -u 1000 user +USER user + +# Set PATH for user-installed packages +ENV PATH="/home/user/.local/bin:$PATH" + +# Copy requirements and install dependencies as user +COPY --chown=user:user requirements.txt . +RUN pip install --no-cache-dir --upgrade -r requirements.txt + +# Copy application code +COPY --chown=user:user . /app + +# Create cache directory for Hugging Face models +RUN mkdir -p /app/.hf_cache + +# Expose HF Spaces port +EXPOSE 7860 + +# Run the application (start.sh already defaults to port 7860) +CMD ["./start.sh"] diff --git a/README.md b/README.md index 4b745075da954d6a68a0b05f67f06fabff663b28..f6eb7ab792b1515bf0d303781af89a6886591eb7 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,182 @@ --- -title: DeepFakeDetectorBackend -emoji: 👁 -colorFrom: gray -colorTo: yellow +title: DeepFake Detector API +emoji: 🎭 +colorFrom: blue +colorTo: purple sdk: docker -pinned: false -license: mit -short_description: FastAPI Backend for MacAI Society DeepFake Detector +app_port: 7860 --- -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +# 🎭 DeepFake Detector API + +FastAPI backend for detecting AI-generated (deepfake) images using an ensemble of state-of-the-art deep learning models. + +## 🤖 Models + +This API uses a fusion ensemble of 5 deep learning models: + +- **CNN Transfer** (EfficientNet-B0) - Transfer learning from ImageNet +- **ViT Base** (Vision Transformer) - Attention-based architecture +- **DeiT Distilled** (Data-efficient Image Transformer) - Distilled ViT variant +- **Gradient Field CNN** - Custom architecture analyzing gradient patterns +- **FFT CNN** - Frequency domain analysis using Fast Fourier Transform + +All models are combined using a **Logistic Regression stacking ensemble** for optimal accuracy. + +## 🔗 API Endpoints + +| Endpoint | Method | Description | +|----------|--------|-------------| +| `/health` | GET | Health check - returns API status | +| `/ready` | GET | Model readiness check - confirms models are loaded | +| `/models` | GET | List all loaded models with metadata | +| `/predict` | POST | Predict if an image is real or AI-generated | +| `/docs` | GET | Interactive Swagger API documentation | +| `/redoc` | GET | Alternative API documentation | + +## 🚀 Usage Example + +### Using cURL + +```bash +# Check if API is ready +curl https://lukhsaankumar-deepfakedetectorbackend.hf.space/ready + +# Make a prediction +curl -X POST "https://lukhsaankumar-deepfakedetectorbackend.hf.space/predict" \ + -F "file=@image.jpg" \ + -F "explain=true" +``` + +### Using Python + +```python +import requests + +# Upload an image for prediction +url = "https://lukhsaankumar-deepfakedetectorbackend.hf.space/predict" +files = {"file": open("image.jpg", "rb")} +data = {"explain": True} + +response = requests.post(url, files=files, data=data) +result = response.json() + +print(f"Prediction: {result['prediction']}") +print(f"Confidence: {result['confidence']:.2%}") +print(f"Explanation: {result['explanation']}") +``` + +## 🎯 Response Format + +```json +{ + "prediction": "fake", + "confidence": 0.8734, + "probabilities": { + "real": 0.1266, + "fake": 0.8734 + }, + "model_predictions": { + "cnn_transfer": {"prediction": "fake", "confidence": 0.89}, + "vit_base": {"prediction": "fake", "confidence": 0.92}, + "deit": {"prediction": "fake", "confidence": 0.85}, + "gradient_field": {"prediction": "real", "confidence": 0.55}, + "fft_cnn": {"prediction": "fake", "confidence": 0.78} + }, + "fusion_confidence": 0.8734, + "explanation": "AI-powered analysis of the prediction...", + "processing_time_ms": 342 +} +``` + +## 🔧 Configuration + +### Required Secrets + +Set these in your Space Settings → Repository secrets: + +| Secret | Description | Required | +|--------|-------------|----------| +| `GOOGLE_API_KEY` | Google Gemini API key for AI explanations | Yes | +| `HF_TOKEN` | Hugging Face token (auto-set by Spaces) | No | + +### Optional Environment Variables + +| Variable | Default | Description | +|----------|---------|-------------| +| `HF_FUSION_REPO_ID` | `DeepFakeDetector/fusion-logreg-final` | Hugging Face model repository | +| `CORS_ORIGINS` | Multiple defaults | Comma-separated allowed CORS origins | +| `GEMINI_MODEL` | `gemini-2.5-flash` | Gemini model for explanations | + +## 🏗️ Architecture + +``` +┌─────────────┐ +│ Client │ +└──────┬──────┘ + │ + ▼ +┌─────────────────────────────────┐ +│ FastAPI Backend │ +│ ┌──────────────────────────┐ │ +│ │ Model Registry │ │ +│ │ ┌────────────────────┐ │ │ +│ │ │ CNN Transfer │ │ │ +│ │ │ ViT Base │ │ │ +│ │ │ DeiT Distilled │ │ │ +│ │ │ Gradient Field │ │ │ +│ │ │ FFT CNN │ │ │ +│ │ └────────────────────┘ │ │ +│ │ ┌────────────────────┐ │ │ +│ │ │ Fusion Ensemble │ │ │ +│ │ │ (LogReg Stacking) │ │ │ +│ │ └────────────────────┘ │ │ +│ └──────────────────────────┘ │ +│ ┌──────────────────────────┐ │ +│ │ Gemini Explainer │ │ +│ └──────────────────────────┘ │ +└─────────────────────────────────┘ +``` + +## 📊 Performance + +- **Accuracy**: ~87% on test set (OpenFake dataset) +- **Inference Time**: ~200-500ms per image (with GPU) +- **Model Size**: ~500MB total +- **Supported Formats**: JPG, PNG, WEBP + +## 🐛 Troubleshooting + +### Models not loading? +- Check the Logs tab for specific errors +- Verify `HF_FUSION_REPO_ID` points to a valid repository +- Ensure the repository is public or `HF_TOKEN` is set + +### Explanations not working? +- Verify `GOOGLE_API_KEY` is set in Space Settings +- Check if you have Gemini API quota remaining +- Review logs for API errors + +### CORS errors? +- Add your frontend domain to `CORS_ORIGINS` in Space Settings +- Format: `https://yourdomain.com,https://www.yourdomain.com` + +## 📚 Documentation + +- **Interactive Docs**: Visit `/docs` for Swagger UI +- **ReDoc**: Visit `/redoc` for alternative documentation +- **Source Code**: [GitHub Repository](https://github.com/lukhsaankumar/DeepFakeDetector) + +## 📝 License + +This project is part of the MacAI Society research initiative. + +## 🙏 Acknowledgments + +- Models trained on OpenFake, ImageNet, and custom datasets +- Powered by PyTorch, Hugging Face, and FastAPI +- AI explanations by Google Gemini + +--- + +**Built with ❤️ by MacAI Society** diff --git a/README_HF.md b/README_HF.md new file mode 100644 index 0000000000000000000000000000000000000000..f6eb7ab792b1515bf0d303781af89a6886591eb7 --- /dev/null +++ b/README_HF.md @@ -0,0 +1,182 @@ +--- +title: DeepFake Detector API +emoji: 🎭 +colorFrom: blue +colorTo: purple +sdk: docker +app_port: 7860 +--- + +# 🎭 DeepFake Detector API + +FastAPI backend for detecting AI-generated (deepfake) images using an ensemble of state-of-the-art deep learning models. + +## 🤖 Models + +This API uses a fusion ensemble of 5 deep learning models: + +- **CNN Transfer** (EfficientNet-B0) - Transfer learning from ImageNet +- **ViT Base** (Vision Transformer) - Attention-based architecture +- **DeiT Distilled** (Data-efficient Image Transformer) - Distilled ViT variant +- **Gradient Field CNN** - Custom architecture analyzing gradient patterns +- **FFT CNN** - Frequency domain analysis using Fast Fourier Transform + +All models are combined using a **Logistic Regression stacking ensemble** for optimal accuracy. + +## 🔗 API Endpoints + +| Endpoint | Method | Description | +|----------|--------|-------------| +| `/health` | GET | Health check - returns API status | +| `/ready` | GET | Model readiness check - confirms models are loaded | +| `/models` | GET | List all loaded models with metadata | +| `/predict` | POST | Predict if an image is real or AI-generated | +| `/docs` | GET | Interactive Swagger API documentation | +| `/redoc` | GET | Alternative API documentation | + +## 🚀 Usage Example + +### Using cURL + +```bash +# Check if API is ready +curl https://lukhsaankumar-deepfakedetectorbackend.hf.space/ready + +# Make a prediction +curl -X POST "https://lukhsaankumar-deepfakedetectorbackend.hf.space/predict" \ + -F "file=@image.jpg" \ + -F "explain=true" +``` + +### Using Python + +```python +import requests + +# Upload an image for prediction +url = "https://lukhsaankumar-deepfakedetectorbackend.hf.space/predict" +files = {"file": open("image.jpg", "rb")} +data = {"explain": True} + +response = requests.post(url, files=files, data=data) +result = response.json() + +print(f"Prediction: {result['prediction']}") +print(f"Confidence: {result['confidence']:.2%}") +print(f"Explanation: {result['explanation']}") +``` + +## 🎯 Response Format + +```json +{ + "prediction": "fake", + "confidence": 0.8734, + "probabilities": { + "real": 0.1266, + "fake": 0.8734 + }, + "model_predictions": { + "cnn_transfer": {"prediction": "fake", "confidence": 0.89}, + "vit_base": {"prediction": "fake", "confidence": 0.92}, + "deit": {"prediction": "fake", "confidence": 0.85}, + "gradient_field": {"prediction": "real", "confidence": 0.55}, + "fft_cnn": {"prediction": "fake", "confidence": 0.78} + }, + "fusion_confidence": 0.8734, + "explanation": "AI-powered analysis of the prediction...", + "processing_time_ms": 342 +} +``` + +## 🔧 Configuration + +### Required Secrets + +Set these in your Space Settings → Repository secrets: + +| Secret | Description | Required | +|--------|-------------|----------| +| `GOOGLE_API_KEY` | Google Gemini API key for AI explanations | Yes | +| `HF_TOKEN` | Hugging Face token (auto-set by Spaces) | No | + +### Optional Environment Variables + +| Variable | Default | Description | +|----------|---------|-------------| +| `HF_FUSION_REPO_ID` | `DeepFakeDetector/fusion-logreg-final` | Hugging Face model repository | +| `CORS_ORIGINS` | Multiple defaults | Comma-separated allowed CORS origins | +| `GEMINI_MODEL` | `gemini-2.5-flash` | Gemini model for explanations | + +## 🏗️ Architecture + +``` +┌─────────────┐ +│ Client │ +└──────┬──────┘ + │ + ▼ +┌─────────────────────────────────┐ +│ FastAPI Backend │ +│ ┌──────────────────────────┐ │ +│ │ Model Registry │ │ +│ │ ┌────────────────────┐ │ │ +│ │ │ CNN Transfer │ │ │ +│ │ │ ViT Base │ │ │ +│ │ │ DeiT Distilled │ │ │ +│ │ │ Gradient Field │ │ │ +│ │ │ FFT CNN │ │ │ +│ │ └────────────────────┘ │ │ +│ │ ┌────────────────────┐ │ │ +│ │ │ Fusion Ensemble │ │ │ +│ │ │ (LogReg Stacking) │ │ │ +│ │ └────────────────────┘ │ │ +│ └──────────────────────────┘ │ +│ ┌──────────────────────────┐ │ +│ │ Gemini Explainer │ │ +│ └──────────────────────────┘ │ +└─────────────────────────────────┘ +``` + +## 📊 Performance + +- **Accuracy**: ~87% on test set (OpenFake dataset) +- **Inference Time**: ~200-500ms per image (with GPU) +- **Model Size**: ~500MB total +- **Supported Formats**: JPG, PNG, WEBP + +## 🐛 Troubleshooting + +### Models not loading? +- Check the Logs tab for specific errors +- Verify `HF_FUSION_REPO_ID` points to a valid repository +- Ensure the repository is public or `HF_TOKEN` is set + +### Explanations not working? +- Verify `GOOGLE_API_KEY` is set in Space Settings +- Check if you have Gemini API quota remaining +- Review logs for API errors + +### CORS errors? +- Add your frontend domain to `CORS_ORIGINS` in Space Settings +- Format: `https://yourdomain.com,https://www.yourdomain.com` + +## 📚 Documentation + +- **Interactive Docs**: Visit `/docs` for Swagger UI +- **ReDoc**: Visit `/redoc` for alternative documentation +- **Source Code**: [GitHub Repository](https://github.com/lukhsaankumar/DeepFakeDetector) + +## 📝 License + +This project is part of the MacAI Society research initiative. + +## 🙏 Acknowledgments + +- Models trained on OpenFake, ImageNet, and custom datasets +- Powered by PyTorch, Hugging Face, and FastAPI +- AI explanations by Google Gemini + +--- + +**Built with ❤️ by MacAI Society** diff --git a/app/__init__.py b/app/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1844cddfa2938958d3f372459c223ee1000adbe2 --- /dev/null +++ b/app/__init__.py @@ -0,0 +1 @@ +# DeepFake Detector Backend Application diff --git a/app/__pycache__/__init__.cpython-312.pyc b/app/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e595a285a5ec57466f53b517dab7af0284b677e6 Binary files /dev/null and b/app/__pycache__/__init__.cpython-312.pyc differ diff --git a/app/__pycache__/main.cpython-312.pyc b/app/__pycache__/main.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..344fccc4e886feeacc2f0f1b0cf6409f7575c5cc Binary files /dev/null and b/app/__pycache__/main.cpython-312.pyc differ diff --git a/app/api/__init__.py b/app/api/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9c7f58ef7befcb42f1b879de6c1ef5f6705c9188 --- /dev/null +++ b/app/api/__init__.py @@ -0,0 +1 @@ +# API module diff --git a/app/api/__pycache__/__init__.cpython-312.pyc b/app/api/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1fbcb4e5bdb46755a7df9939caa50ab5a392aa4d Binary files /dev/null and b/app/api/__pycache__/__init__.cpython-312.pyc differ diff --git a/app/api/__pycache__/routes_health.cpython-312.pyc b/app/api/__pycache__/routes_health.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b51ec2efce69b85c620bc4c2addc1ca23c95b004 Binary files /dev/null and b/app/api/__pycache__/routes_health.cpython-312.pyc differ diff --git a/app/api/__pycache__/routes_models.cpython-312.pyc b/app/api/__pycache__/routes_models.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df91288dbb4de0e98a0bd0aa74cd56f5c9a306c0 Binary files /dev/null and b/app/api/__pycache__/routes_models.cpython-312.pyc differ diff --git a/app/api/__pycache__/routes_predict.cpython-312.pyc b/app/api/__pycache__/routes_predict.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94b8c62637f36c7b2cce938bd8513d443d0d0818 Binary files /dev/null and b/app/api/__pycache__/routes_predict.cpython-312.pyc differ diff --git a/app/api/routes_health.py b/app/api/routes_health.py new file mode 100644 index 0000000000000000000000000000000000000000..0f8c036cb81ad1270f863dc094bb04e2e96fd0e0 --- /dev/null +++ b/app/api/routes_health.py @@ -0,0 +1,62 @@ +""" +Health check routes. +""" + +from fastapi import APIRouter + +from app.core.logging import get_logger +from app.schemas.models import HealthResponse, ReadyResponse +from app.services.model_registry import get_model_registry + +logger = get_logger(__name__) +router = APIRouter(tags=["health"]) + + +@router.get( + "/health", + response_model=HealthResponse, + summary="Health check", + description="Simple health check to verify the API is running" +) +async def health_check() -> HealthResponse: + """ + Health check endpoint. + + Returns OK if the API server is running. + """ + return HealthResponse(status="ok") + + +@router.get( + "/ready", + response_model=ReadyResponse, + summary="Readiness check", + description="Check if models are loaded and the API is ready to serve predictions" +) +async def readiness_check() -> ReadyResponse: + """ + Readiness check endpoint. + + Verifies that models are loaded and ready for inference. + Returns detailed information about loaded models. + """ + registry = get_model_registry() + + if not registry.is_loaded: + return ReadyResponse( + status="not_ready", + models_loaded=False, + fusion_repo=None, + submodels=[] + ) + + return ReadyResponse( + status="ready", + models_loaded=True, + fusion_repo=registry.get_fusion_repo_id(), + submodels=[ + model["repo_id"] + for model in registry.list_models() + if model["model_type"] == "submodel" + ] + ) diff --git a/app/api/routes_models.py b/app/api/routes_models.py new file mode 100644 index 0000000000000000000000000000000000000000..cc53f8e4e55479803978377a8f56df7620b41190 --- /dev/null +++ b/app/api/routes_models.py @@ -0,0 +1,51 @@ +""" +Model listing routes. +""" + +from fastapi import APIRouter + +from app.core.logging import get_logger +from app.schemas.models import ModelsListResponse, ModelInfo +from app.services.model_registry import get_model_registry + +logger = get_logger(__name__) +router = APIRouter(tags=["models"]) + + +@router.get( + "/models", + response_model=ModelsListResponse, + summary="List loaded models", + description="Get information about all loaded models including fusion and submodels" +) +async def list_models() -> ModelsListResponse: + """ + List all loaded models. + + Returns information about the fusion model and all submodels, + including their Hugging Face repository IDs and configurations. + """ + registry = get_model_registry() + models = registry.list_models() + + fusion_info = None + submodels_info = [] + + for model in models: + model_info = ModelInfo( + repo_id=model["repo_id"], + name=model["name"], + model_type=model["model_type"], + config=model.get("config") + ) + + if model["model_type"] == "fusion": + fusion_info = model_info + else: + submodels_info.append(model_info) + + return ModelsListResponse( + fusion=fusion_info, + submodels=submodels_info, + total_count=len(models) + ) diff --git a/app/api/routes_predict.py b/app/api/routes_predict.py new file mode 100644 index 0000000000000000000000000000000000000000..9ba65ec084039fbe30b02084cbeb7511be0567db --- /dev/null +++ b/app/api/routes_predict.py @@ -0,0 +1,286 @@ +""" +Prediction routes. +""" + +import base64 +from typing import Optional + +from fastapi import APIRouter, File, Form, HTTPException, Query, UploadFile + +from app.core.errors import ( + DeepFakeDetectorError, + ImageProcessingError, + InferenceError, + FusionError, + ModelNotFoundError, + ModelNotLoadedError +) +from app.core.logging import get_logger +from app.schemas.predict import ( + PredictResponse, + PredictionResult, + TimingInfo, + ErrorResponse, + FusionMeta, + ModelDisplayInfo, + ExplainModelResponse, + SingleModelInsight +) +from app.services.inference_service import get_inference_service +from app.services.fusion_service import get_fusion_service +from app.services.preprocess_service import get_preprocess_service +from app.services.model_registry import get_model_registry +from app.services.llm_service import get_llm_service, get_model_display_info, MODEL_DISPLAY_INFO +from app.utils.timing import Timer + +logger = get_logger(__name__) +router = APIRouter(tags=["predict"]) + + +@router.post( + "/predict", + response_model=PredictResponse, + summary="Predict if image is real or fake", + description="Upload an image to get a deepfake detection prediction", + responses={ + 400: {"model": ErrorResponse, "description": "Invalid image or request"}, + 404: {"model": ErrorResponse, "description": "Model not found"}, + 500: {"model": ErrorResponse, "description": "Inference error"} + } +) +async def predict( + image: UploadFile = File(..., description="Image file to analyze"), + use_fusion: bool = Query( + True, + description="Use fusion model (majority vote) across all submodels" + ), + model: Optional[str] = Query( + None, + description="Specific submodel to use (name or repo_id). Only used when use_fusion=false" + ), + return_submodels: Optional[bool] = Query( + None, + description="Include individual submodel predictions in response. Defaults to true when use_fusion=true" + ), + explain: bool = Query( + True, + description="Generate explainability heatmaps (Grad-CAM for CNNs, attention rollout for transformers)" + ) +) -> PredictResponse: + """ + Predict if an uploaded image is real or fake. + + When use_fusion=true (default): + - Runs all submodels on the image + - Combines predictions using majority vote fusion + - Returns the fused result plus optionally individual submodel results + + When use_fusion=false: + - Runs only the specified submodel (or the first available if not specified) + - Returns just that model's prediction + + Response includes timing information for each step. + """ + timer = Timer() + timer.start_total() + + # Determine if we should return submodel results + should_return_submodels = return_submodels if return_submodels is not None else use_fusion + + try: + # Read image bytes + with timer.measure("download"): + image_bytes = await image.read() + + # Validate and preprocess + with timer.measure("preprocess"): + preprocess_service = get_preprocess_service() + preprocess_service.validate_image(image_bytes) + + inference_service = get_inference_service() + fusion_service = get_fusion_service() + registry = get_model_registry() + + if use_fusion: + # Run all submodels + with timer.measure("inference"): + submodel_outputs = inference_service.predict_all_submodels( + image_bytes=image_bytes, + explain=explain + ) + + # Run fusion + with timer.measure("fusion"): + final_result = fusion_service.fuse(submodel_outputs=submodel_outputs) + + timer.stop_total() + + # Extract fusion meta (contribution percentages) + fusion_meta_dict = final_result.get("meta", {}) + contribution_percentages = fusion_meta_dict.get("contribution_percentages", {}) + + # Build fusion meta object + fusion_meta = FusionMeta( + submodel_weights=fusion_meta_dict.get("submodel_weights", {}), + weighted_contributions=fusion_meta_dict.get("weighted_contributions", {}), + contribution_percentages=contribution_percentages + ) if fusion_meta_dict else None + + # Build model display info for frontend + model_display_info = { + name: ModelDisplayInfo(**get_model_display_info(name)) + for name in submodel_outputs.keys() + } + + # Build response + return PredictResponse( + final=PredictionResult( + pred=final_result["pred"], + pred_int=final_result["pred_int"], + prob_fake=final_result["prob_fake"] + ), + fusion_used=True, + submodels={ + name: PredictionResult( + pred=output["pred"], + pred_int=output["pred_int"], + prob_fake=output["prob_fake"], + heatmap_base64=output.get("heatmap_base64"), + explainability_type=output.get("explainability_type"), + focus_summary=output.get("focus_summary"), + contribution_percentage=contribution_percentages.get(name) + ) + for name, output in submodel_outputs.items() + } if should_return_submodels else None, + fusion_meta=fusion_meta, + model_display_info=model_display_info if should_return_submodels else None, + timing_ms=TimingInfo(**timer.get_timings()) + ) + + else: + # Single model prediction + model_key = model or registry.get_submodel_names()[0] + + with timer.measure("inference"): + result = inference_service.predict_single( + model_key=model_key, + image_bytes=image_bytes, + explain=explain + ) + + timer.stop_total() + + return PredictResponse( + final=PredictionResult( + pred=result["pred"], + pred_int=result["pred_int"], + prob_fake=result["prob_fake"], + heatmap_base64=result.get("heatmap_base64"), + explainability_type=result.get("explainability_type"), + focus_summary=result.get("focus_summary") + ), + fusion_used=False, + submodels=None, + timing_ms=TimingInfo(**timer.get_timings()) + ) + + except ImageProcessingError as e: + logger.warning(f"Image processing error: {e.message}") + raise HTTPException( + status_code=400, + detail={"error": "ImageProcessingError", "message": e.message, "details": e.details} + ) + + except ModelNotFoundError as e: + logger.warning(f"Model not found: {e.message}") + raise HTTPException( + status_code=404, + detail={"error": "ModelNotFoundError", "message": e.message, "details": e.details} + ) + + except ModelNotLoadedError as e: + logger.error(f"Models not loaded: {e.message}") + raise HTTPException( + status_code=503, + detail={"error": "ModelNotLoadedError", "message": e.message, "details": e.details} + ) + + except (InferenceError, FusionError) as e: + logger.error(f"Inference/Fusion error: {e.message}") + raise HTTPException( + status_code=500, + detail={"error": type(e).__name__, "message": e.message, "details": e.details} + ) + + except Exception as e: + logger.exception(f"Unexpected error in predict endpoint: {e}") + raise HTTPException( + status_code=500, + detail={"error": "InternalError", "message": str(e)} + ) + + +@router.post("/explain-model", response_model=ExplainModelResponse) +async def explain_model( + image: UploadFile = File(...), + model_name: str = Form(...), + prob_fake: float = Form(...), + contribution_percentage: float = Form(None), + heatmap_base64: str = Form(None), + focus_summary: str = Form(None) +): + """ + Generate an on-demand LLM explanation for a single model's prediction. + This endpoint is token-efficient - only called when user requests insights. + """ + try: + # Read and validate image + image_bytes = await image.read() + if len(image_bytes) == 0: + raise HTTPException(status_code=400, detail="Empty image file") + + # Encode image to base64 for LLM + original_b64 = base64.b64encode(image_bytes).decode('utf-8') + + # Get LLM service + llm_service = get_llm_service() + if not llm_service.enabled: + raise HTTPException( + status_code=503, + detail="LLM service is not enabled. Set GEMINI_API_KEY environment variable." + ) + + # Generate explanation + result = llm_service.generate_single_model_explanation( + model_name=model_name, + original_image_b64=original_b64, + prob_fake=prob_fake, + heatmap_b64=heatmap_base64, + contribution_percentage=contribution_percentage, + focus_summary=focus_summary + ) + + if result is None: + raise HTTPException( + status_code=500, + detail="Failed to generate explanation from LLM" + ) + + return ExplainModelResponse( + model_name=model_name, + insight=SingleModelInsight( + key_finding=result["key_finding"], + what_model_saw=result["what_model_saw"], + important_regions=result["important_regions"], + confidence_qualifier=result["confidence_qualifier"] + ) + ) + + except HTTPException: + raise + except Exception as e: + logger.exception(f"Error generating model explanation: {e}") + raise HTTPException( + status_code=500, + detail={"error": "ExplanationError", "message": str(e)} + ) diff --git a/app/core/__init__.py b/app/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3e83c630dd1490116c306cfaf22393cdebc7b8da --- /dev/null +++ b/app/core/__init__.py @@ -0,0 +1 @@ +# Core module diff --git a/app/core/__pycache__/__init__.cpython-312.pyc b/app/core/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0167dedcbf72e39e72f11afa9607c71267a929f Binary files /dev/null and b/app/core/__pycache__/__init__.cpython-312.pyc differ diff --git a/app/core/__pycache__/config.cpython-312.pyc b/app/core/__pycache__/config.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b597a3c18dc4d66c2ee513193eaf93e180fdc7e1 Binary files /dev/null and b/app/core/__pycache__/config.cpython-312.pyc differ diff --git a/app/core/__pycache__/errors.cpython-312.pyc b/app/core/__pycache__/errors.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d7e0c1ce790dbbf79cf7f4ca495a651072630d8 Binary files /dev/null and b/app/core/__pycache__/errors.cpython-312.pyc differ diff --git a/app/core/__pycache__/logging.cpython-312.pyc b/app/core/__pycache__/logging.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..04ebf358104e029ce2a271e2f8fe9a8bd74e82f7 Binary files /dev/null and b/app/core/__pycache__/logging.cpython-312.pyc differ diff --git a/app/core/config.py b/app/core/config.py new file mode 100644 index 0000000000000000000000000000000000000000..6e36d0a1689e3b2211a129946cac3faf90482856 --- /dev/null +++ b/app/core/config.py @@ -0,0 +1,64 @@ +""" +Application configuration with environment variable support. +""" + +import os +from functools import lru_cache +from pydantic_settings import BaseSettings +from typing import Optional + + +class Settings(BaseSettings): + """Application settings loaded from environment variables.""" + + # Hugging Face configuration + # Available fusion models: + # - DeepFakeDetector/fusion-logreg-final (Logistic Regression - default) + # - DeepFakeDetector/fusion-meta-final (Meta-classifier) + HF_FUSION_REPO_ID: str = "DeepFakeDetector/fusion-logreg-final" + HF_CACHE_DIR: str = ".hf_cache" + HF_TOKEN: Optional[str] = None + + # Google Gemini API configuration + GOOGLE_API_KEY: Optional[str] = None + GEMINI_MODEL: str = "gemini-2.5-flash" + + @property + def llm_enabled(self) -> bool: + """Check if LLM explanations are available.""" + return self.GOOGLE_API_KEY is not None and len(self.GOOGLE_API_KEY) > 0 + + # Application configuration + ENABLE_DEBUG: bool = False + LOG_LEVEL: str = "INFO" + + # Server configuration + HOST: str = "0.0.0.0" + PORT: int = 8000 + + # CORS configuration + CORS_ORIGINS: str = "http://localhost:5173,http://localhost:3000,https://www.deepfake-detector.app,https://deepfake-detector.app" + + @property + def cors_origins_list(self) -> list[str]: + """Parse CORS origins from comma-separated string.""" + return [origin.strip() for origin in self.CORS_ORIGINS.split(",") if origin.strip()] + + # API configuration + API_V1_PREFIX: str = "/api/v1" + PROJECT_NAME: str = "DeepFake Detector API" + VERSION: str = "0.1.0" + + class Config: + env_file = ".env" + env_file_encoding = "utf-8" + case_sensitive = True + + +@lru_cache() +def get_settings() -> Settings: + """Get cached settings instance.""" + return Settings() + + +settings = get_settings() diff --git a/app/core/errors.py b/app/core/errors.py new file mode 100644 index 0000000000000000000000000000000000000000..81c8f987ab21332800887755d53a16d2ecf40b0b --- /dev/null +++ b/app/core/errors.py @@ -0,0 +1,53 @@ +""" +Custom exceptions and error handling for the application. +""" + +from typing import Any, Dict, Optional + + +class DeepFakeDetectorError(Exception): + """Base exception for DeepFake Detector application.""" + + def __init__( + self, + message: str, + details: Optional[Dict[str, Any]] = None + ): + self.message = message + self.details = details or {} + super().__init__(self.message) + + +class ModelNotLoadedError(DeepFakeDetectorError): + """Raised when attempting to use a model that hasn't been loaded.""" + pass + + +class ModelNotFoundError(DeepFakeDetectorError): + """Raised when a requested model is not found in the registry.""" + pass + + +class HuggingFaceDownloadError(DeepFakeDetectorError): + """Raised when downloading from Hugging Face fails.""" + pass + + +class ImageProcessingError(DeepFakeDetectorError): + """Raised when image processing/decoding fails.""" + pass + + +class InferenceError(DeepFakeDetectorError): + """Raised when model inference fails.""" + pass + + +class FusionError(DeepFakeDetectorError): + """Raised when fusion prediction fails.""" + pass + + +class ConfigurationError(DeepFakeDetectorError): + """Raised when configuration is invalid or missing.""" + pass diff --git a/app/core/logging.py b/app/core/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..e2f7ba227174b011dbed3513ee12017691c06fe3 --- /dev/null +++ b/app/core/logging.py @@ -0,0 +1,61 @@ +""" +Logging configuration for the application. +""" + +import logging +import sys +from typing import Optional + +from app.core.config import settings + + +def setup_logging(level: Optional[str] = None) -> logging.Logger: + """ + Set up application logging. + + Args: + level: Log level string (DEBUG, INFO, WARNING, ERROR, CRITICAL) + + Returns: + Configured logger instance + """ + log_level = level or settings.LOG_LEVEL + + # Create formatter + formatter = logging.Formatter( + fmt="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S" + ) + + # Configure root logger + root_logger = logging.getLogger() + root_logger.setLevel(getattr(logging, log_level.upper(), logging.INFO)) + + # Remove existing handlers + for handler in root_logger.handlers[:]: + root_logger.removeHandler(handler) + + # Add stdout handler + stdout_handler = logging.StreamHandler(sys.stdout) + stdout_handler.setFormatter(formatter) + root_logger.addHandler(stdout_handler) + + # Set third-party loggers to WARNING to reduce noise + logging.getLogger("uvicorn").setLevel(logging.WARNING) + logging.getLogger("httpx").setLevel(logging.WARNING) + logging.getLogger("huggingface_hub").setLevel(logging.WARNING) + + return root_logger + + +def get_logger(name: str) -> logging.Logger: + """ + Get a named logger instance. + + Args: + name: Logger name (typically __name__) + + Returns: + Logger instance + """ + return logging.getLogger(name) diff --git a/app/main.py b/app/main.py new file mode 100644 index 0000000000000000000000000000000000000000..6b8ee91e929158a0e210891405ac8936e32c2685 --- /dev/null +++ b/app/main.py @@ -0,0 +1,128 @@ +""" +FastAPI application entry point. + +DeepFake Detector API - Milestone 1: Hugging Face hosted dummy models. +""" + +from contextlib import asynccontextmanager +from typing import AsyncGenerator + +from fastapi import FastAPI, Request +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse + +from app.api import routes_health, routes_models, routes_predict +from app.core.config import settings +from app.core.errors import DeepFakeDetectorError +from app.core.logging import setup_logging, get_logger +from app.services.model_registry import get_model_registry + +# Set up logging +setup_logging() +logger = get_logger(__name__) + + +@asynccontextmanager +async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: + """ + Application lifespan manager. + + Handles startup and shutdown events: + - Startup: Load models from Hugging Face + - Shutdown: Cleanup resources + """ + # Startup + logger.info("Starting DeepFake Detector API...") + logger.info(f"Configuration: HF_FUSION_REPO_ID={settings.HF_FUSION_REPO_ID}") + logger.info(f"Configuration: HF_CACHE_DIR={settings.HF_CACHE_DIR}") + + # Load models from Hugging Face + try: + registry = get_model_registry() + await registry.load_from_fusion_repo(settings.HF_FUSION_REPO_ID) + logger.info("Models loaded successfully!") + except Exception as e: + logger.error(f"Failed to load models on startup: {e}") + logger.warning("API will start but /ready will report not_ready until models are loaded") + + yield # Application runs here + + # Shutdown + logger.info("Shutting down DeepFake Detector API...") + + +# Create FastAPI application +app = FastAPI( + title=settings.PROJECT_NAME, + version=settings.VERSION, + description=""" + DeepFake Detector API - Analyze images to detect AI-generated content. + + ## Features + + - **Fusion prediction**: Combines multiple model predictions using majority vote + - **Individual model prediction**: Run specific submodels directly + - **Timing information**: Detailed performance metrics for each request + + ## Milestone 1 + + This is the initial milestone using dummy random models hosted on Hugging Face + for testing the API infrastructure. + """, + lifespan=lifespan, + debug=settings.ENABLE_DEBUG +) + +# Add CORS middleware +app.add_middleware( + CORSMiddleware, + allow_origins=settings.cors_origins_list, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +logger.info(f"CORS enabled for origins: {settings.cors_origins_list}") + + +# Global exception handler for custom errors +@app.exception_handler(DeepFakeDetectorError) +async def deepfake_error_handler(request: Request, exc: DeepFakeDetectorError): + """Handle custom DeepFakeDetector exceptions.""" + return JSONResponse( + status_code=500, + content={ + "error": type(exc).__name__, + "message": exc.message, + "details": exc.details + } + ) + + +# Include routers +app.include_router(routes_health.router) +app.include_router(routes_models.router) +app.include_router(routes_predict.router) + + +# Root endpoint +@app.get("/", tags=["root"]) +async def root(): + """Root endpoint with API information.""" + return { + "name": settings.PROJECT_NAME, + "version": settings.VERSION, + "docs": "/docs", + "health": "/health", + "ready": "/ready" + } + + +if __name__ == "__main__": + import uvicorn + uvicorn.run( + "app.main:app", + host=settings.HOST, + port=settings.PORT, + reload=settings.ENABLE_DEBUG + ) diff --git a/app/models/__init__.py b/app/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e2313c5f27005d22f22a8d72437a122e4ef8ba68 --- /dev/null +++ b/app/models/__init__.py @@ -0,0 +1 @@ +# Models module diff --git a/app/models/__pycache__/__init__.cpython-312.pyc b/app/models/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5d5a42d59814dc6faa7d2b763ee559fadacba04 Binary files /dev/null and b/app/models/__pycache__/__init__.cpython-312.pyc differ diff --git a/app/models/wrappers/__init__.py b/app/models/wrappers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..46df5bfe55d8c1d34a9eb148fa429a660f4f11bb --- /dev/null +++ b/app/models/wrappers/__init__.py @@ -0,0 +1 @@ +# Model wrappers module diff --git a/app/models/wrappers/__pycache__/__init__.cpython-312.pyc b/app/models/wrappers/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ff574cf7e797af8a08c83cea347ad7ab35c1746 Binary files /dev/null and b/app/models/wrappers/__pycache__/__init__.cpython-312.pyc differ diff --git a/app/models/wrappers/__pycache__/base_wrapper.cpython-312.pyc b/app/models/wrappers/__pycache__/base_wrapper.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79085f08459bd10c40d3530b663d405d1b47f8ab Binary files /dev/null and b/app/models/wrappers/__pycache__/base_wrapper.cpython-312.pyc differ diff --git a/app/models/wrappers/__pycache__/cnn_transfer_wrapper.cpython-312.pyc b/app/models/wrappers/__pycache__/cnn_transfer_wrapper.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..22342605534905f5eafe151f3d7084d816c3af6f Binary files /dev/null and b/app/models/wrappers/__pycache__/cnn_transfer_wrapper.cpython-312.pyc differ diff --git a/app/models/wrappers/__pycache__/deit_distilled_wrapper.cpython-312.pyc b/app/models/wrappers/__pycache__/deit_distilled_wrapper.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..86d3b7e6cf5c64c3875f7ce652d72ec76800da33 Binary files /dev/null and b/app/models/wrappers/__pycache__/deit_distilled_wrapper.cpython-312.pyc differ diff --git a/app/models/wrappers/__pycache__/dummy_majority_fusion_wrapper.cpython-312.pyc b/app/models/wrappers/__pycache__/dummy_majority_fusion_wrapper.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dbc0e167344df87cda3bc8d30c5f2c817ecc50b7 Binary files /dev/null and b/app/models/wrappers/__pycache__/dummy_majority_fusion_wrapper.cpython-312.pyc differ diff --git a/app/models/wrappers/__pycache__/dummy_random_wrapper.cpython-312.pyc b/app/models/wrappers/__pycache__/dummy_random_wrapper.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2aa149717ef17b1a153acfde85c383f90658f9df Binary files /dev/null and b/app/models/wrappers/__pycache__/dummy_random_wrapper.cpython-312.pyc differ diff --git a/app/models/wrappers/__pycache__/gradfield_cnn_wrapper.cpython-312.pyc b/app/models/wrappers/__pycache__/gradfield_cnn_wrapper.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4dc590844f4d51820aa01b267e0970d83584e431 Binary files /dev/null and b/app/models/wrappers/__pycache__/gradfield_cnn_wrapper.cpython-312.pyc differ diff --git a/app/models/wrappers/__pycache__/logreg_fusion_wrapper.cpython-312.pyc b/app/models/wrappers/__pycache__/logreg_fusion_wrapper.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5c67d07f7427d4618495a6abb7ce76ab8cb26cf8 Binary files /dev/null and b/app/models/wrappers/__pycache__/logreg_fusion_wrapper.cpython-312.pyc differ diff --git a/app/models/wrappers/__pycache__/vit_base_wrapper.cpython-312.pyc b/app/models/wrappers/__pycache__/vit_base_wrapper.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..12d1de22b39c3fff6b5c7e045f299e673753480b Binary files /dev/null and b/app/models/wrappers/__pycache__/vit_base_wrapper.cpython-312.pyc differ diff --git a/app/models/wrappers/base_wrapper.py b/app/models/wrappers/base_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..216f234507a05c6a8f9449d69ff424f74f02ab50 --- /dev/null +++ b/app/models/wrappers/base_wrapper.py @@ -0,0 +1,150 @@ +""" +Base wrapper class for model wrappers. +""" + +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, Optional + +from PIL import Image + + +class BaseModelWrapper(ABC): + """ + Abstract base class for model wrappers. + + All model wrappers should inherit from this class and implement + the abstract methods. + """ + + def __init__( + self, + repo_id: str, + config: Dict[str, Any], + local_path: str + ): + """ + Initialize the wrapper. + + Args: + repo_id: Hugging Face repository ID + config: Configuration from config.json + local_path: Local path where the model files are stored + """ + self.repo_id = repo_id + self.config = config + self.local_path = local_path + self._predict_fn: Optional[Callable] = None + + @property + def name(self) -> str: + """ + Get the short name of the model. + + Prefers 'name' from config if available, otherwise derives from repo_id. + Strips '-final' suffix to ensure consistency with fusion configs. + """ + # Try to get name from config first + config_name = self.config.get("name") + if config_name: + # Strip -final suffix if present + return config_name.replace("-final", "") + + # Fall back to repo_id last part, strip -final suffix + repo_name = self.repo_id.split("/")[-1] + return repo_name.replace("-final", "") + + @abstractmethod + def load(self) -> None: + """ + Load the model and prepare for inference. + + This method should import the predict function from the downloaded + repository and store it for later use. + """ + pass + + @abstractmethod + def predict(self, *args, **kwargs) -> Dict[str, Any]: + """ + Run prediction. + + Returns: + Dictionary with standardized prediction fields: + - pred_int: 0 (real) or 1 (fake) + - pred: "real" or "fake" + - prob_fake: float probability + - meta: dict with any additional metadata + """ + pass + + def is_loaded(self) -> bool: + """Check if the model is loaded and ready for inference.""" + return self._predict_fn is not None + + def get_info(self) -> Dict[str, Any]: + """ + Get model information. + + Returns: + Dictionary with model info + """ + return { + "repo_id": self.repo_id, + "name": self.name, + "config": self.config, + "local_path": self.local_path, + "is_loaded": self.is_loaded() + } + + +class BaseSubmodelWrapper(BaseModelWrapper): + """Base wrapper for submodels that process images.""" + + @abstractmethod + def predict( + self, + image: Optional[Image.Image] = None, + image_bytes: Optional[bytes] = None, + explain: bool = False, + **kwargs + ) -> Dict[str, Any]: + """ + Run prediction on an image. + + Args: + image: PIL Image object + image_bytes: Raw image bytes (alternative to image) + explain: If True, include explainability heatmap in output + **kwargs: Additional arguments + + Returns: + Standardized prediction dictionary with: + - pred_int: 0 (real) or 1 (fake) + - pred: "real" or "fake" + - prob_fake: float probability + - heatmap_base64: Optional[str] (when explain=True) + - explainability_type: Optional[str] (when explain=True) + """ + pass + + +class BaseFusionWrapper(BaseModelWrapper): + """Base wrapper for fusion models that combine submodel outputs.""" + + @abstractmethod + def predict( + self, + submodel_outputs: Dict[str, Dict[str, Any]], + **kwargs + ) -> Dict[str, Any]: + """ + Run fusion prediction on submodel outputs. + + Args: + submodel_outputs: Dictionary mapping submodel name to its output + **kwargs: Additional arguments + + Returns: + Standardized prediction dictionary + """ + pass diff --git a/app/models/wrappers/cnn_transfer_wrapper.py b/app/models/wrappers/cnn_transfer_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..8af0034bc583139c04901f2e149def9b7bb0c824 --- /dev/null +++ b/app/models/wrappers/cnn_transfer_wrapper.py @@ -0,0 +1,226 @@ +""" +Wrapper for CNN Transfer (EfficientNet-B0) submodel. +""" + +import json +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from pathlib import Path +from typing import Any, Dict, Optional, Tuple +from PIL import Image +from torchvision import transforms +from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights + +from app.core.errors import InferenceError, ConfigurationError +from app.core.logging import get_logger +from app.models.wrappers.base_wrapper import BaseSubmodelWrapper +from app.services.explainability import GradCAM, heatmap_to_base64, compute_focus_summary + +logger = get_logger(__name__) + + +class CNNTransferWrapper(BaseSubmodelWrapper): + """ + Wrapper for CNN Transfer model using EfficientNet-B0 backbone. + + Model expects 224x224 RGB images with ImageNet normalization. + """ + + def __init__( + self, + repo_id: str, + config: Dict[str, Any], + local_path: str + ): + super().__init__(repo_id, config, local_path) + self._model: Optional[nn.Module] = None + self._transform: Optional[transforms.Compose] = None + self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self._threshold = config.get("threshold", 0.5) + logger.info(f"Initialized CNNTransferWrapper for {repo_id}") + + def load(self) -> None: + """Load the EfficientNet-B0 model with trained weights.""" + weights_path = Path(self.local_path) / "model.pth" + preprocess_path = Path(self.local_path) / "preprocess.json" + + if not weights_path.exists(): + raise ConfigurationError( + message=f"model.pth not found in {self.local_path}", + details={"repo_id": self.repo_id, "expected_path": str(weights_path)} + ) + + try: + # Load preprocessing config + preprocess_config = {} + if preprocess_path.exists(): + with open(preprocess_path, "r") as f: + preprocess_config = json.load(f) + + # Build transform pipeline + input_size = preprocess_config.get("input_size", [224, 224]) + if isinstance(input_size, int): + input_size = [input_size, input_size] + + normalize_config = preprocess_config.get("normalize", {}) + mean = normalize_config.get("mean", [0.485, 0.456, 0.406]) + std = normalize_config.get("std", [0.229, 0.224, 0.225]) + + self._transform = transforms.Compose([ + transforms.Resize(input_size), + transforms.ToTensor(), + transforms.Normalize(mean=mean, std=std) + ]) + + # Create model architecture + num_classes = self.config.get("num_classes", 2) + self._model = efficientnet_b0(weights=None) + + # Replace classifier for binary classification + in_features = self._model.classifier[1].in_features + self._model.classifier = nn.Sequential( + nn.Dropout(p=0.2, inplace=True), + nn.Linear(in_features, num_classes) + ) + + # Load trained weights + state_dict = torch.load(weights_path, map_location=self._device, weights_only=True) + self._model.load_state_dict(state_dict) + self._model.to(self._device) + self._model.eval() + + # Mark as loaded + self._predict_fn = self._run_inference + logger.info(f"Loaded CNN Transfer model from {self.repo_id}") + + except ConfigurationError: + raise + except Exception as e: + logger.error(f"Failed to load CNN Transfer model: {e}") + raise ConfigurationError( + message=f"Failed to load model: {e}", + details={"repo_id": self.repo_id, "error": str(e)} + ) + + def _run_inference( + self, + image_tensor: torch.Tensor, + explain: bool = False + ) -> Dict[str, Any]: + """Run model inference on preprocessed tensor.""" + heatmap = None + + if explain: + # Use GradCAM for explainability (requires gradients) + target_layer = self._model.features[-1] # Last MBConv block + gradcam = GradCAM(self._model, target_layer) + try: + # GradCAM needs gradients, so don't use no_grad + logits = self._model(image_tensor) + probs = F.softmax(logits, dim=1) + prob_fake = probs[0, 1].item() + pred_int = 1 if prob_fake >= self._threshold else 0 + + # Compute heatmap for predicted class + heatmap = gradcam( + image_tensor.clone(), + target_class=pred_int, + output_size=(224, 224) + ) + finally: + gradcam.remove_hooks() + else: + with torch.no_grad(): + logits = self._model(image_tensor) + probs = F.softmax(logits, dim=1) + prob_fake = probs[0, 1].item() + pred_int = 1 if prob_fake >= self._threshold else 0 + + result = { + "logits": logits[0].detach().cpu().numpy().tolist(), + "prob_fake": prob_fake, + "pred_int": pred_int + } + + if heatmap is not None: + result["heatmap"] = heatmap + + return result + + def predict( + self, + image: Optional[Image.Image] = None, + image_bytes: Optional[bytes] = None, + explain: bool = False, + **kwargs + ) -> Dict[str, Any]: + """ + Run prediction on an image. + + Args: + image: PIL Image object + image_bytes: Raw image bytes (will be converted to PIL Image) + explain: If True, compute GradCAM heatmap + + Returns: + Standardized prediction dictionary with optional heatmap + """ + if self._model is None or self._transform is None: + raise InferenceError( + message="Model not loaded", + details={"repo_id": self.repo_id} + ) + + try: + # Convert bytes to PIL Image if needed + if image is None and image_bytes is not None: + import io + image = Image.open(io.BytesIO(image_bytes)).convert("RGB") + elif image is not None: + image = image.convert("RGB") + else: + raise InferenceError( + message="No image provided", + details={"repo_id": self.repo_id} + ) + + # Preprocess + image_tensor = self._transform(image).unsqueeze(0).to(self._device) + + # Run inference + result = self._run_inference(image_tensor, explain=explain) + + # Standardize output + labels = self.config.get("labels", {"0": "real", "1": "fake"}) + pred_int = result["pred_int"] + + output = { + "pred_int": pred_int, + "pred": labels.get(str(pred_int), "unknown"), + "prob_fake": result["prob_fake"], + "meta": { + "model": self.name, + "threshold": self._threshold, + "logits": result["logits"] + } + } + + # Add heatmap if requested + if explain and "heatmap" in result: + heatmap = result["heatmap"] + output["heatmap_base64"] = heatmap_to_base64(heatmap) + output["explainability_type"] = "grad_cam" + output["focus_summary"] = compute_focus_summary(heatmap) + + return output + + except InferenceError: + raise + except Exception as e: + logger.error(f"Prediction failed for {self.repo_id}: {e}") + raise InferenceError( + message=f"Prediction failed: {e}", + details={"repo_id": self.repo_id, "error": str(e)} + ) diff --git a/app/models/wrappers/deit_distilled_wrapper.py b/app/models/wrappers/deit_distilled_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..8f53606d1ed56175e8d4631ba926c9a12d23f8e0 --- /dev/null +++ b/app/models/wrappers/deit_distilled_wrapper.py @@ -0,0 +1,312 @@ +""" +Wrapper for DeiT Distilled submodel. +""" + +import json +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple +from PIL import Image +from torchvision import transforms + +try: + import timm + TIMM_AVAILABLE = True +except ImportError: + TIMM_AVAILABLE = False + +from app.core.errors import InferenceError, ConfigurationError +from app.core.logging import get_logger +from app.models.wrappers.base_wrapper import BaseSubmodelWrapper +from app.services.explainability import attention_rollout, heatmap_to_base64, compute_focus_summary + +logger = get_logger(__name__) + + +def create_custom_mlp_head(in_features: int = 768, num_classes: int = 2) -> nn.Sequential: + """ + Create custom MLP head for DeiT model matching training configuration. + + Returns nn.Sequential to match saved state dict keys (0, 1, 4 indices). + """ + return nn.Sequential( + nn.LayerNorm(in_features), # 0 + nn.Linear(in_features, 512), # 1 + nn.GELU(), # 2 (no params) + nn.Dropout(p=0.2), # 3 (no params) + nn.Linear(512, num_classes) # 4 + ) + + +class DeiTDistilledWrapper(BaseSubmodelWrapper): + """ + Wrapper for DeiT Distilled model. + + Model expects 224x224 RGB images with ImageNet normalization. + Uses a custom MLP head for classification. + """ + + def __init__( + self, + repo_id: str, + config: Dict[str, Any], + local_path: str + ): + super().__init__(repo_id, config, local_path) + self._model: Optional[nn.Module] = None + self._transform: Optional[transforms.Compose] = None + self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self._threshold = config.get("threshold", 0.5) + logger.info(f"Initialized DeiTDistilledWrapper for {repo_id}") + + def load(self) -> None: + """Load the DeiT model with custom head and trained weights.""" + if not TIMM_AVAILABLE: + raise ConfigurationError( + message="timm package not installed. Run: pip install timm", + details={"repo_id": self.repo_id} + ) + + weights_path = Path(self.local_path) / "deit_distilled_final.pt" + preprocess_path = Path(self.local_path) / "preprocess.json" + + if not weights_path.exists(): + raise ConfigurationError( + message=f"deit_distilled_final.pt not found in {self.local_path}", + details={"repo_id": self.repo_id, "expected_path": str(weights_path)} + ) + + try: + # Load preprocessing config + preprocess_config = {} + if preprocess_path.exists(): + with open(preprocess_path, "r") as f: + preprocess_config = json.load(f) + + # Build transform pipeline + input_size = preprocess_config.get("input_size", 224) + if isinstance(input_size, list): + input_size = input_size[0] + + normalize_config = preprocess_config.get("normalize", {}) + mean = normalize_config.get("mean", [0.485, 0.456, 0.406]) + std = normalize_config.get("std", [0.229, 0.224, 0.225]) + + # Use bicubic interpolation as specified + interpolation = preprocess_config.get("interpolation", "bicubic") + interp_mode = transforms.InterpolationMode.BICUBIC if interpolation == "bicubic" else transforms.InterpolationMode.BILINEAR + + self._transform = transforms.Compose([ + transforms.Resize((input_size, input_size), interpolation=interp_mode), + transforms.ToTensor(), + transforms.Normalize(mean=mean, std=std) + ]) + + # Create model architecture + model_name = self.config.get("model_name", "deit_base_distilled_patch16_224") + num_classes = self.config.get("num_classes", 2) + + # Create base model without pretrained weights + self._model = timm.create_model(model_name, pretrained=False, num_classes=0) + + # Replace heads with custom MLP heads (Sequential assigned directly) + # Note: state dict has separate keys for head and head_dist, so don't share + hidden_dim = 768 # DeiT base hidden dimension + self._model.head = create_custom_mlp_head(hidden_dim, num_classes) + self._model.head_dist = create_custom_mlp_head(hidden_dim, num_classes) + + # Load trained weights + state_dict = torch.load(weights_path, map_location=self._device, weights_only=True) + self._model.load_state_dict(state_dict) + self._model.to(self._device) + self._model.eval() + + # Mark as loaded + self._predict_fn = self._run_inference + logger.info(f"Loaded DeiT Distilled model from {self.repo_id}") + + except ConfigurationError: + raise + except Exception as e: + logger.error(f"Failed to load DeiT Distilled model: {e}") + raise ConfigurationError( + message=f"Failed to load model: {e}", + details={"repo_id": self.repo_id, "error": str(e)} + ) + + def _run_inference( + self, + image_tensor: torch.Tensor, + explain: bool = False + ) -> Dict[str, Any]: + """Run model inference on preprocessed tensor.""" + heatmap = None + + if explain: + # Collect attention weights from all blocks + attentions: List[torch.Tensor] = [] + handles = [] + + # Hook into attention modules to capture weights + # DeiT blocks structure: blocks[i].attn + def create_attn_hook(): + stored_attn = [] + + def hook(module, inputs, outputs): + # Get q, k from the module's forward computation + # inputs[0] is x of shape [B, N, C] + x = inputs[0] + B, N, C = x.shape + + # Access the attention module's parameters + qkv = module.qkv(x) # [B, N, 3*dim] + qkv = qkv.reshape(B, N, 3, module.num_heads, C // module.num_heads) + qkv = qkv.permute(2, 0, 3, 1, 4) # [3, B, heads, N, dim_head] + q, k, v = qkv[0], qkv[1], qkv[2] + + # Compute attention weights + scale = (C // module.num_heads) ** -0.5 + attn = (q @ k.transpose(-2, -1)) * scale + attn = attn.softmax(dim=-1) # [B, heads, N, N] + + # Average over heads + attn_avg = attn.mean(dim=1) # [B, N, N] + stored_attn.append(attn_avg.detach()) + + return hook, stored_attn + + all_stored_attns = [] + for block in self._model.blocks: + hook_fn, stored = create_attn_hook() + all_stored_attns.append(stored) + handle = block.attn.register_forward_hook(hook_fn) + handles.append(handle) + + try: + with torch.no_grad(): + logits = self._model(image_tensor) + probs = F.softmax(logits, dim=1) + prob_fake = probs[0, 1].item() + pred_int = 1 if prob_fake >= self._threshold else 0 + + # Get attention from hooks + attention_list = [stored[0] for stored in all_stored_attns if len(stored) > 0] + + if attention_list: + # Stack: [num_layers, B, N, N] + attention_stack = torch.stack(attention_list, dim=0) + # Compute rollout - returns (grid_size, grid_size) heatmap + attention_map = attention_rollout( + attention_stack[:, 0], # [num_layers, N, N] + head_fusion="mean", # Already averaged + discard_ratio=0.0, + num_prefix_tokens=2 # DeiT has CLS + distillation token + ) # Returns (14, 14) for DeiT-Base + + # Resize to image size + from PIL import Image as PILImage + heatmap_img = PILImage.fromarray( + (attention_map * 255).astype(np.uint8) + ).resize((224, 224), PILImage.BILINEAR) + heatmap = np.array(heatmap_img).astype(np.float32) / 255.0 + + finally: + for handle in handles: + handle.remove() + else: + with torch.no_grad(): + # In eval mode, DeiT returns single tensor + logits = self._model(image_tensor) + probs = F.softmax(logits, dim=1) + prob_fake = probs[0, 1].item() + pred_int = 1 if prob_fake >= self._threshold else 0 + + result = { + "logits": logits[0].cpu().numpy().tolist(), + "prob_fake": prob_fake, + "pred_int": pred_int + } + + if heatmap is not None: + result["heatmap"] = heatmap + + return result + + def predict( + self, + image: Optional[Image.Image] = None, + image_bytes: Optional[bytes] = None, + explain: bool = False, + **kwargs + ) -> Dict[str, Any]: + """ + Run prediction on an image. + + Args: + image: PIL Image object + image_bytes: Raw image bytes (will be converted to PIL Image) + explain: If True, compute attention rollout heatmap + + Returns: + Standardized prediction dictionary with optional heatmap + """ + if self._model is None or self._transform is None: + raise InferenceError( + message="Model not loaded", + details={"repo_id": self.repo_id} + ) + + try: + # Convert bytes to PIL Image if needed + if image is None and image_bytes is not None: + import io + image = Image.open(io.BytesIO(image_bytes)).convert("RGB") + elif image is not None: + image = image.convert("RGB") + else: + raise InferenceError( + message="No image provided", + details={"repo_id": self.repo_id} + ) + + # Preprocess + image_tensor = self._transform(image).unsqueeze(0).to(self._device) + + # Run inference + result = self._run_inference(image_tensor, explain=explain) + + # Standardize output + class_mapping = self.config.get("class_mapping", {"0": "real", "1": "fake"}) + pred_int = result["pred_int"] + + output = { + "pred_int": pred_int, + "pred": class_mapping.get(str(pred_int), "unknown"), + "prob_fake": result["prob_fake"], + "meta": { + "model": self.name, + "threshold": self._threshold, + "logits": result["logits"] + } + } + + # Add heatmap if requested + if explain and "heatmap" in result: + heatmap = result["heatmap"] + output["heatmap_base64"] = heatmap_to_base64(heatmap) + output["explainability_type"] = "attention_rollout" + output["focus_summary"] = compute_focus_summary(heatmap) + + return output + + except InferenceError: + raise + except Exception as e: + logger.error(f"Prediction failed for {self.repo_id}: {e}") + raise InferenceError( + message=f"Prediction failed: {e}", + details={"repo_id": self.repo_id, "error": str(e)} + ) diff --git a/app/models/wrappers/dummy_majority_fusion_wrapper.py b/app/models/wrappers/dummy_majority_fusion_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..cbcc2118d289ede1f3515bf1411723ae7762c5e0 --- /dev/null +++ b/app/models/wrappers/dummy_majority_fusion_wrapper.py @@ -0,0 +1,171 @@ +""" +Wrapper for dummy majority vote fusion model. +""" + +import importlib.util +import sys +from pathlib import Path +from typing import Any, Dict, List + +from app.core.errors import FusionError, ConfigurationError +from app.core.logging import get_logger +from app.models.wrappers.base_wrapper import BaseFusionWrapper + +logger = get_logger(__name__) + + +class DummyMajorityFusionWrapper(BaseFusionWrapper): + """ + Wrapper for dummy majority vote fusion models. + + These models are hosted on Hugging Face and contain a fusion.py + with a predict() function that performs majority voting on submodel outputs. + """ + + def __init__( + self, + repo_id: str, + config: Dict[str, Any], + local_path: str + ): + """ + Initialize the wrapper. + + Args: + repo_id: Hugging Face repository ID (e.g., "DeepFakeDetector/fusion-majority-test") + config: Configuration from config.json + local_path: Local path where the model files are stored + """ + super().__init__(repo_id, config, local_path) + self._submodel_repos: List[str] = config.get("submodels", []) + logger.info(f"Initialized DummyMajorityFusionWrapper for {repo_id}") + logger.info(f"Submodels: {self._submodel_repos}") + + @property + def submodel_repos(self) -> List[str]: + """Get list of submodel repository IDs.""" + return self._submodel_repos + + def load(self) -> None: + """ + Load the fusion predict function from the downloaded repository. + + Dynamically imports predict.py and extracts the predict function. + """ + fusion_path = Path(self.local_path) / "predict.py" + + if not fusion_path.exists(): + raise ConfigurationError( + message=f"predict.py not found in {self.local_path}", + details={"repo_id": self.repo_id, "expected_path": str(fusion_path)} + ) + + try: + # Create a unique module name to avoid conflicts + module_name = f"hf_model_{self.name.replace('-', '_')}_fusion" + + # Load the module dynamically + spec = importlib.util.spec_from_file_location(module_name, fusion_path) + if spec is None or spec.loader is None: + raise ConfigurationError( + message=f"Could not load spec for {fusion_path}", + details={"repo_id": self.repo_id} + ) + + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + + # Get the predict function + if not hasattr(module, "predict"): + raise ConfigurationError( + message=f"predict.py does not have a 'predict' function", + details={"repo_id": self.repo_id} + ) + + self._predict_fn = module.predict + logger.info(f"Loaded fusion predict function from {self.repo_id}") + + except ConfigurationError: + raise + except Exception as e: + logger.error(f"Failed to load fusion function from {self.repo_id}: {e}") + raise ConfigurationError( + message=f"Failed to load fusion model: {e}", + details={"repo_id": self.repo_id, "error": str(e)} + ) + + def predict( + self, + submodel_outputs: Dict[str, Dict[str, Any]], + **kwargs + ) -> Dict[str, Any]: + """ + Run fusion prediction on submodel outputs. + + Args: + submodel_outputs: Dictionary mapping submodel name to its prediction output + **kwargs: Additional arguments passed to the fusion function + + Returns: + Standardized prediction dictionary with: + - pred_int: 0 or 1 + - pred: "real" or "fake" + - prob_fake: float (average of pred_ints) + - meta: dict + """ + if self._predict_fn is None: + raise FusionError( + message="Fusion model not loaded", + details={"repo_id": self.repo_id} + ) + + try: + # Call the actual fusion predict function from the HF repo + result = self._predict_fn(submodel_outputs=submodel_outputs, **kwargs) + + # Validate and standardize the output + standardized = self._standardize_output(result) + return standardized + + except FusionError: + raise + except Exception as e: + logger.error(f"Fusion prediction failed for {self.repo_id}: {e}") + raise FusionError( + message=f"Fusion prediction failed: {e}", + details={"repo_id": self.repo_id, "error": str(e)} + ) + + def _standardize_output(self, result: Dict[str, Any]) -> Dict[str, Any]: + """ + Standardize the fusion output to ensure consistent format. + + Args: + result: Raw fusion output + + Returns: + Standardized dictionary + """ + pred_int = result.get("pred_int", 0) + + # Ensure pred_int is 0 or 1 + if pred_int not in (0, 1): + pred_int = 1 if pred_int > 0.5 else 0 + + # Generate pred label if not present + pred = result.get("pred") + if pred is None: + pred = "fake" if pred_int == 1 else "real" + + # Generate prob_fake if not present + prob_fake = result.get("prob_fake") + if prob_fake is None: + prob_fake = float(pred_int) + + return { + "pred_int": pred_int, + "pred": pred, + "prob_fake": float(prob_fake), + "meta": result.get("meta", {}) + } diff --git a/app/models/wrappers/dummy_random_wrapper.py b/app/models/wrappers/dummy_random_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..74ad06d8c39ddffc9f55a70b620e312dcd5823ea --- /dev/null +++ b/app/models/wrappers/dummy_random_wrapper.py @@ -0,0 +1,168 @@ +""" +Wrapper for dummy random submodels. +""" + +import importlib.util +import sys +from pathlib import Path +from typing import Any, Dict, Optional + +from PIL import Image + +from app.core.errors import InferenceError, ConfigurationError +from app.core.logging import get_logger +from app.models.wrappers.base_wrapper import BaseSubmodelWrapper + +logger = get_logger(__name__) + + +class DummyRandomWrapper(BaseSubmodelWrapper): + """ + Wrapper for dummy random prediction models. + + These models are hosted on Hugging Face and contain a predict.py + with a predict() function that returns random predictions. + """ + + def __init__( + self, + repo_id: str, + config: Dict[str, Any], + local_path: str + ): + """ + Initialize the wrapper. + + Args: + repo_id: Hugging Face repository ID (e.g., "DeepFakeDetector/test-random-a") + config: Configuration from config.json + local_path: Local path where the model files are stored + """ + super().__init__(repo_id, config, local_path) + logger.info(f"Initialized DummyRandomWrapper for {repo_id}") + + def load(self) -> None: + """ + Load the predict function from the downloaded repository. + + Dynamically imports predict.py and extracts the predict function. + """ + predict_path = Path(self.local_path) / "predict.py" + + if not predict_path.exists(): + raise ConfigurationError( + message=f"predict.py not found in {self.local_path}", + details={"repo_id": self.repo_id, "expected_path": str(predict_path)} + ) + + try: + # Create a unique module name to avoid conflicts + module_name = f"hf_model_{self.name.replace('-', '_')}_predict" + + # Load the module dynamically + spec = importlib.util.spec_from_file_location(module_name, predict_path) + if spec is None or spec.loader is None: + raise ConfigurationError( + message=f"Could not load spec for {predict_path}", + details={"repo_id": self.repo_id} + ) + + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + + # Get the predict function + if not hasattr(module, "predict"): + raise ConfigurationError( + message=f"predict.py does not have a 'predict' function", + details={"repo_id": self.repo_id} + ) + + self._predict_fn = module.predict + logger.info(f"Loaded predict function from {self.repo_id}") + + except ConfigurationError: + raise + except Exception as e: + logger.error(f"Failed to load predict function from {self.repo_id}: {e}") + raise ConfigurationError( + message=f"Failed to load model: {e}", + details={"repo_id": self.repo_id, "error": str(e)} + ) + + def predict( + self, + image: Optional[Image.Image] = None, + image_bytes: Optional[bytes] = None, + **kwargs + ) -> Dict[str, Any]: + """ + Run prediction on an image. + + Args: + image: PIL Image object (optional for dummy model) + image_bytes: Raw image bytes (optional for dummy model) + **kwargs: Additional arguments passed to the model + + Returns: + Standardized prediction dictionary with: + - pred_int: 0 or 1 + - pred: "real" or "fake" + - prob_fake: float + - meta: dict + """ + if self._predict_fn is None: + raise InferenceError( + message="Model not loaded", + details={"repo_id": self.repo_id} + ) + + try: + # Call the actual predict function from the HF repo + result = self._predict_fn(image_bytes=image_bytes, **kwargs) + + # Validate and standardize the output + standardized = self._standardize_output(result) + return standardized + + except InferenceError: + raise + except Exception as e: + logger.error(f"Prediction failed for {self.repo_id}: {e}") + raise InferenceError( + message=f"Prediction failed: {e}", + details={"repo_id": self.repo_id, "error": str(e)} + ) + + def _standardize_output(self, result: Dict[str, Any]) -> Dict[str, Any]: + """ + Standardize the model output to ensure consistent format. + + Args: + result: Raw model output + + Returns: + Standardized dictionary + """ + pred_int = result.get("pred_int", 0) + + # Ensure pred_int is 0 or 1 + if pred_int not in (0, 1): + pred_int = 1 if pred_int > 0.5 else 0 + + # Generate pred label if not present + pred = result.get("pred") + if pred is None: + pred = "fake" if pred_int == 1 else "real" + + # Generate prob_fake if not present + prob_fake = result.get("prob_fake") + if prob_fake is None: + prob_fake = float(pred_int) + + return { + "pred_int": pred_int, + "pred": pred, + "prob_fake": float(prob_fake), + "meta": result.get("meta", {}) + } diff --git a/app/models/wrappers/gradfield_cnn_wrapper.py b/app/models/wrappers/gradfield_cnn_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..5262311e91bdfafb337a46eca236e8ed299e2797 --- /dev/null +++ b/app/models/wrappers/gradfield_cnn_wrapper.py @@ -0,0 +1,401 @@ +""" +Wrapper for Gradient Field CNN submodel. +""" + +import json +import math +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from pathlib import Path +from typing import Any, Dict, Optional, Tuple +from PIL import Image +from torchvision import transforms + +from app.core.errors import InferenceError, ConfigurationError +from app.core.logging import get_logger +from app.models.wrappers.base_wrapper import BaseSubmodelWrapper +from app.services.explainability import heatmap_to_base64, compute_focus_summary + +logger = get_logger(__name__) + + +class CompactGradientNet(nn.Module): + """ + CNN for gradient field classification with discriminative features. + + Input: Luminance image (1-channel) + Internal: Computes 6-channel gradient field [luminance, Gx, Gy, magnitude, angle, coherence] + Output: Logits and embeddings + """ + + def __init__(self, depth=4, base_filters=32, dropout=0.3, embedding_dim=128): + super().__init__() + + # Sobel kernels + sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], + dtype=torch.float32).view(1, 1, 3, 3) + sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], + dtype=torch.float32).view(1, 1, 3, 3) + self.register_buffer('sobel_x', sobel_x) + self.register_buffer('sobel_y', sobel_y) + + # Gaussian kernel for structure tensor smoothing + gaussian = torch.tensor([[1, 4, 6, 4, 1], [4, 16, 24, 16, 4], + [6, 24, 36, 24, 6], [4, 16, 24, 16, 4], + [1, 4, 6, 4, 1]], dtype=torch.float32) / 256.0 + self.register_buffer('gaussian', gaussian.view(1, 1, 5, 5)) + + # Input normalization and channel mixing + self.input_norm = nn.BatchNorm2d(6) + self.channel_mix = nn.Sequential( + nn.Conv2d(6, 6, kernel_size=1), + nn.ReLU() + ) + + # CNN layers + layers = [] + in_ch = 6 + for i in range(depth): + out_ch = base_filters * (2**i) + layers.extend([ + nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), + nn.BatchNorm2d(out_ch), + nn.ReLU(), + nn.MaxPool2d(2) + ]) + if dropout > 0: + layers.append(nn.Dropout2d(dropout)) + in_ch = out_ch + + self.cnn = nn.Sequential(*layers) + self.global_pool = nn.AdaptiveAvgPool2d(1) + self.embedding = nn.Linear(out_ch, embedding_dim) + self.classifier = nn.Linear(embedding_dim, 1) + + def compute_gradient_field(self, luminance): + """Compute 6-channel gradient field on GPU (includes luminance).""" + G_x = F.conv2d(luminance, self.sobel_x, padding=1) + G_y = F.conv2d(luminance, self.sobel_y, padding=1) + + magnitude = torch.sqrt(G_x**2 + G_y**2 + 1e-8) + angle = torch.atan2(G_y, G_x) / math.pi + + # Structure tensor for coherence + Gxx, Gxy, Gyy = G_x * G_x, G_x * G_y, G_y * G_y + Sxx = F.conv2d(Gxx, self.gaussian, padding=2) + Sxy = F.conv2d(Gxy, self.gaussian, padding=2) + Syy = F.conv2d(Gyy, self.gaussian, padding=2) + + trace = Sxx + Syy + det_term = torch.sqrt((Sxx - Syy)**2 + 4 * Sxy**2 + 1e-8) + lambda1, lambda2 = 0.5 * (trace + det_term), 0.5 * (trace - det_term) + coherence = ((lambda1 - lambda2) / (lambda1 + lambda2 + 1e-8))**2 + + magnitude_scaled = torch.log1p(magnitude * 10) + + return torch.cat([luminance, G_x, G_y, magnitude_scaled, angle, coherence], dim=1) + + def forward(self, luminance): + x = self.compute_gradient_field(luminance) + x = self.input_norm(x) + x = self.channel_mix(x) + x = self.cnn(x) + x = self.global_pool(x).flatten(1) + emb = self.embedding(x) + logit = self.classifier(emb) + return logit.squeeze(1), emb + + +class GradfieldCNNWrapper(BaseSubmodelWrapper): + """ + Wrapper for Gradient Field CNN model. + + Model expects 256x256 luminance images. + Internally computes Sobel gradients and other discriminative features. + """ + + # BT.709 luminance coefficients + R_COEFF = 0.2126 + G_COEFF = 0.7152 + B_COEFF = 0.0722 + + def __init__( + self, + repo_id: str, + config: Dict[str, Any], + local_path: str + ): + super().__init__(repo_id, config, local_path) + self._model: Optional[nn.Module] = None + self._resize: Optional[transforms.Resize] = None + self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self._threshold = config.get("threshold", 0.5) + logger.info(f"Initialized GradfieldCNNWrapper for {repo_id}") + + def load(self) -> None: + """Load the Gradient Field CNN model with trained weights.""" + # Try different weight file names + weights_path = None + for fname in ["gradient_field_cnn_v3_finetuned.pth", "gradient_field_cnn_v2.pth", "weights.pt", "model.pth"]: + candidate = Path(self.local_path) / fname + if candidate.exists(): + weights_path = candidate + break + + preprocess_path = Path(self.local_path) / "preprocess.json" + + if weights_path is None: + raise ConfigurationError( + message=f"No weights file found in {self.local_path}", + details={"repo_id": self.repo_id} + ) + + try: + # Load preprocessing config + preprocess_config = {} + if preprocess_path.exists(): + with open(preprocess_path, "r") as f: + preprocess_config = json.load(f) + + # Get input size (default 256 for gradient field) + input_size = preprocess_config.get("input_size", 256) + if isinstance(input_size, list): + input_size = input_size[0] + + self._resize = transforms.Resize((input_size, input_size)) + + # Get model parameters from config + model_params = self.config.get("model_parameters", {}) + depth = model_params.get("depth", 4) + base_filters = model_params.get("base_filters", 32) + dropout = model_params.get("dropout", 0.3) + embedding_dim = model_params.get("embedding_dim", 128) + + # Create model + self._model = CompactGradientNet( + depth=depth, + base_filters=base_filters, + dropout=dropout, + embedding_dim=embedding_dim + ) + + # Load trained weights + # Note: weights_only=False needed because checkpoint contains numpy types + state_dict = torch.load(weights_path, map_location=self._device, weights_only=False) + + # Handle different checkpoint formats + if isinstance(state_dict, dict): + if "model_state_dict" in state_dict: + state_dict = state_dict["model_state_dict"] + elif "state_dict" in state_dict: + state_dict = state_dict["state_dict"] + elif "model" in state_dict: + state_dict = state_dict["model"] + + self._model.load_state_dict(state_dict) + self._model.to(self._device) + self._model.eval() + + # Mark as loaded + self._predict_fn = self._run_inference + logger.info(f"Loaded Gradient Field CNN model from {self.repo_id}") + + except ConfigurationError: + raise + except Exception as e: + logger.error(f"Failed to load Gradient Field CNN model: {e}") + raise ConfigurationError( + message=f"Failed to load model: {e}", + details={"repo_id": self.repo_id, "error": str(e)} + ) + + def _rgb_to_luminance(self, img_tensor: torch.Tensor) -> torch.Tensor: + """ + Convert RGB tensor to luminance using BT.709 coefficients. + + Args: + img_tensor: RGB tensor of shape (3, H, W) with values in [0, 1] + + Returns: + Luminance tensor of shape (1, H, W) + """ + luminance = ( + self.R_COEFF * img_tensor[0] + + self.G_COEFF * img_tensor[1] + + self.B_COEFF * img_tensor[2] + ) + return luminance.unsqueeze(0) + + def _run_inference( + self, + luminance_tensor: torch.Tensor, + explain: bool = False + ) -> Dict[str, Any]: + """Run model inference on preprocessed luminance tensor.""" + heatmap = None + + if explain: + # Custom GradCAM implementation for single-logit binary model + # Using absolute CAM values to capture both positive and negative contributions + # Target the last Conv2d layer (cnn[-5]) + target_layer = self._model.cnn[-5] + + activations = None + gradients = None + + def forward_hook(module, input, output): + nonlocal activations + activations = output.detach() + + def backward_hook(module, grad_input, grad_output): + nonlocal gradients + gradients = grad_output[0].detach() + + h_fwd = target_layer.register_forward_hook(forward_hook) + h_bwd = target_layer.register_full_backward_hook(backward_hook) + + try: + # Forward pass with gradients + input_tensor = luminance_tensor.clone().requires_grad_(True) + logits, embedding = self._model(input_tensor) + prob_fake = torch.sigmoid(logits).item() + pred_int = 1 if prob_fake >= self._threshold else 0 + + # Backward pass + self._model.zero_grad() + logits.backward() + + if gradients is not None and activations is not None: + # Compute Grad-CAM weights (global average pooled gradients) + weights = gradients.mean(dim=(2, 3), keepdim=True) # [1, C, 1, 1] + + # Weighted combination of activation maps + cam = (weights * activations).sum(dim=1, keepdim=True) # [1, 1, H, W] + + # Use absolute values instead of ReLU to capture all contributions + # This is important for models where negative gradients carry meaning + cam = torch.abs(cam) + + # Normalize to [0, 1] + cam = cam - cam.min() + cam_max = cam.max() + if cam_max > 0: + cam = cam / cam_max + + # Resize to output size (256x256) + cam = F.interpolate( + cam, + size=(256, 256), + mode='bilinear', + align_corners=False + ) + + heatmap = cam.squeeze().cpu().numpy() + else: + logger.warning("GradCAM: gradients or activations not captured") + heatmap = np.zeros((256, 256), dtype=np.float32) + + finally: + h_fwd.remove() + h_bwd.remove() + else: + with torch.no_grad(): + logits, embedding = self._model(luminance_tensor) + prob_fake = torch.sigmoid(logits).item() + pred_int = 1 if prob_fake >= self._threshold else 0 + + result = { + "logits": logits.detach().cpu().numpy().tolist() if hasattr(logits, 'detach') else logits.cpu().numpy().tolist(), + "prob_fake": prob_fake, + "pred_int": pred_int, + "embedding": embedding.detach().cpu().numpy().tolist() if explain else embedding.cpu().numpy().tolist() + } + + if heatmap is not None: + result["heatmap"] = heatmap + + return result + + def predict( + self, + image: Optional[Image.Image] = None, + image_bytes: Optional[bytes] = None, + explain: bool = False, + **kwargs + ) -> Dict[str, Any]: + """ + Run prediction on an image. + + Args: + image: PIL Image object + image_bytes: Raw image bytes (will be converted to PIL Image) + explain: If True, compute GradCAM heatmap + + Returns: + Standardized prediction dictionary with optional heatmap + """ + if self._model is None or self._resize is None: + raise InferenceError( + message="Model not loaded", + details={"repo_id": self.repo_id} + ) + + try: + # Convert bytes to PIL Image if needed + if image is None and image_bytes is not None: + import io + image = Image.open(io.BytesIO(image_bytes)).convert("RGB") + elif image is not None: + image = image.convert("RGB") + else: + raise InferenceError( + message="No image provided", + details={"repo_id": self.repo_id} + ) + + # Resize + image = self._resize(image) + + # Convert to tensor + img_tensor = transforms.functional.to_tensor(image) + + # Convert to luminance + luminance = self._rgb_to_luminance(img_tensor) + luminance = luminance.unsqueeze(0).to(self._device) # Add batch dim + + # Run inference + result = self._run_inference(luminance, explain=explain) + + # Standardize output + labels = self.config.get("labels", {"0": "real", "1": "fake"}) + pred_int = result["pred_int"] + + output = { + "pred_int": pred_int, + "pred": labels.get(str(pred_int), "unknown"), + "prob_fake": result["prob_fake"], + "meta": { + "model": self.name, + "threshold": self._threshold + } + } + + # Add heatmap if requested + if explain and "heatmap" in result: + heatmap = result["heatmap"] + output["heatmap_base64"] = heatmap_to_base64(heatmap) + output["explainability_type"] = "grad_cam" + output["focus_summary"] = compute_focus_summary(heatmap) + " (edge-based analysis)" + + return output + + except InferenceError: + raise + except Exception as e: + logger.error(f"Prediction failed for {self.repo_id}: {e}") + raise InferenceError( + message=f"Prediction failed: {e}", + details={"repo_id": self.repo_id, "error": str(e)} + ) diff --git a/app/models/wrappers/logreg_fusion_wrapper.py b/app/models/wrappers/logreg_fusion_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..ea40dbdc0b757aabcb635500b48b5ca7d44e228f --- /dev/null +++ b/app/models/wrappers/logreg_fusion_wrapper.py @@ -0,0 +1,161 @@ +""" +Wrapper for logistic regression stacking fusion model. +""" + +import pickle +from pathlib import Path +from typing import Any, Dict, List + +import joblib +import numpy as np + +from app.core.errors import FusionError, ConfigurationError +from app.core.logging import get_logger +from app.models.wrappers.base_wrapper import BaseFusionWrapper + +logger = get_logger(__name__) + + +class LogRegFusionWrapper(BaseFusionWrapper): + """ + Wrapper for probability stacking fusion with logistic regression. + + This fusion model takes probability outputs from submodels, + stacks them into a feature vector, and runs them through a + trained logistic regression classifier. + """ + + def __init__( + self, + repo_id: str, + config: Dict[str, Any], + local_path: str + ): + """ + Initialize the wrapper. + + Args: + repo_id: Hugging Face repository ID + config: Configuration from config.json + local_path: Local path where the model files are stored + """ + super().__init__(repo_id, config, local_path) + self._model = None + self._submodel_order: List[str] = config.get("submodel_order", []) + self._threshold: float = config.get("threshold", 0.5) + logger.info(f"Initialized LogRegFusionWrapper for {repo_id}") + logger.info(f"Submodel order: {self._submodel_order}") + + @property + def submodel_repos(self) -> List[str]: + """Get list of submodel repository IDs.""" + return self.config.get("submodels", []) + + def load(self) -> None: + """ + Load the logistic regression model from the downloaded repository. + + Loads fusion_logreg.pkl using joblib (sklearn models are saved with joblib). + """ + model_path = Path(self.local_path) / "fusion_logreg.pkl" + + if not model_path.exists(): + raise ConfigurationError( + message=f"fusion_logreg.pkl not found in {self.local_path}", + details={"repo_id": self.repo_id, "expected_path": str(model_path)} + ) + + try: + # Use joblib for sklearn models instead of pickle + self._model = joblib.load(model_path) + logger.info(f"Loaded logistic regression fusion model from {self.repo_id}") + + except Exception as e: + logger.error(f"Failed to load fusion model from {self.repo_id}: {e}") + raise ConfigurationError( + message=f"Failed to load fusion model: {e}", + details={"repo_id": self.repo_id, "error": str(e)} + ) + + def predict( + self, + submodel_outputs: Dict[str, Dict[str, Any]], + **kwargs + ) -> Dict[str, Any]: + """ + Run fusion prediction on submodel outputs. + + Stacks submodel probabilities in the correct order and runs + through the logistic regression classifier. + + Args: + submodel_outputs: Dictionary mapping submodel name to its prediction output + Each output must contain "prob_fake" key + **kwargs: Additional arguments (unused) + + Returns: + Standardized prediction dictionary with: + - pred_int: 0 or 1 + - pred: "real" or "fake" + - prob_fake: float probability of being fake + - meta: dict with submodel probabilities + """ + if self._model is None: + raise FusionError( + message="Fusion model not loaded", + details={"repo_id": self.repo_id} + ) + + try: + # Stack submodel probabilities in the correct order + probs = [] + for submodel_name in self._submodel_order: + if submodel_name not in submodel_outputs: + raise FusionError( + message=f"Missing output from submodel: {submodel_name}", + details={ + "repo_id": self.repo_id, + "missing_submodel": submodel_name, + "available_submodels": list(submodel_outputs.keys()) + } + ) + + output = submodel_outputs[submodel_name] + if "prob_fake" not in output: + raise FusionError( + message=f"Submodel output missing 'prob_fake': {submodel_name}", + details={ + "repo_id": self.repo_id, + "submodel": submodel_name, + "output_keys": list(output.keys()) + } + ) + + probs.append(output["prob_fake"]) + + # Convert to numpy array and reshape for sklearn + X = np.array(probs).reshape(1, -1) + + # Get prediction and probability + prob_fake = float(self._model.predict_proba(X)[0, 1]) + pred_int = 1 if prob_fake >= self._threshold else 0 + pred = "fake" if pred_int == 1 else "real" + + return { + "pred_int": pred_int, + "pred": pred, + "prob_fake": prob_fake, + "meta": { + "submodel_probs": dict(zip(self._submodel_order, probs)), + "threshold": self._threshold + } + } + + except FusionError: + raise + except Exception as e: + logger.error(f"Fusion prediction failed for {self.repo_id}: {e}") + raise FusionError( + message=f"Fusion prediction failed: {e}", + details={"repo_id": self.repo_id, "error": str(e)} + ) diff --git a/app/models/wrappers/vit_base_wrapper.py b/app/models/wrappers/vit_base_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..7e443306ccc8e79416e02db5f87cc8bee0247540 --- /dev/null +++ b/app/models/wrappers/vit_base_wrapper.py @@ -0,0 +1,331 @@ +""" +Wrapper for ViT Base submodel. +""" + +import json +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple +from PIL import Image +from torchvision import transforms + +try: + import timm + TIMM_AVAILABLE = True +except ImportError: + TIMM_AVAILABLE = False + +from app.core.errors import InferenceError, ConfigurationError +from app.core.logging import get_logger +from app.models.wrappers.base_wrapper import BaseSubmodelWrapper +from app.services.explainability import attention_rollout, heatmap_to_base64, compute_focus_summary + +logger = get_logger(__name__) + + +class ViTWithMLPHead(nn.Module): + """ + ViT model wrapper matching the training checkpoint format. + + The checkpoint was saved with: + - self.vit = timm ViT backbone (num_classes=0) + - self.fc1 = Linear(768, hidden) + - self.fc2 = Linear(hidden, num_classes) + """ + + def __init__(self, arch: str = "vit_base_patch16_224", num_classes: int = 2, hidden_dim: int = 512): + super().__init__() + # Create backbone without classification head + self.vit = timm.create_model(arch, pretrained=False, num_classes=0) + embed_dim = self.vit.embed_dim # 768 for ViT-Base + self.fc1 = nn.Linear(embed_dim, hidden_dim) + self.fc2 = nn.Linear(hidden_dim, num_classes) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + features = self.vit(x) # [B, embed_dim] + x = F.relu(self.fc1(features)) + logits = self.fc2(x) + return logits + + +class ViTBaseWrapper(BaseSubmodelWrapper): + """ + Wrapper for ViT Base model (Vision Transformer). + + Model expects 224x224 RGB images with ImageNet normalization. + """ + + def __init__( + self, + repo_id: str, + config: Dict[str, Any], + local_path: str + ): + super().__init__(repo_id, config, local_path) + self._model: Optional[nn.Module] = None + self._transform: Optional[transforms.Compose] = None + self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self._threshold = config.get("threshold", 0.5) + logger.info(f"Initialized ViTBaseWrapper for {repo_id}") + + def load(self) -> None: + """Load the ViT Base model with trained weights.""" + if not TIMM_AVAILABLE: + raise ConfigurationError( + message="timm package not installed. Run: pip install timm", + details={"repo_id": self.repo_id} + ) + + weights_path = Path(self.local_path) / "deepfake_vit_finetuned_wildfake.pth" + preprocess_path = Path(self.local_path) / "preprocess.json" + + if not weights_path.exists(): + raise ConfigurationError( + message=f"deepfake_vit_finetuned_wildfake.pth not found in {self.local_path}", + details={"repo_id": self.repo_id, "expected_path": str(weights_path)} + ) + + try: + # Load preprocessing config + preprocess_config = {} + if preprocess_path.exists(): + with open(preprocess_path, "r") as f: + preprocess_config = json.load(f) + + # Build transform pipeline + input_size = preprocess_config.get("input_size", 224) + if isinstance(input_size, list): + input_size = input_size[0] + + normalize_config = preprocess_config.get("normalize", {}) + mean = normalize_config.get("mean", [0.485, 0.456, 0.406]) + std = normalize_config.get("std", [0.229, 0.224, 0.225]) + + # Use bicubic interpolation as specified + interpolation = preprocess_config.get("interpolation", "bicubic") + interp_mode = transforms.InterpolationMode.BICUBIC if interpolation == "bicubic" else transforms.InterpolationMode.BILINEAR + + self._transform = transforms.Compose([ + transforms.Resize((input_size, input_size), interpolation=interp_mode), + transforms.ToTensor(), + transforms.Normalize(mean=mean, std=std) + ]) + + # Create model architecture matching the training checkpoint format + arch = self.config.get("arch", "vit_base_patch16_224") + num_classes = self.config.get("num_classes", 2) + # MLP hidden dim is 512 per training notebook (fc1: 768->512, fc2: 512->2) + # Note: config.hidden_dim (768) is ViT embedding dim, not MLP hidden dim + mlp_hidden_dim = self.config.get("mlp_hidden_dim", 512) + + # Use custom wrapper that matches checkpoint structure (vit.* + fc1/fc2) + self._model = ViTWithMLPHead(arch=arch, num_classes=num_classes, hidden_dim=mlp_hidden_dim) + + # Load trained weights + checkpoint = torch.load(weights_path, map_location=self._device, weights_only=False) + + # Handle training checkpoint format (has "model", "optimizer_state", "epoch" keys) + if isinstance(checkpoint, dict) and "model" in checkpoint: + state_dict = checkpoint["model"] + else: + state_dict = checkpoint + + self._model.load_state_dict(state_dict) + self._model.to(self._device) + self._model.eval() + + # Mark as loaded + self._predict_fn = self._run_inference + logger.info(f"Loaded ViT Base model from {self.repo_id}") + + except ConfigurationError: + raise + except Exception as e: + logger.error(f"Failed to load ViT Base model: {e}") + raise ConfigurationError( + message=f"Failed to load model: {e}", + details={"repo_id": self.repo_id, "error": str(e)} + ) + + def _run_inference( + self, + image_tensor: torch.Tensor, + explain: bool = False + ) -> Dict[str, Any]: + """Run model inference on preprocessed tensor.""" + heatmap = None + + if explain: + # Collect attention weights from all blocks + attentions: List[torch.Tensor] = [] + handles = [] + + def get_attention_hook(module, input, output): + # For timm ViT, the attention forward returns (attn @ v) + # We need to hook into the softmax to get raw attention weights + # Alternative: access module's internal attn variable if available + pass + + # Hook into attention modules to capture weights + # timm ViT blocks structure: blocks[i].attn + # We'll use a forward hook that computes attention manually + def create_attn_hook(): + stored_attn = [] + + def hook(module, inputs, outputs): + # Get q, k from the module's forward computation + # inputs[0] is x of shape [B, N, C] + x = inputs[0] + B, N, C = x.shape + + # Access the attention module's parameters + qkv = module.qkv(x) # [B, N, 3*dim] + qkv = qkv.reshape(B, N, 3, module.num_heads, C // module.num_heads) + qkv = qkv.permute(2, 0, 3, 1, 4) # [3, B, heads, N, dim_head] + q, k, v = qkv[0], qkv[1], qkv[2] + + # Compute attention weights + scale = (C // module.num_heads) ** -0.5 + attn = (q @ k.transpose(-2, -1)) * scale + attn = attn.softmax(dim=-1) # [B, heads, N, N] + + # Average over heads + attn_avg = attn.mean(dim=1) # [B, N, N] + stored_attn.append(attn_avg.detach()) + + return hook, stored_attn + + all_stored_attns = [] + for block in self._model.vit.blocks: + hook_fn, stored = create_attn_hook() + all_stored_attns.append(stored) + handle = block.attn.register_forward_hook(hook_fn) + handles.append(handle) + + try: + with torch.no_grad(): + logits = self._model(image_tensor) + probs = F.softmax(logits, dim=1) + prob_fake = probs[0, 1].item() + pred_int = 1 if prob_fake >= self._threshold else 0 + + # Get attention from hooks + attention_list = [stored[0] for stored in all_stored_attns if len(stored) > 0] + + if attention_list: + # Stack: [num_layers, B, N, N] + attention_stack = torch.stack(attention_list, dim=0) + # Compute rollout - returns (grid_size, grid_size) heatmap + attention_map = attention_rollout( + attention_stack[:, 0], # [num_layers, N, N] + head_fusion="mean", # Already averaged + discard_ratio=0.0, + num_prefix_tokens=1 # ViT has 1 CLS token + ) # Returns (14, 14) for ViT-Base + + # Resize to image size + from PIL import Image as PILImage + heatmap_img = PILImage.fromarray( + (attention_map * 255).astype(np.uint8) + ).resize((224, 224), PILImage.BILINEAR) + heatmap = np.array(heatmap_img).astype(np.float32) / 255.0 + + finally: + for handle in handles: + handle.remove() + else: + with torch.no_grad(): + logits = self._model(image_tensor) + probs = F.softmax(logits, dim=1) + prob_fake = probs[0, 1].item() + pred_int = 1 if prob_fake >= self._threshold else 0 + + result = { + "logits": logits[0].cpu().numpy().tolist(), + "prob_fake": prob_fake, + "pred_int": pred_int + } + + if heatmap is not None: + result["heatmap"] = heatmap + + return result + + def predict( + self, + image: Optional[Image.Image] = None, + image_bytes: Optional[bytes] = None, + explain: bool = False, + **kwargs + ) -> Dict[str, Any]: + """ + Run prediction on an image. + + Args: + image: PIL Image object + image_bytes: Raw image bytes (will be converted to PIL Image) + explain: If True, compute attention rollout heatmap + + Returns: + Standardized prediction dictionary with optional heatmap + """ + if self._model is None or self._transform is None: + raise InferenceError( + message="Model not loaded", + details={"repo_id": self.repo_id} + ) + + try: + # Convert bytes to PIL Image if needed + if image is None and image_bytes is not None: + import io + image = Image.open(io.BytesIO(image_bytes)).convert("RGB") + elif image is not None: + image = image.convert("RGB") + else: + raise InferenceError( + message="No image provided", + details={"repo_id": self.repo_id} + ) + + # Preprocess + image_tensor = self._transform(image).unsqueeze(0).to(self._device) + + # Run inference + result = self._run_inference(image_tensor, explain=explain) + + # Standardize output + labels = self.config.get("labels", {"0": "real", "1": "fake"}) + pred_int = result["pred_int"] + + output = { + "pred_int": pred_int, + "pred": labels.get(str(pred_int), "unknown"), + "prob_fake": result["prob_fake"], + "meta": { + "model": self.name, + "threshold": self._threshold, + "logits": result["logits"] + } + } + + # Add heatmap if requested + if explain and "heatmap" in result: + heatmap = result["heatmap"] + output["heatmap_base64"] = heatmap_to_base64(heatmap) + output["explainability_type"] = "attention_rollout" + output["focus_summary"] = compute_focus_summary(heatmap) + + return output + + except InferenceError: + raise + except Exception as e: + logger.error(f"Prediction failed for {self.repo_id}: {e}") + raise InferenceError( + message=f"Prediction failed: {e}", + details={"repo_id": self.repo_id, "error": str(e)} + ) diff --git a/app/schemas/__init__.py b/app/schemas/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fff1fac6141583fd94c1f546ca95dae1066b063e --- /dev/null +++ b/app/schemas/__init__.py @@ -0,0 +1 @@ +# Schemas module diff --git a/app/schemas/__pycache__/__init__.cpython-312.pyc b/app/schemas/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a347d7c630964faf5bf1eb7751b8e71f9de31f0d Binary files /dev/null and b/app/schemas/__pycache__/__init__.cpython-312.pyc differ diff --git a/app/schemas/__pycache__/models.cpython-312.pyc b/app/schemas/__pycache__/models.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd9aae4664c5fdce7063a6d33a6c86ff691c03eb Binary files /dev/null and b/app/schemas/__pycache__/models.cpython-312.pyc differ diff --git a/app/schemas/__pycache__/predict.cpython-312.pyc b/app/schemas/__pycache__/predict.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d8beb5b30808f6e12ad5480152826ca2bdb434a1 Binary files /dev/null and b/app/schemas/__pycache__/predict.cpython-312.pyc differ diff --git a/app/schemas/models.py b/app/schemas/models.py new file mode 100644 index 0000000000000000000000000000000000000000..8ca526327f3a80b321b0a64c999353db5c19095a --- /dev/null +++ b/app/schemas/models.py @@ -0,0 +1,53 @@ +""" +Pydantic schemas for model-related endpoints. +""" + +from typing import Dict, List, Literal, Optional, Any +from pydantic import BaseModel, Field + + +class ModelInfo(BaseModel): + """Information about a loaded model.""" + + repo_id: str = Field(..., description="Hugging Face repository ID") + name: str = Field(..., description="Short name of the model") + model_type: Literal["submodel", "fusion"] = Field( + ..., + description="Type of model" + ) + config: Optional[Dict[str, Any]] = Field( + None, + description="Model configuration from config.json" + ) + + +class ModelsListResponse(BaseModel): + """Response schema for listing models.""" + + fusion: Optional[ModelInfo] = Field( + None, + description="Fusion model information" + ) + submodels: List[ModelInfo] = Field( + default_factory=list, + description="List of loaded submodels" + ) + total_count: int = Field(..., description="Total number of loaded models") + + +class HealthResponse(BaseModel): + """Response schema for health check.""" + + status: Literal["ok", "error"] = Field(..., description="Health status") + + +class ReadyResponse(BaseModel): + """Response schema for readiness check.""" + + status: Literal["ready", "not_ready"] = Field(..., description="Readiness status") + models_loaded: bool = Field(..., description="Whether models are loaded") + fusion_repo: Optional[str] = Field(None, description="Fusion repository ID") + submodels: List[str] = Field( + default_factory=list, + description="List of loaded submodel repository IDs" + ) diff --git a/app/schemas/predict.py b/app/schemas/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..fe7ca15060081b523674263d12c40e6364ab04fa --- /dev/null +++ b/app/schemas/predict.py @@ -0,0 +1,161 @@ +""" +Pydantic schemas for prediction endpoints. +""" + +from typing import Dict, List, Literal, Optional +from pydantic import BaseModel, Field + + +# Model display info schema (for frontend) +class ModelDisplayInfo(BaseModel): + """User-friendly display information for a model.""" + + display_name: str = Field(..., description="User-friendly model name (e.g., 'Texture Analysis')") + short_name: str = Field(..., description="Short identifier (e.g., 'CNN')") + method_name: str = Field(..., description="Explainability method name (e.g., 'Grad-CAM')") + method_description: str = Field(..., description="Brief description of the method") + educational_text: str = Field(..., description="Educational text about what this model analyzes") + what_it_looks_for: List[str] = Field(..., description="List of things this model looks for") + + +# LLM single-model explanation schema (for on-demand requests) +class SingleModelInsight(BaseModel): + """LLM-generated insight for a single model (on-demand).""" + + key_finding: str = Field(..., description="Main finding from the model") + what_model_saw: str = Field(..., description="What the model detected in the image") + important_regions: List[str] = Field(..., description="Key regions identified") + confidence_qualifier: str = Field(..., description="Confidence assessment with hedging") + + +class PredictionResult(BaseModel): + """Single prediction result from a model.""" + + pred: Literal["real", "fake"] = Field( + ..., + description="Human-readable prediction label" + ) + pred_int: Literal[0, 1] = Field( + ..., + description="Integer prediction: 0=real, 1=fake" + ) + prob_fake: float = Field( + ..., + ge=0.0, + le=1.0, + description="Probability that the image is fake (0.0-1.0)" + ) + heatmap_base64: Optional[str] = Field( + None, + description="Base64-encoded PNG heatmap showing model attention/saliency (when explain=true)" + ) + explainability_type: Optional[Literal["grad_cam", "attention_rollout"]] = Field( + None, + description="Type of explainability method used" + ) + focus_summary: Optional[str] = Field( + None, + description="Brief description of where the model focused (e.g., 'concentrated on face region')" + ) + contribution_percentage: Optional[float] = Field( + None, + ge=0.0, + le=100.0, + description="How much this model contributed to the fusion decision (0-100%)" + ) + + +class FusionMeta(BaseModel): + """Metadata from fusion model about how decision was made.""" + + submodel_weights: Dict[str, float] = Field( + default_factory=dict, + description="Learned coefficients for each submodel" + ) + weighted_contributions: Dict[str, float] = Field( + default_factory=dict, + description="Actual contribution to this prediction (weight * prob_fake)" + ) + contribution_percentages: Dict[str, float] = Field( + default_factory=dict, + description="Normalized percentages for display" + ) + + +class TimingInfo(BaseModel): + """Timing breakdown for the prediction request.""" + + total: int = Field(..., description="Total time in milliseconds") + download: Optional[int] = Field(None, description="Image download time in ms") + preprocess: Optional[int] = Field(None, description="Preprocessing time in ms") + inference: Optional[int] = Field(None, description="Model inference time in ms") + fusion: Optional[int] = Field(None, description="Fusion computation time in ms") + + +class PredictResponse(BaseModel): + """Response schema for prediction endpoint.""" + + final: PredictionResult = Field( + ..., + description="Final prediction result" + ) + fusion_used: bool = Field( + ..., + description="Whether fusion was used for this prediction" + ) + submodels: Optional[Dict[str, PredictionResult]] = Field( + None, + description="Individual submodel predictions (when fusion_used=true and return_submodels=true)" + ) + fusion_meta: Optional[FusionMeta] = Field( + None, + description="Fusion metadata including model weights and contributions" + ) + model_display_info: Optional[Dict[str, ModelDisplayInfo]] = Field( + None, + description="Display information for each model (for frontend rendering)" + ) + timing_ms: TimingInfo = Field( + ..., + description="Timing breakdown in milliseconds" + ) + + class Config: + json_schema_extra = { + "example": { + "final": {"pred": "fake", "pred_int": 1, "prob_fake": 0.6667}, + "fusion_used": True, + "submodels": { + "cnn-transfer": {"pred": "fake", "pred_int": 1, "prob_fake": 0.82, "contribution_percentage": 45.2}, + "vit-base": {"pred": "fake", "pred_int": 1, "prob_fake": 0.75, "contribution_percentage": 32.1}, + "gradfield-cnn": {"pred": "fake", "pred_int": 1, "prob_fake": 0.91, "contribution_percentage": 22.7} + }, + "timing_ms": {"total": 250, "inference": 200, "fusion": 5} + } + } + + +# Request/Response for single-model explanation endpoint +class ExplainModelRequest(BaseModel): + """Request schema for single-model explanation.""" + + model_name: str = Field(..., description="Name of the model to explain") + prob_fake: float = Field(..., ge=0.0, le=1.0, description="Model's fake probability") + heatmap_base64: Optional[str] = Field(None, description="Base64-encoded heatmap") + focus_summary: Optional[str] = Field(None, description="Where the model focused") + contribution_percentage: Optional[float] = Field(None, description="Model's contribution to fusion") + + +class ExplainModelResponse(BaseModel): + """Response schema for single-model explanation.""" + + model_name: str = Field(..., description="Internal model name") + insight: SingleModelInsight = Field(..., description="LLM-generated insight") + + +class ErrorResponse(BaseModel): + """Error response schema.""" + + error: str = Field(..., description="Error type") + message: str = Field(..., description="Error message") + details: Optional[Dict] = Field(None, description="Additional error details") diff --git a/app/services/__init__.py b/app/services/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0557eb635c5522686a57e633065437599726336a --- /dev/null +++ b/app/services/__init__.py @@ -0,0 +1 @@ +# Services module diff --git a/app/services/__pycache__/__init__.cpython-312.pyc b/app/services/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..166d46cac6b8425d5d614b3421077bd156f588c8 Binary files /dev/null and b/app/services/__pycache__/__init__.cpython-312.pyc differ diff --git a/app/services/__pycache__/explainability.cpython-312.pyc b/app/services/__pycache__/explainability.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd9a7bc92af13260cee464429c198d571f0c1410 Binary files /dev/null and b/app/services/__pycache__/explainability.cpython-312.pyc differ diff --git a/app/services/__pycache__/fusion_service.cpython-312.pyc b/app/services/__pycache__/fusion_service.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c8d4e90606903293854bd41690aa181817cc26bc Binary files /dev/null and b/app/services/__pycache__/fusion_service.cpython-312.pyc differ diff --git a/app/services/__pycache__/hf_hub_service.cpython-312.pyc b/app/services/__pycache__/hf_hub_service.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55ee4186d37a5c6c2c9a21afaefae125b24d9dd9 Binary files /dev/null and b/app/services/__pycache__/hf_hub_service.cpython-312.pyc differ diff --git a/app/services/__pycache__/inference_service.cpython-312.pyc b/app/services/__pycache__/inference_service.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..029117938a3548b2cb13cae12ee03eae2c796aca Binary files /dev/null and b/app/services/__pycache__/inference_service.cpython-312.pyc differ diff --git a/app/services/__pycache__/llm_service.cpython-312.pyc b/app/services/__pycache__/llm_service.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2804f4205b82c0ab66cf12950d52abc4c9aa0ed9 Binary files /dev/null and b/app/services/__pycache__/llm_service.cpython-312.pyc differ diff --git a/app/services/__pycache__/model_registry.cpython-312.pyc b/app/services/__pycache__/model_registry.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91c36dbbfac719c60c7c4f68b5b324653f9b6ba2 Binary files /dev/null and b/app/services/__pycache__/model_registry.cpython-312.pyc differ diff --git a/app/services/__pycache__/preprocess_service.cpython-312.pyc b/app/services/__pycache__/preprocess_service.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8553c4320ad5f2249cc9edf37ee151d0bb8ddd83 Binary files /dev/null and b/app/services/__pycache__/preprocess_service.cpython-312.pyc differ diff --git a/app/services/cache_service.py b/app/services/cache_service.py new file mode 100644 index 0000000000000000000000000000000000000000..96b6cf71dd9f167697cd491e0160ba7e40611f7a --- /dev/null +++ b/app/services/cache_service.py @@ -0,0 +1,103 @@ +""" +Cache service for managing cached data. + +Note: For Milestone 1, this is a simple placeholder. +Future milestones may add more sophisticated caching. +""" + +from typing import Any, Dict, Optional +import hashlib +import time + +from app.core.logging import get_logger + +logger = get_logger(__name__) + + +class CacheService: + """ + Simple in-memory cache service. + + This is a basic implementation for Milestone 1. + Future versions may use Redis or other backends. + """ + + def __init__(self, max_size: int = 100, ttl_seconds: int = 3600): + """ + Initialize the cache service. + + Args: + max_size: Maximum number of items to cache + ttl_seconds: Time-to-live for cache entries in seconds + """ + self._cache: Dict[str, Dict[str, Any]] = {} + self.max_size = max_size + self.ttl_seconds = ttl_seconds + + def _generate_key(self, data: bytes) -> str: + """Generate a cache key from data (SHA256 hash).""" + return hashlib.sha256(data).hexdigest() + + def get(self, key: str) -> Optional[Any]: + """ + Get a value from cache. + + Args: + key: Cache key + + Returns: + Cached value or None if not found/expired + """ + entry = self._cache.get(key) + if entry is None: + return None + + # Check TTL + if time.time() - entry["timestamp"] > self.ttl_seconds: + del self._cache[key] + return None + + return entry["value"] + + def set(self, key: str, value: Any) -> None: + """ + Set a value in cache. + + Args: + key: Cache key + value: Value to cache + """ + # Simple eviction: remove oldest if at max size + if len(self._cache) >= self.max_size: + oldest_key = min(self._cache.keys(), key=lambda k: self._cache[k]["timestamp"]) + del self._cache[oldest_key] + + self._cache[key] = { + "value": value, + "timestamp": time.time() + } + + def clear(self) -> None: + """Clear all cached entries.""" + self._cache.clear() + + def size(self) -> int: + """Get current cache size.""" + return len(self._cache) + + +# Global singleton instance +_cache_service: Optional[CacheService] = None + + +def get_cache_service() -> CacheService: + """ + Get the global cache service instance. + + Returns: + CacheService instance + """ + global _cache_service + if _cache_service is None: + _cache_service = CacheService() + return _cache_service diff --git a/app/services/explainability.py b/app/services/explainability.py new file mode 100644 index 0000000000000000000000000000000000000000..c91a42e17d932cf90b53527b6f83ada9b724ebb9 --- /dev/null +++ b/app/services/explainability.py @@ -0,0 +1,667 @@ +""" +Explainability utilities for DeepFake detection models. + +Provides: +- GradCAM: For CNN-based models (EfficientNet, CompactGradientNet) +- Attention Rollout: For ViT/DeiT transformer models +- Heatmap visualization utilities +""" + +import base64 +import io +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from PIL import Image + +from app.core.logging import get_logger + +logger = get_logger(__name__) + + +class GradCAM: + """ + Gradient-weighted Class Activation Mapping for CNN models. + + Computes importance heatmaps by weighting feature map activations + by the gradients flowing into them from the target class. + + Usage: + gradcam = GradCAM(model, target_layer) + heatmap = gradcam(input_tensor, target_class=1) + """ + + def __init__(self, model: nn.Module, target_layer: nn.Module): + """ + Args: + model: The CNN model + target_layer: The convolutional layer to compute Grad-CAM on + (typically the last conv layer before pooling) + """ + self.model = model + self.target_layer = target_layer + self.gradients: Optional[torch.Tensor] = None + self.activations: Optional[torch.Tensor] = None + self._hooks: List = [] + + self._register_hooks() + + def _register_hooks(self): + """Register forward and backward hooks on target layer.""" + def forward_hook(module, input, output): + self.activations = output.detach() + + def backward_hook(module, grad_input, grad_output): + self.gradients = grad_output[0].detach() + + self._hooks.append( + self.target_layer.register_forward_hook(forward_hook) + ) + self._hooks.append( + self.target_layer.register_full_backward_hook(backward_hook) + ) + + def remove_hooks(self): + """Remove registered hooks.""" + for hook in self._hooks: + hook.remove() + self._hooks.clear() + + def __call__( + self, + input_tensor: torch.Tensor, + target_class: Optional[int] = None, + output_size: Tuple[int, int] = (224, 224) + ) -> np.ndarray: + """ + Compute Grad-CAM heatmap. + + Args: + input_tensor: Input image tensor [1, C, H, W] + target_class: Class index to compute gradients for. + If None, uses the predicted class. + output_size: Size to resize the heatmap to (H, W) + + Returns: + Normalized heatmap as numpy array [H, W] in range [0, 1] + """ + self.model.eval() + + # Enable gradients for this forward pass + input_tensor = input_tensor.clone().requires_grad_(True) + + # Forward pass + output = self.model(input_tensor) + + # Handle different output formats + if isinstance(output, tuple): + logits = output[0] # Some models return (logits, embeddings) + else: + logits = output + + # Ensure logits is 2D [batch, classes] + if logits.dim() == 1: + logits = logits.unsqueeze(0) + + # Determine target class + if target_class is None: + target_class = logits.argmax(dim=1).item() + + # Zero gradients + self.model.zero_grad() + + # Backward pass for target class + if logits.shape[-1] > 1: + # Multi-class: select target class score + target_score = logits[0, target_class] + else: + # Binary with single output: use the logit directly + target_score = logits[0, 0] if target_class == 1 else -logits[0, 0] + + target_score.backward(retain_graph=True) + + # Compute Grad-CAM + if self.gradients is None or self.activations is None: + logger.warning("Gradients or activations not captured") + return np.zeros(output_size, dtype=np.float32) + + # Global average pool gradients to get weights + weights = self.gradients.mean(dim=(2, 3), keepdim=True) # [1, C, 1, 1] + + # Weighted combination of activation maps + cam = (weights * self.activations).sum(dim=1, keepdim=True) # [1, 1, H, W] + + # ReLU to keep only positive contributions + cam = F.relu(cam) + + # Normalize + cam = cam - cam.min() + cam_max = cam.max() + if cam_max > 0: + cam = cam / cam_max + + # Resize to output size + cam = F.interpolate( + cam, + size=output_size, + mode='bilinear', + align_corners=False + ) + + # Convert to numpy + heatmap = cam.squeeze().cpu().numpy() + + return heatmap + + def __del__(self): + self.remove_hooks() + + +def attention_rollout( + attentions: Union[List[torch.Tensor], torch.Tensor], + discard_ratio: float = 0.0, + head_fusion: str = "mean", + num_prefix_tokens: int = 1 +) -> np.ndarray: + """ + Compute attention rollout for Vision Transformers. + + Aggregates attention across all layers by matrix multiplication, + accounting for residual connections. + + Args: + attentions: Attention tensors from each layer. Can be: + - List of tensors, each shape [batch, num_heads, seq_len, seq_len] or [seq_len, seq_len] + - Stacked tensor of shape [num_layers, seq_len, seq_len] (already head-fused) + discard_ratio: Fraction of lowest attention weights to discard + head_fusion: How to combine attention heads ("mean", "max", "min") + num_prefix_tokens: Number of special tokens (1 for ViT cls, 2 for DeiT cls+dist) + + Returns: + Attention map as numpy array of shape (grid_size, grid_size) + """ + # Default grid size for ViT-Base (14x14 patches from 224x224 with 16x16 patch size) + default_grid_size = 14 + + # Handle empty input + if attentions is None: + logger.warning("No attention tensors provided (None)") + return np.zeros((default_grid_size, default_grid_size), dtype=np.float32) + + # Convert tensor to list if needed + if isinstance(attentions, torch.Tensor): + if attentions.numel() == 0: + logger.warning("Empty attention tensor provided") + return np.zeros((default_grid_size, default_grid_size), dtype=np.float32) + # Convert stacked tensor to list + attentions = [attentions[i] for i in range(attentions.shape[0])] + + # Check if list is empty + if len(attentions) == 0: + logger.warning("Empty attention list provided") + return np.zeros((default_grid_size, default_grid_size), dtype=np.float32) + + result = None + + for attention in attentions: + # Handle different input formats + if attention.dim() == 2: + # Already fused: [seq_len, seq_len] + attention_fused = attention.unsqueeze(0) # [1, seq, seq] + elif attention.dim() == 3: + # Batched without heads or already fused: [B, seq, seq] + attention_fused = attention + elif attention.dim() == 4: + # Full attention: [B, heads, seq, seq] - fuse heads + if head_fusion == "mean": + attention_fused = attention.mean(dim=1) # [B, seq, seq] + elif head_fusion == "max": + attention_fused = attention.max(dim=1)[0] + elif head_fusion == "min": + attention_fused = attention.min(dim=1)[0] + else: + attention_fused = attention.mean(dim=1) + else: + logger.warning(f"Unexpected attention shape: {attention.shape}") + continue + + # Discard low attention (optional) + if discard_ratio > 0: + flat = attention_fused.view(attention_fused.size(0), -1) + threshold = torch.quantile(flat, discard_ratio, dim=1, keepdim=True) + threshold = threshold.view(attention_fused.size(0), 1, 1) + attention_fused = torch.where( + attention_fused < threshold, + torch.zeros_like(attention_fused), + attention_fused + ) + # Renormalize + attention_fused = attention_fused / (attention_fused.sum(dim=-1, keepdim=True) + 1e-9) + + # Add identity for residual connection + seq_len = attention_fused.size(-1) + identity = torch.eye(seq_len, device=attention_fused.device, dtype=attention_fused.dtype) + attention_with_residual = 0.5 * attention_fused + 0.5 * identity.unsqueeze(0) + + # Matrix multiply through layers + if result is None: + result = attention_with_residual + else: + result = torch.bmm(attention_with_residual, result) + + # Extract CLS token attention to all patch tokens + # result shape: [B, seq_len, seq_len] + # CLS token is at index 0, patches start at index num_prefix_tokens + cls_attention = result[0, 0, num_prefix_tokens:] # [num_patches] + + # Reshape to grid + num_patches = cls_attention.size(0) + grid_size = int(math.sqrt(num_patches)) + + if grid_size * grid_size != num_patches: + logger.warning(f"Non-square number of patches: {num_patches}") + # Pad or truncate to nearest square + grid_size = int(math.ceil(math.sqrt(num_patches))) + padded = torch.zeros(grid_size * grid_size, device=cls_attention.device) + padded[:num_patches] = cls_attention + cls_attention = padded + + attention_map = cls_attention.reshape(grid_size, grid_size).cpu().numpy() + + # Normalize + attention_map = attention_map - attention_map.min() + if attention_map.max() > 0: + attention_map = attention_map / attention_map.max() + + return attention_map + + +def heatmap_to_base64( + heatmap: np.ndarray, + colormap: str = "turbo", + output_size: Optional[Tuple[int, int]] = None +) -> str: + """ + Convert a heatmap array to base64-encoded PNG string. + + Args: + heatmap: 2D numpy array with values in [0, 1] + colormap: Matplotlib colormap name ("turbo", "jet", "viridis", "inferno") + output_size: Optional (width, height) to resize to + + Returns: + Base64-encoded PNG string (without data:image/png;base64, prefix) + """ + import matplotlib + matplotlib.use('Agg') # Non-interactive backend + import matplotlib.pyplot as plt + import matplotlib.cm as cm + + # Get colormap + cmap = cm.get_cmap(colormap) + + # Apply colormap (returns RGBA) + colored = cmap(heatmap) + + # Convert to uint8 RGB + rgb = (colored[:, :, :3] * 255).astype(np.uint8) + + # Create PIL image + img = Image.fromarray(rgb) + + # Resize if needed + if output_size is not None: + img = img.resize(output_size, Image.BILINEAR) + + # Save to bytes + buffer = io.BytesIO() + img.save(buffer, format='PNG', optimize=True) + buffer.seek(0) + + # Encode to base64 + encoded = base64.b64encode(buffer.getvalue()).decode('utf-8') + + return encoded + + +def overlay_heatmap_on_image( + image: Union[np.ndarray, Image.Image], + heatmap: np.ndarray, + alpha: float = 0.5, + colormap: str = "turbo" +) -> str: + """ + Overlay a heatmap on an image and return as base64 PNG. + + Args: + image: Original image (numpy array HWC or PIL Image) + heatmap: 2D heatmap array [0, 1] + alpha: Blend factor (0 = image only, 1 = heatmap only) + colormap: Matplotlib colormap name + + Returns: + Base64-encoded PNG of the overlaid image + """ + import matplotlib + matplotlib.use('Agg') + import matplotlib.cm as cm + + # Convert image to numpy if needed + if isinstance(image, Image.Image): + image = np.array(image) + + # Ensure image is uint8 RGB + if image.dtype != np.uint8: + image = (image * 255).astype(np.uint8) + if image.ndim == 2: + image = np.stack([image] * 3, axis=-1) + elif image.shape[-1] == 1: + image = np.concatenate([image] * 3, axis=-1) + elif image.shape[-1] == 4: + image = image[:, :, :3] + + # Resize heatmap to match image size + h, w = image.shape[:2] + heatmap_resized = np.array( + Image.fromarray((heatmap * 255).astype(np.uint8)).resize((w, h), Image.BILINEAR) + ) / 255.0 + + # Apply colormap + cmap = cm.get_cmap(colormap) + heatmap_colored = cmap(heatmap_resized)[:, :, :3] + heatmap_colored = (heatmap_colored * 255).astype(np.uint8) + + # Blend + blended = ( + (1 - alpha) * image.astype(np.float32) + + alpha * heatmap_colored.astype(np.float32) + ).astype(np.uint8) + + # Convert to base64 + img = Image.fromarray(blended) + buffer = io.BytesIO() + img.save(buffer, format='PNG', optimize=True) + buffer.seek(0) + + return base64.b64encode(buffer.getvalue()).decode('utf-8') + + +class AttentionExtractor: + """ + Hook-based attention extractor for ViT/DeiT models. + + Registers hooks on transformer blocks to capture attention weights + during forward pass. + + Usage: + extractor = AttentionExtractor(model.blocks) + output = model(input) + attentions = extractor.get_attentions() + extractor.clear() + """ + + def __init__(self, blocks: nn.ModuleList): + """ + Args: + blocks: List of transformer blocks (each should have .attn attribute) + """ + self.attentions: List[torch.Tensor] = [] + self._hooks: List = [] + + for block in blocks: + if hasattr(block, 'attn'): + # Hook into the attention module + # We need to capture after softmax, before dropout + # timm stores attention in attn.attn_drop or we can compute from qkv + hook = block.attn.register_forward_hook(self._make_hook()) + self._hooks.append(hook) + + def _make_hook(self): + """Create a forward hook that captures attention weights.""" + def hook(module, input, output): + # For timm ViT, we need to recompute attention from qkv + # The module receives x and outputs x after attention + # We'll store a flag and compute in get_attentions + pass + return hook + + def extract_attention_from_block( + self, + block: nn.Module, + x: torch.Tensor + ) -> torch.Tensor: + """ + Extract attention weights from a single transformer block. + + Args: + block: Transformer block with attention module + x: Input tensor [B, seq_len, embed_dim] + + Returns: + Attention weights [B, num_heads, seq_len, seq_len] + """ + attn = block.attn + B, N, C = x.shape + + # Get qkv + qkv = attn.qkv(x).reshape(B, N, 3, attn.num_heads, C // attn.num_heads) + qkv = qkv.permute(2, 0, 3, 1, 4) # [3, B, heads, N, head_dim] + q, k, v = qkv[0], qkv[1], qkv[2] + + # Compute attention + scale = (C // attn.num_heads) ** -0.5 + attn_weights = (q @ k.transpose(-2, -1)) * scale + attn_weights = attn_weights.softmax(dim=-1) + + return attn_weights + + def get_attentions(self) -> List[torch.Tensor]: + """Return captured attention tensors.""" + return self.attentions + + def clear(self): + """Clear captured attentions.""" + self.attentions.clear() + + def remove_hooks(self): + """Remove all hooks.""" + for hook in self._hooks: + hook.remove() + self._hooks.clear() + + def __del__(self): + self.remove_hooks() + + +def compute_vit_attention_rollout( + model: nn.Module, + input_tensor: torch.Tensor, + blocks_attr: str = "blocks", + num_prefix_tokens: int = 1, + output_size: Tuple[int, int] = (224, 224) +) -> np.ndarray: + """ + Compute attention rollout for a ViT-style model. + + Args: + model: The ViT model (should have .blocks attribute with transformer layers) + input_tensor: Input image tensor [1, 3, H, W] + blocks_attr: Attribute name for transformer blocks (e.g., "blocks" or "vit.blocks") + num_prefix_tokens: Number of prefix tokens (1 for CLS, 2 for CLS+DIST) + output_size: Size to resize output heatmap + + Returns: + Attention heatmap as numpy array [H, W] in range [0, 1] + """ + model.eval() + + # Navigate to blocks + blocks = model + for attr in blocks_attr.split('.'): + blocks = getattr(blocks, attr) + + attentions = [] + + # Hook to capture attention weights + def make_attn_hook(storage): + def hook(module, input, output): + # Recompute attention weights + x = input[0] + B, N, C = x.shape + + # Get qkv projection + qkv = module.qkv(x).reshape(B, N, 3, module.num_heads, C // module.num_heads) + qkv = qkv.permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + # Compute attention weights + scale = (C // module.num_heads) ** -0.5 + attn = (q @ k.transpose(-2, -1)) * scale + attn = attn.softmax(dim=-1) + + storage.append(attn.detach()) + return hook + + # Register hooks + hooks = [] + for block in blocks: + if hasattr(block, 'attn'): + h = block.attn.register_forward_hook(make_attn_hook(attentions)) + hooks.append(h) + + try: + # Forward pass + with torch.no_grad(): + _ = model(input_tensor) + + # Compute rollout + if attentions: + rollout = attention_rollout( + attentions, + num_prefix_tokens=num_prefix_tokens + ) + + # Resize to output size + rollout_img = Image.fromarray((rollout * 255).astype(np.uint8)) + rollout_img = rollout_img.resize(output_size, Image.BILINEAR) + rollout = np.array(rollout_img) / 255.0 + + return rollout + else: + logger.warning("No attention weights captured") + return np.zeros(output_size, dtype=np.float32) + + finally: + # Clean up hooks + for h in hooks: + h.remove() + + +def compute_focus_summary( + heatmap: np.ndarray, + threshold: float = 0.5 +) -> str: + """ + Compute a human-readable summary of where the heatmap focuses. + + Analyzes the heatmap to describe the spatial distribution of high + activation regions (e.g., "concentrated on upper-left", "diffuse across image"). + + Args: + heatmap: 2D numpy array with values in [0, 1], shape (H, W) + threshold: Threshold for considering a region as "high activation" + + Returns: + Human-readable focus summary string + """ + if heatmap is None or heatmap.size == 0: + return "no activation data available" + + # Normalize heatmap + heatmap = np.array(heatmap, dtype=np.float32) + if heatmap.max() > 0: + heatmap = heatmap / heatmap.max() + + h, w = heatmap.shape + + # Compute centroid of high activation regions + mask = heatmap > threshold + if not mask.any(): + # Lower threshold if nothing above it + mask = heatmap > (heatmap.max() * 0.5) + + if not mask.any(): + return "very low activation across entire image" + + # Get coordinates of activated pixels + y_coords, x_coords = np.where(mask) + + # Compute centroid + centroid_y = y_coords.mean() / h # Normalized to [0, 1] + centroid_x = x_coords.mean() / w # Normalized to [0, 1] + + # Compute spread (standard deviation normalized by image size) + spread_y = y_coords.std() / h if len(y_coords) > 1 else 0 + spread_x = x_coords.std() / w if len(x_coords) > 1 else 0 + spread = (spread_y + spread_x) / 2 + + # Compute coverage (fraction of image with high activation) + coverage = mask.sum() / mask.size + + # Build description + parts = [] + + # Describe spread + if spread < 0.15: + parts.append("highly concentrated") + elif spread < 0.25: + parts.append("moderately concentrated") + else: + parts.append("spread across") + + # Describe location + location_parts = [] + + # Vertical position + if centroid_y < 0.33: + location_parts.append("upper") + elif centroid_y > 0.67: + location_parts.append("lower") + else: + location_parts.append("middle") + + # Horizontal position + if centroid_x < 0.33: + location_parts.append("left") + elif centroid_x > 0.67: + location_parts.append("right") + else: + location_parts.append("center") + + # Combine location (avoid "middle center") + if location_parts == ["middle", "center"]: + location = "central region" + else: + location = "-".join(location_parts) + " region" + + parts.append(location) + + # Add coverage note for diffuse patterns + if coverage > 0.4: + parts.append(f"(~{int(coverage*100)}% of image)") + + summary = " ".join(parts) + + # Add semantic hints based on common portrait regions + # Center typically = face, edges/corners = background + if centroid_y < 0.5 and 0.3 < centroid_x < 0.7 and spread < 0.2: + summary += " (likely face/subject area)" + elif spread > 0.3: + summary += " (examining multiple regions)" + + return summary + diff --git a/app/services/fusion_service.py b/app/services/fusion_service.py new file mode 100644 index 0000000000000000000000000000000000000000..8b63cad51a7919a44a81663ef32639c7b867534c --- /dev/null +++ b/app/services/fusion_service.py @@ -0,0 +1,67 @@ +""" +Fusion service for combining submodel predictions. +""" + +from typing import Any, Dict + +from app.core.errors import FusionError +from app.core.logging import get_logger +from app.services.model_registry import get_model_registry + +logger = get_logger(__name__) + + +class FusionService: + """ + Service for running fusion predictions. + """ + + def __init__(self): + self._registry = get_model_registry() + + def fuse( + self, + submodel_outputs: Dict[str, Dict[str, Any]], + **kwargs + ) -> Dict[str, Any]: + """ + Run fusion on submodel outputs. + + Args: + submodel_outputs: Dictionary mapping submodel name to its prediction output + **kwargs: Additional arguments for the fusion model + + Returns: + Standardized prediction dictionary + + Raises: + FusionError: If fusion fails + """ + try: + fusion = self._registry.get_fusion() + return fusion.predict(submodel_outputs=submodel_outputs, **kwargs) + except FusionError: + raise + except Exception as e: + logger.error(f"Fusion failed: {e}") + raise FusionError( + message="Fusion prediction failed", + details={"error": str(e)} + ) + + +# Global singleton instance +_fusion_service = None + + +def get_fusion_service() -> FusionService: + """ + Get the global fusion service instance. + + Returns: + FusionService instance + """ + global _fusion_service + if _fusion_service is None: + _fusion_service = FusionService() + return _fusion_service diff --git a/app/services/hf_hub_service.py b/app/services/hf_hub_service.py new file mode 100644 index 0000000000000000000000000000000000000000..3dc97f64958c761e3bdf83b51bb3b8319c0baeb2 --- /dev/null +++ b/app/services/hf_hub_service.py @@ -0,0 +1,146 @@ +""" +Hugging Face Hub service for downloading model repositories. +""" + +import os +from pathlib import Path +from typing import Optional + +from huggingface_hub import snapshot_download +from huggingface_hub.utils import HfHubHTTPError + +from app.core.config import settings +from app.core.errors import HuggingFaceDownloadError +from app.core.logging import get_logger + +logger = get_logger(__name__) + +# Disable symlink warnings on Windows +os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1" + + +class HFHubService: + """ + Service for interacting with Hugging Face Hub. + + Handles downloading model repositories and caching them locally. + """ + + def __init__(self, cache_dir: Optional[str] = None, token: Optional[str] = None): + """ + Initialize the HF Hub service. + + Args: + cache_dir: Local directory for caching downloads. + Defaults to settings.HF_CACHE_DIR + token: Hugging Face API token for private repos. + Defaults to settings.HF_TOKEN + """ + self.cache_dir = cache_dir or settings.HF_CACHE_DIR + self.token = token or settings.HF_TOKEN + + # Ensure cache directory exists + Path(self.cache_dir).mkdir(parents=True, exist_ok=True) + logger.info(f"HF Hub service initialized with cache dir: {self.cache_dir}") + + def download_repo( + self, + repo_id: str, + revision: Optional[str] = None, + force_download: bool = False + ) -> str: + """ + Download a repository from Hugging Face Hub. + + Uses snapshot_download which handles caching automatically. + If the repo is already cached and not stale, it returns the cached path. + + Args: + repo_id: Hugging Face repository ID (e.g., "DeepFakeDetector/test-random-a") + revision: Git revision (branch, tag, or commit hash). Defaults to "main" + force_download: If True, re-download even if cached + + Returns: + Local path to the downloaded repository + + Raises: + HuggingFaceDownloadError: If download fails + """ + logger.info(f"Downloading repo: {repo_id} (revision={revision}, force={force_download})") + + try: + # Use local_dir instead of cache_dir to avoid symlink issues on Windows + repo_name = repo_id.replace("/", "--") + local_dir = Path(self.cache_dir) / repo_name + + local_path = snapshot_download( + repo_id=repo_id, + revision=revision or "main", + local_dir=str(local_dir), + token=self.token, + force_download=force_download, + local_files_only=False + ) + + logger.info(f"Downloaded {repo_id} to {local_path}") + return local_path + + except HfHubHTTPError as e: + logger.error(f"HTTP error downloading {repo_id}: {e}") + raise HuggingFaceDownloadError( + message=f"Failed to download repository: {repo_id}", + details={"repo_id": repo_id, "error": str(e)} + ) + except Exception as e: + logger.error(f"Error downloading {repo_id}: {e}") + raise HuggingFaceDownloadError( + message=f"Failed to download repository: {repo_id}", + details={"repo_id": repo_id, "error": str(e)} + ) + + def get_cached_path(self, repo_id: str) -> Optional[str]: + """ + Get the cached path for a repository if it exists. + + Args: + repo_id: Hugging Face repository ID + + Returns: + Local path if cached, None otherwise + """ + # Check local_dir path format (used to avoid symlinks on Windows) + repo_name = repo_id.replace("/", "--") + local_dir = Path(self.cache_dir) / repo_name + + if local_dir.exists() and any(local_dir.iterdir()): + return str(local_dir) + return None + + def is_cached(self, repo_id: str) -> bool: + """ + Check if a repository is already cached. + + Args: + repo_id: Hugging Face repository ID + + Returns: + True if cached, False otherwise + """ + return self.get_cached_path(repo_id) is not None + + +# Global singleton instance +_hf_hub_service: Optional[HFHubService] = None + + +def get_hf_hub_service() -> HFHubService: + """ + Get the global HF Hub service instance. + + Returns: + HFHubService instance + """ + global _hf_hub_service + if _hf_hub_service is None: + _hf_hub_service = HFHubService() + return _hf_hub_service diff --git a/app/services/inference_service.py b/app/services/inference_service.py new file mode 100644 index 0000000000000000000000000000000000000000..fcdaa2e589917ff831f4f70605b602f62897f6b0 --- /dev/null +++ b/app/services/inference_service.py @@ -0,0 +1,111 @@ +""" +Inference service for running model predictions. +""" + +from typing import Any, Dict, Optional + +from PIL import Image + +from app.core.errors import InferenceError, ModelNotFoundError +from app.core.logging import get_logger +from app.services.model_registry import get_model_registry +from app.utils.timing import Timer + +logger = get_logger(__name__) + + +class InferenceService: + """ + Service for running inference on individual models. + """ + + def __init__(self): + self._registry = get_model_registry() + + def predict_single( + self, + model_key: str, + image: Optional[Image.Image] = None, + image_bytes: Optional[bytes] = None, + **kwargs + ) -> Dict[str, Any]: + """ + Run prediction on a single submodel. + + Args: + model_key: Submodel name or repo_id + image: PIL Image object + image_bytes: Raw image bytes (alternative to image) + **kwargs: Additional arguments for the model + + Returns: + Standardized prediction dictionary + + Raises: + ModelNotFoundError: If model not found + InferenceError: If prediction fails + """ + try: + submodel = self._registry.get_submodel(model_key) + return submodel.predict(image=image, image_bytes=image_bytes, **kwargs) + except ModelNotFoundError: + raise + except Exception as e: + logger.error(f"Inference failed for {model_key}: {e}") + raise InferenceError( + message=f"Inference failed for model {model_key}", + details={"model": model_key, "error": str(e)} + ) + + def predict_all_submodels( + self, + image: Optional[Image.Image] = None, + image_bytes: Optional[bytes] = None, + **kwargs + ) -> Dict[str, Dict[str, Any]]: + """ + Run prediction on all loaded submodels. + + Args: + image: PIL Image object + image_bytes: Raw image bytes (alternative to image) + **kwargs: Additional arguments for the models + + Returns: + Dictionary mapping submodel name to prediction result + + Raises: + InferenceError: If any prediction fails + """ + submodels = self._registry.get_all_submodels() + results = {} + + for name, submodel in submodels.items(): + try: + result = submodel.predict(image=image, image_bytes=image_bytes, **kwargs) + results[name] = result + except Exception as e: + logger.error(f"Inference failed for submodel {name}: {e}") + raise InferenceError( + message=f"Inference failed for submodel {name}", + details={"model": name, "error": str(e)} + ) + + return results + + +# Global singleton instance +_inference_service: Optional[InferenceService] = None + + +def get_inference_service() -> InferenceService: + """ + Get the global inference service instance. + + Returns: + InferenceService instance + """ + global _inference_service + if _inference_service is None: + _inference_service = InferenceService() + return _inference_service diff --git a/app/services/llm_service.py b/app/services/llm_service.py new file mode 100644 index 0000000000000000000000000000000000000000..52c894b931e89d364adaaf6fdacd39609b3caeab --- /dev/null +++ b/app/services/llm_service.py @@ -0,0 +1,583 @@ +""" +LLM Service for generating human-readable explanations of model predictions. + +Uses Google Gemini to translate model-space evidence (heatmaps, attention maps) +into human-understandable hypotheses with proper hedging language. +""" + +import json +import base64 +from typing import Any, Dict, List, Optional +from functools import lru_cache + +from app.core.config import get_settings +from app.core.logging import get_logger + +logger = get_logger(__name__) + +# Model type descriptions for the LLM +MODEL_TYPE_DESCRIPTIONS = { + "cnn-transfer": { + "type": "rgb_texture_cnn", + "description": "Analyzes RGB pixel textures, colors, and fine details at multiple scales", + "typical_cues": ["skin texture uniformity", "shading gradients", "fine detail at boundaries"] + }, + "vit-base": { + "type": "patch_consistency_vit", + "description": "Analyzes global consistency and relationships between image patches", + "typical_cues": ["lighting consistency", "background blur patterns", "patch-level coherence"] + }, + "deit-distilled": { + "type": "patch_consistency_vit", + "description": "Analyzes global consistency with knowledge distillation for refined attention", + "typical_cues": ["global-local consistency", "texture repetition", "depth coherence"] + }, + "gradfield-cnn": { + "type": "edge_coherence_cnn", + "description": "Analyzes edge patterns, boundary sharpness, and gradient field coherence", + "typical_cues": ["edge smoothness", "boundary naturalness", "gradient consistency"] + } +} + +# User-facing display information for each model (used in frontend) +MODEL_DISPLAY_INFO = { + "cnn-transfer": { + "display_name": "Texture Analysis", + "short_name": "CNN", + "method_name": "Grad-CAM", + "method_description": "Gradient-weighted Class Activation Mapping", + "educational_text": ( + "This model examines fine-grained texture patterns and pixel-level details. " + "The heatmap highlights regions where texture anomalies were detected. " + "AI-generated images often have subtle texture inconsistencies - overly smooth skin, " + "unnatural fabric patterns, or repetitive background textures that this model can detect." + ), + "what_it_looks_for": [ + "Skin texture uniformity vs natural variation", + "Fine detail preservation at edges and boundaries", + "Color gradient smoothness and shading realism" + ] + }, + "vit-base": { + "display_name": "Patch Consistency", + "short_name": "ViT", + "method_name": "Attention Rollout", + "method_description": "Aggregated attention across all transformer layers", + "educational_text": ( + "This model analyzes how different parts of the image relate to each other. " + "The heatmap shows which image patches drew the most attention. " + "AI-generated images may have inconsistencies between regions - " + "mismatched lighting, perspective errors, or elements that don't quite fit together." + ), + "what_it_looks_for": [ + "Consistency of lighting across the image", + "Spatial relationships between objects", + "Background-foreground coherence" + ] + }, + "deit-distilled": { + "display_name": "Global Structure", + "short_name": "DeiT", + "method_name": "Attention Rollout", + "method_description": "Distilled attention patterns from teacher model", + "educational_text": ( + "This model uses knowledge distillation to detect global structural anomalies. " + "The heatmap reveals areas where the overall image structure seems inconsistent. " + "AI-generated images sometimes have subtle global issues - " + "like depth inconsistencies or anatomical improbabilities." + ), + "what_it_looks_for": [ + "Global-to-local consistency", + "Depth and perspective coherence", + "Structural plausibility of objects" + ] + }, + "gradfield-cnn": { + "display_name": "Edge Coherence", + "short_name": "GradField", + "method_name": "Gradient Field Analysis", + "method_description": "Analysis of image gradient patterns and edge transitions", + "educational_text": ( + "This model analyzes edge patterns and how colors transition across boundaries. " + "The heatmap highlights areas with unusual edge characteristics. " + "AI-generated images often have telltale edge artifacts - " + "unnaturally sharp or blurry boundaries, inconsistent edge directions, or gradient anomalies." + ), + "what_it_looks_for": [ + "Edge sharpness consistency", + "Natural boundary transitions", + "Gradient flow coherence" + ] + } +} + +def get_model_display_info(model_name: str) -> Dict[str, Any]: + """Get display info for a model, with fallback for unknown models.""" + return MODEL_DISPLAY_INFO.get(model_name, { + "display_name": model_name.replace("-", " ").title(), + "short_name": model_name[:3].upper(), + "method_name": "Analysis", + "method_description": "Model-specific analysis", + "educational_text": f"This model ({model_name}) analyzes the image for signs of AI generation.", + "what_it_looks_for": ["Image anomalies", "Generation artifacts"] + }) + +SYSTEM_PROMPT = """You are an AI image analysis interpreter for a deepfake detection system. Your role is to translate model evidence into human-understandable hypotheses. + +CRITICAL RULES: +1. NEVER claim certainty. Always use hedging language: "may", "suggests", "possible", "could indicate", "might show" +2. ALWAYS cite which model's evidence supports each statement (e.g., "based on CNN heatmap focus") +3. If evidence is diffuse or unclear, say so explicitly: "Evidence is spread across the image; interpretation is less certain" +4. Provide user-checkable observations, not definitive claims about what IS fake +5. Remember: you are explaining what the MODEL focused on, not proving the image is fake + +MODEL TYPES AND WHAT THEY ANALYZE: +- CNN (rgb_texture_cnn): Pixel textures, colors, fine details - looks for texture anomalies +- ViT/DeiT (patch_consistency_vit): Global consistency, patch relationships - looks for coherence issues +- GradField (edge_coherence_cnn): Edge patterns, boundaries, gradient fields - looks for edge artifacts + +OUTPUT FORMAT: +You must respond with valid JSON matching this exact structure: +{ + "per_model_insights": { + "": { + "what_model_relied_on": "One sentence describing the model's focus area", + "possible_cues": ["Cue 1 with hedging (based on evidence)", "Cue 2...", "Cue 3..."], + "confidence_note": "Note about confidence level" + } + }, + "consensus_summary": [ + "Bullet 1 about model agreement/disagreement", + "Bullet 2 about overall evidence pattern" + ] +}""" + + +class LLMService: + """Service for generating LLM-powered explanations of model predictions.""" + + def __init__(self): + self._client = None + self._model_name = None + self._enabled = False + self._initialize() + + def _initialize(self): + """Initialize the Gemini client if API key is available.""" + settings = get_settings() + + if not settings.llm_enabled: + logger.info("LLM explanations disabled: No GOOGLE_API_KEY configured") + return + + try: + from google import genai + self._client = genai.Client(api_key=settings.GOOGLE_API_KEY) + self._model_name = settings.GEMINI_MODEL + self._enabled = True + logger.info(f"LLM service initialized with model: {settings.GEMINI_MODEL}") + except ImportError: + logger.warning("google-genai package not installed. LLM explanations disabled.") + except Exception as e: + logger.error(f"Failed to initialize LLM service: {e}") + + @property + def enabled(self) -> bool: + """Check if LLM explanations are available.""" + return self._enabled + + def build_evidence_packet( + self, + model_name: str, + model_output: Dict[str, Any] + ) -> Dict[str, Any]: + """ + Build a structured evidence packet from model output. + + Args: + model_name: Name of the model (e.g., "cnn-transfer") + model_output: Raw output from the model's predict() method + + Returns: + Structured evidence packet for LLM consumption + """ + model_info = MODEL_TYPE_DESCRIPTIONS.get(model_name, { + "type": "unknown", + "description": "Unknown model type", + "typical_cues": [] + }) + + return { + "model_name": model_name, + "model_type": model_info["type"], + "model_description": model_info["description"], + "prob_fake": model_output.get("prob_fake", 0.0), + "prediction": model_output.get("pred", "unknown"), + "focus_summary": model_output.get("focus_summary", "focus pattern not available"), + "explainability_type": model_output.get("explainability_type", "unknown"), + "typical_cues_for_this_model": model_info["typical_cues"] + } + + def generate_explanation( + self, + original_image_b64: Optional[str], + submodel_outputs: Dict[str, Dict[str, Any]], + include_images: bool = True + ) -> Optional[Dict[str, Any]]: + """ + Generate LLM explanation for model predictions. + + Args: + original_image_b64: Base64-encoded original image (optional) + submodel_outputs: Dict mapping model names to their outputs + include_images: Whether to include images in the prompt (uses vision model) + + Returns: + ExplanationResult dict or None if generation fails + """ + if not self._enabled: + logger.warning("LLM explanations requested but service not enabled") + return None + + try: + # Build evidence packets for all models + evidence_packets = {} + for model_name, output in submodel_outputs.items(): + evidence_packets[model_name] = self.build_evidence_packet(model_name, output) + + # Build the prompt + user_prompt = self._build_user_prompt(evidence_packets, submodel_outputs) + + # Build content parts (text + optional images) + content_parts = [] + + # Add images if requested and available + if include_images: + # Add original image + if original_image_b64: + content_parts.append({ + "mime_type": "image/png", + "data": original_image_b64 + }) + content_parts.append("Original image shown above.\n\n") + + # Add heatmap overlays for each model + for model_name, output in submodel_outputs.items(): + if output.get("heatmap_base64"): + content_parts.append({ + "mime_type": "image/png", + "data": output["heatmap_base64"] + }) + content_parts.append(f"Heatmap overlay for {model_name} shown above.\n\n") + + # Add the main text prompt + content_parts.append(user_prompt) + + # Call the LLM using new google.genai API + logger.info("Generating LLM explanation...") + from google.genai import types + + # Build the parts list for the new API + parts = [] + for part in content_parts: + if isinstance(part, dict) and "mime_type" in part: + # Image part + parts.append(types.Part.from_bytes( + data=__import__('base64').b64decode(part["data"]), + mime_type=part["mime_type"] + )) + else: + # Text part + parts.append(types.Part.from_text(text=str(part))) + + response = self._client.models.generate_content( + model=self._model_name, + contents=[SYSTEM_PROMPT] + parts, + config=types.GenerateContentConfig( + temperature=0.3, + top_p=0.8, + max_output_tokens=2048, + ) + ) + + # Parse the response + return self._parse_response(response.text, list(submodel_outputs.keys())) + + except Exception as e: + logger.error(f"Failed to generate LLM explanation: {e}") + return None + + def _build_user_prompt( + self, + evidence_packets: Dict[str, Dict], + submodel_outputs: Dict[str, Dict] + ) -> str: + """Build the user prompt with evidence data.""" + + # Calculate some aggregate stats + prob_fakes = [p["prob_fake"] for p in evidence_packets.values()] + avg_prob = sum(prob_fakes) / len(prob_fakes) if prob_fakes else 0 + agreement = "Models generally agree" if max(prob_fakes) - min(prob_fakes) < 0.3 else "Models show disagreement" + + prompt = f"""I have {len(evidence_packets)} deepfake detection models analyzing an image. + +EVIDENCE FROM EACH MODEL: +{json.dumps(evidence_packets, indent=2)} + +AGGREGATE ANALYSIS: +- Average fake probability: {avg_prob:.1%} +- Model agreement: {agreement} +- Probability range: {min(prob_fakes):.1%} to {max(prob_fakes):.1%} + +TASK: +For each model, provide: +1. "what_model_relied_on": One sentence describing where the model focused (cite the focus_summary) +2. "possible_cues": 2-4 possible visual cues a human could check, phrased as hypotheses with hedging language +3. "confidence_note": Assessment based on prob_fake value and focus pattern + +Then provide "consensus_summary": 2-3 bullets about where models agreed/disagreed and overall evidence quality. + +Remember: Use hedging language ("may", "suggests", "possible"). Never claim certainty. + +Respond with valid JSON only, no markdown formatting.""" + + return prompt + + def _parse_response( + self, + response_text: str, + expected_models: List[str] + ) -> Optional[Dict[str, Any]]: + """Parse and validate the LLM response.""" + + try: + # Try to extract JSON from the response + # Sometimes the model wraps it in markdown code blocks + text = response_text.strip() + if text.startswith("```"): + # Remove markdown code block + lines = text.split("\n") + text = "\n".join(lines[1:-1] if lines[-1] == "```" else lines[1:]) + text = text.strip() + + result = json.loads(text) + + # Validate structure + if "per_model_insights" not in result: + logger.warning("LLM response missing per_model_insights") + result["per_model_insights"] = {} + + if "consensus_summary" not in result: + logger.warning("LLM response missing consensus_summary") + result["consensus_summary"] = ["Model analysis completed."] + + # Ensure all expected models have entries (fill with defaults if missing) + for model_name in expected_models: + if model_name not in result["per_model_insights"]: + result["per_model_insights"][model_name] = { + "what_model_relied_on": f"The {model_name} model analyzed the image.", + "possible_cues": ["Evidence details not available for this model."], + "confidence_note": "Unable to generate detailed analysis." + } + + return result + + except json.JSONDecodeError as e: + logger.error(f"Failed to parse LLM response as JSON: {e}") + logger.debug(f"Raw response: {response_text[:500]}...") + + # Return a fallback response + return { + "per_model_insights": { + model: { + "what_model_relied_on": f"The {model} model analyzed the image.", + "possible_cues": ["Unable to generate detailed explanation."], + "confidence_note": "LLM response parsing failed." + } + for model in expected_models + }, + "consensus_summary": ["Model analysis completed but detailed explanation unavailable."] + } + + def generate_single_model_explanation( + self, + model_name: str, + prob_fake: float, + original_image_b64: Optional[str] = None, + heatmap_b64: Optional[str] = None, + focus_summary: Optional[str] = None, + contribution_percentage: Optional[float] = None + ) -> Optional[Dict[str, Any]]: + """ + Generate LLM explanation for a single model's prediction. + + This is more token-efficient than generating all explanations at once, + and allows users to request explanations on-demand per model. + + Args: + model_name: Name of the model (e.g., "cnn-transfer") + prob_fake: The model's fake probability + original_image_b64: Base64-encoded original image + heatmap_b64: Base64-encoded heatmap overlay + focus_summary: Text summary of where model focused + contribution_percentage: How much this model contributed to fusion decision + + Returns: + Dict with insight for this model or None if generation fails + """ + if not self._enabled: + logger.warning("LLM explanations requested but service not enabled") + return None + + try: + # Get display info for this model + display_info = get_model_display_info(model_name) + model_type_info = MODEL_TYPE_DESCRIPTIONS.get(model_name, { + "type": "unknown", + "description": "Unknown model type", + "typical_cues": [] + }) + + # Build focused prompt for single model + prompt = f"""You are analyzing a single model's output from a deepfake detection system. + +MODEL INFORMATION: +- Display Name: {display_info['display_name']} +- Analysis Method: {display_info['method_name']} ({display_info['method_description']}) +- What It Analyzes: {model_type_info['description']} +- Typical Cues It Detects: {', '.join(model_type_info['typical_cues'])} + +DETECTION RESULTS: +- Fake Probability: {prob_fake:.1%} +- Prediction: {"Likely AI-Generated" if prob_fake >= 0.5 else "Likely Real"} +- Focus Summary: {focus_summary or "Not available"} +{f"- Contribution to Final Decision: {contribution_percentage:.1f}%" if contribution_percentage else ""} + +The heatmap shows where this model focused its attention. Brighter/warmer colors indicate higher attention. + +TASK: +Analyze the image and heatmap to explain what this specific model detected. Provide: +1. A clear explanation of what the model focused on and why it might indicate AI generation (or authenticity) +2. 2-4 specific visual cues a human could verify, phrased as hypotheses with hedging language +3. A confidence assessment based on the probability and focus pattern + +CRITICAL: Use hedging language - "may", "suggests", "possible", "could indicate". Never claim certainty. + +Respond with valid JSON matching this exact structure: +{{ + "key_finding": "One sentence main finding about what the model detected", + "what_model_saw": "2-3 sentences explaining what the model detected and why it matters", + "important_regions": ["Region 1 with hedging language", "Region 2...", "Region 3..."], + "confidence_qualifier": "Assessment of reliability with appropriate hedging" +}} + +Respond with valid JSON only, no markdown formatting.""" + + # Build content parts + content_parts = [] + + if original_image_b64: + from google.genai import types + content_parts.append(types.Part.from_bytes( + data=base64.b64decode(original_image_b64), + mime_type="image/png" + )) + content_parts.append(types.Part.from_text(text="Original image shown above.\n\n")) + + if heatmap_b64: + from google.genai import types + content_parts.append(types.Part.from_bytes( + data=base64.b64decode(heatmap_b64), + mime_type="image/png" + )) + content_parts.append(types.Part.from_text(text=f"{display_info['method_name']} heatmap shown above.\n\n")) + + from google.genai import types + content_parts.append(types.Part.from_text(text=prompt)) + + # Call the LLM with JSON response mode + logger.info(f"Generating LLM explanation for {model_name}...") + + response = self._client.models.generate_content( + model=self._model_name, + contents=content_parts, + config=types.GenerateContentConfig( + temperature=0.3, + top_p=0.8, + max_output_tokens=2048, # Increased to avoid truncation + response_mime_type="application/json", + ) + ) + + # Parse response - even with JSON mode, sometimes there are issues + text = response.text.strip() + + try: + result = json.loads(text) + except json.JSONDecodeError as parse_err: + # Log the problematic text for debugging + logger.warning(f"Initial JSON parse failed: {parse_err}") + logger.warning(f"Raw text (first 500 chars): {repr(text[:500])}") + + # Try to fix common issues: newlines inside strings + # Replace literal newlines with escaped ones, but only inside quoted strings + import re + + # More robust approach: find all string values and escape newlines + def escape_newlines_in_strings(s): + result = [] + in_string = False + escape_next = False + for i, c in enumerate(s): + if escape_next: + result.append(c) + escape_next = False + continue + if c == '\\': + escape_next = True + result.append(c) + continue + if c == '"' and not escape_next: + in_string = not in_string + result.append(c) + continue + if in_string and c == '\n': + result.append('\\n') + elif in_string and c == '\r': + result.append('\\r') + else: + result.append(c) + return ''.join(result) + + fixed_text = escape_newlines_in_strings(text) + result = json.loads(fixed_text) + + # Add model metadata to result + result["model_name"] = model_name + + return result + + except json.JSONDecodeError as e: + logger.error(f"Failed to parse single model LLM response: {e}") + return { + "model_name": model_name, + "key_finding": f"The {display_info['display_name']} detected potential signs of manipulation.", + "what_model_saw": f"The model analyzed the image but detailed analysis could not be parsed. The fake probability was {prob_fake:.1%}.", + "important_regions": ["Unable to identify specific regions."], + "confidence_qualifier": "Analysis completed but detailed explanation unavailable due to parsing error." + } + except Exception as e: + logger.error(f"Failed to generate single model explanation: {e}") + return None + + +# Global singleton +_llm_service: Optional[LLMService] = None + + +def get_llm_service() -> LLMService: + """Get the global LLM service instance.""" + global _llm_service + if _llm_service is None: + _llm_service = LLMService() + return _llm_service diff --git a/app/services/model_registry.py b/app/services/model_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..7a69d56d2431f2596e3990e25c6a42dcff385634 --- /dev/null +++ b/app/services/model_registry.py @@ -0,0 +1,343 @@ +""" +Model registry for managing loaded models. +""" + +import asyncio +import json +from pathlib import Path +from typing import Any, Dict, List, Optional, Type + +from app.core.config import settings +from app.core.errors import ModelNotFoundError, ModelNotLoadedError, ConfigurationError +from app.core.logging import get_logger +from app.models.wrappers.base_wrapper import BaseSubmodelWrapper, BaseFusionWrapper +from app.models.wrappers.dummy_random_wrapper import DummyRandomWrapper +from app.models.wrappers.dummy_majority_fusion_wrapper import DummyMajorityFusionWrapper +from app.models.wrappers.logreg_fusion_wrapper import LogRegFusionWrapper +# Real production wrappers +from app.models.wrappers.cnn_transfer_wrapper import CNNTransferWrapper +from app.models.wrappers.deit_distilled_wrapper import DeiTDistilledWrapper +from app.models.wrappers.vit_base_wrapper import ViTBaseWrapper +from app.models.wrappers.gradfield_cnn_wrapper import GradfieldCNNWrapper +from app.services.hf_hub_service import get_hf_hub_service + +logger = get_logger(__name__) + + +def get_wrapper_class(config: Dict[str, Any]) -> Type[BaseSubmodelWrapper]: + """ + Select the appropriate wrapper class based on model config. + + Uses architecture hints or model_type to dispatch to the correct wrapper. + Falls back to DummyRandomWrapper if no match found (useful for testing). + + Args: + config: Model configuration dictionary + + Returns: + Wrapper class (not instance) + """ + # Check various config fields that might indicate model type + arch = config.get("arch", "").lower() + model_type = config.get("type", "").lower() + model_class = config.get("model_class", "").lower() + model_name = config.get("model_name", "").lower() + library = config.get("library", "").lower() + + # EfficientNet / CNN Transfer + if "efficientnet" in arch or "cnn-transfer" in model_type or "efficientnet" in model_name: + return CNNTransferWrapper + + # DeiT Distilled + if "deit" in arch or "deit-distilled" in model_type or "deit" in model_name: + return DeiTDistilledWrapper + + # ViT Base (check vit but not deit) + if (("vit" in arch or "vit" in model_name) and "deit" not in arch and "deit" not in model_name) or "vit-base" in model_type: + return ViTBaseWrapper + + # Gradient Field CNN + if "gradient" in arch or "gradientnet" in model_class or "gradfield" in model_type or "gradient" in model_name: + return GradfieldCNNWrapper + + # Fallback to dummy wrapper + logger.warning(f"No matching wrapper for config, using DummyRandomWrapper: {config}") + return DummyRandomWrapper + + +def get_fusion_wrapper_class(config: Dict[str, Any]) -> Type[BaseFusionWrapper]: + """ + Select the appropriate fusion wrapper class based on config. + + Args: + config: Fusion model configuration dictionary + + Returns: + Fusion wrapper class (not instance) + """ + fusion_type = config.get("type", "").lower() + + # Logistic regression stacking fusion + if "probability_stacking" in fusion_type or "logreg" in fusion_type: + return LogRegFusionWrapper + + # Majority vote fusion + if "majority" in fusion_type: + return DummyMajorityFusionWrapper + + # Default to majority fusion + logger.warning(f"Unknown fusion type, using DummyMajorityFusionWrapper: {fusion_type}") + return DummyMajorityFusionWrapper + + +class ModelRegistry: + """ + Central registry for all loaded models. + + Manages downloading, loading, and accessing models from Hugging Face Hub. + This is the single source of truth for model state. + """ + + def __init__(self): + self._fusion: Optional[BaseFusionWrapper] = None + self._submodels: Dict[str, BaseSubmodelWrapper] = {} + self._is_loaded: bool = False + self._load_lock = asyncio.Lock() + self._hf_service = get_hf_hub_service() + + @property + def is_loaded(self) -> bool: + """Check if models are loaded.""" + return self._is_loaded + + async def load_from_fusion_repo( + self, + fusion_repo_id: str, + force_reload: bool = False + ) -> None: + """ + Load fusion model and all submodels from a fusion repository. + + This is the main entry point for loading models. It: + 1. Downloads the fusion repo and reads its config.json + 2. Extracts submodel repo IDs from config + 3. Downloads and loads each submodel + 4. Loads the fusion model + + Args: + fusion_repo_id: Hugging Face repository ID for fusion model + force_reload: If True, reload even if already loaded + """ + async with self._load_lock: + if self._is_loaded and not force_reload: + logger.info("Models already loaded, skipping") + return + + logger.info(f"Loading models from fusion repo: {fusion_repo_id}") + + # Download fusion repo + fusion_path = await asyncio.to_thread( + self._hf_service.download_repo, fusion_repo_id + ) + + # Read fusion config + fusion_config = self._read_config(fusion_path) + logger.info(f"Fusion config: {fusion_config}") + + # Get submodel repo IDs from config + submodel_repos = fusion_config.get("submodels", []) + if not submodel_repos: + raise ConfigurationError( + message="Fusion config does not specify any submodels", + details={"repo_id": fusion_repo_id} + ) + + # Download and load each submodel + for submodel_repo_id in submodel_repos: + await self._load_submodel(submodel_repo_id) + + # Create and load fusion wrapper + fusion_wrapper_class = get_fusion_wrapper_class(fusion_config) + logger.info(f"Using fusion wrapper class {fusion_wrapper_class.__name__}") + self._fusion = fusion_wrapper_class( + repo_id=fusion_repo_id, + config=fusion_config, + local_path=fusion_path + ) + self._fusion.load() + + self._is_loaded = True + logger.info(f"Successfully loaded {len(self._submodels)} submodels and fusion model") + + async def _load_submodel(self, repo_id: str) -> None: + """ + Download and load a single submodel. + + Uses the config to determine the correct wrapper class. + + Args: + repo_id: Hugging Face repository ID for the submodel + """ + logger.info(f"Loading submodel: {repo_id}") + + # Download the repo + local_path = await asyncio.to_thread( + self._hf_service.download_repo, repo_id + ) + + # Read config + config = self._read_config(local_path) + + # Select appropriate wrapper class based on config + wrapper_class = get_wrapper_class(config) + logger.info(f"Using wrapper class {wrapper_class.__name__} for {repo_id}") + + # Create and load wrapper + wrapper = wrapper_class( + repo_id=repo_id, + config=config, + local_path=local_path + ) + wrapper.load() + + # Store by short name + self._submodels[wrapper.name] = wrapper + logger.info(f"Loaded submodel: {wrapper.name}") + + def _read_config(self, local_path: str) -> Dict[str, Any]: + """ + Read config.json from a local model path. + + Args: + local_path: Path to the downloaded model + + Returns: + Configuration dictionary + """ + config_path = Path(local_path) / "config.json" + + if not config_path.exists(): + logger.warning(f"config.json not found at {config_path}, using empty config") + return {} + + with open(config_path, "r", encoding="utf-8") as f: + return json.load(f) + + def list_models(self) -> List[Dict[str, Any]]: + """ + List all loaded models. + + Returns: + List of model info dictionaries + """ + models = [] + + # Add fusion model + if self._fusion: + models.append({ + "repo_id": self._fusion.repo_id, + "name": self._fusion.name, + "model_type": "fusion", + "config": self._fusion.config + }) + + # Add submodels + for name, wrapper in self._submodels.items(): + models.append({ + "repo_id": wrapper.repo_id, + "name": name, + "model_type": "submodel", + "config": wrapper.config + }) + + return models + + def get_submodel(self, key: str) -> BaseSubmodelWrapper: + """ + Get a submodel by name or repo_id. + + Args: + key: Submodel name or full repo_id + + Returns: + Submodel wrapper + + Raises: + ModelNotFoundError: If submodel not found + ModelNotLoadedError: If models not loaded + """ + if not self._is_loaded: + raise ModelNotLoadedError( + message="Models not loaded yet", + details={"requested_model": key} + ) + + # Try by name first + if key in self._submodels: + return self._submodels[key] + + # Try by repo_id + for name, wrapper in self._submodels.items(): + if wrapper.repo_id == key: + return wrapper + + raise ModelNotFoundError( + message=f"Submodel not found: {key}", + details={ + "requested_model": key, + "available_models": list(self._submodels.keys()) + } + ) + + def get_all_submodels(self) -> Dict[str, BaseSubmodelWrapper]: + """ + Get all loaded submodels. + + Returns: + Dictionary mapping name to submodel wrapper + + Raises: + ModelNotLoadedError: If models not loaded + """ + if not self._is_loaded: + raise ModelNotLoadedError(message="Models not loaded yet") + return self._submodels.copy() + + def get_fusion(self) -> BaseFusionWrapper: + """ + Get the fusion model. + + Returns: + Fusion model wrapper + + Raises: + ModelNotLoadedError: If models not loaded + """ + if not self._is_loaded or self._fusion is None: + raise ModelNotLoadedError(message="Fusion model not loaded yet") + return self._fusion + + def get_submodel_names(self) -> List[str]: + """Get list of loaded submodel names.""" + return list(self._submodels.keys()) + + def get_fusion_repo_id(self) -> Optional[str]: + """Get the fusion repo ID if loaded.""" + return self._fusion.repo_id if self._fusion else None + + +# Global singleton instance +_model_registry: Optional[ModelRegistry] = None + + +def get_model_registry() -> ModelRegistry: + """ + Get the global model registry instance. + + Returns: + ModelRegistry instance + """ + global _model_registry + if _model_registry is None: + _model_registry = ModelRegistry() + return _model_registry diff --git a/app/services/preprocess_service.py b/app/services/preprocess_service.py new file mode 100644 index 0000000000000000000000000000000000000000..e31776a3672558015c5fb67dc919deaa08a647dc --- /dev/null +++ b/app/services/preprocess_service.py @@ -0,0 +1,128 @@ +""" +Image preprocessing service. +""" + +from typing import Optional, Dict, Any + +from PIL import Image + +from app.core.errors import ImageProcessingError +from app.core.logging import get_logger +from app.utils.image import load_image_from_bytes, validate_image_bytes +from app.utils.security import validate_file_size, validate_image_content, MAX_FILE_SIZE_BYTES + +logger = get_logger(__name__) + + +class PreprocessService: + """ + Service for preprocessing images before model inference. + + For Milestone 1, this is minimal - just validates and optionally + decodes images. Future milestones will add more preprocessing. + """ + + def __init__(self, max_file_size: int = MAX_FILE_SIZE_BYTES): + """ + Initialize the preprocess service. + + Args: + max_file_size: Maximum allowed file size in bytes + """ + self.max_file_size = max_file_size + + def validate_image(self, image_bytes: bytes) -> Dict[str, Any]: + """ + Validate uploaded image bytes. + + Args: + image_bytes: Raw image bytes + + Returns: + Dictionary with validation results + + Raises: + ImageProcessingError: If validation fails + """ + # Check file size + if not validate_file_size(image_bytes, self.max_file_size): + raise ImageProcessingError( + message=f"File too large. Maximum size is {self.max_file_size // (1024*1024)}MB", + details={"size": len(image_bytes), "max_size": self.max_file_size} + ) + + # Check content type via magic bytes + if not validate_image_content(image_bytes): + raise ImageProcessingError( + message="Invalid image format. Supported formats: JPEG, PNG, GIF, WebP, BMP", + details={"size": len(image_bytes)} + ) + + return { + "valid": True, + "size_bytes": len(image_bytes) + } + + def decode_image(self, image_bytes: bytes) -> Image.Image: + """ + Decode image bytes to PIL Image. + + Args: + image_bytes: Raw image bytes + + Returns: + PIL Image object + + Raises: + ImageProcessingError: If decoding fails + """ + return load_image_from_bytes(image_bytes) + + def preprocess( + self, + image_bytes: bytes, + decode: bool = False + ) -> Dict[str, Any]: + """ + Full preprocessing pipeline. + + Args: + image_bytes: Raw image bytes + decode: Whether to decode to PIL Image + + Returns: + Dictionary with: + - image_bytes: Original or processed bytes + - image: PIL Image if decode=True + - validation: Validation results + """ + # Validate + validation = self.validate_image(image_bytes) + + result = { + "image_bytes": image_bytes, + "validation": validation + } + + # Optionally decode + if decode: + result["image"] = self.decode_image(image_bytes) + + return result + + +# Global singleton instance +_preprocess_service: Optional[PreprocessService] = None + + +def get_preprocess_service() -> PreprocessService: + """ + Get the global preprocess service instance. + + Returns: + PreprocessService instance + """ + global _preprocess_service + if _preprocess_service is None: + _preprocess_service = PreprocessService() + return _preprocess_service diff --git a/app/utils/__init__.py b/app/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..feddb932ca0920b603aeff70b51cd5a931bb4876 --- /dev/null +++ b/app/utils/__init__.py @@ -0,0 +1 @@ +# Utils module diff --git a/app/utils/__pycache__/__init__.cpython-312.pyc b/app/utils/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f920b3a057c23adbed92573137beb02fe4460747 Binary files /dev/null and b/app/utils/__pycache__/__init__.cpython-312.pyc differ diff --git a/app/utils/__pycache__/image.cpython-312.pyc b/app/utils/__pycache__/image.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..81a1abd38ca63b9bac42b68647b649e3b174b5c8 Binary files /dev/null and b/app/utils/__pycache__/image.cpython-312.pyc differ diff --git a/app/utils/__pycache__/security.cpython-312.pyc b/app/utils/__pycache__/security.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e6780beb39a59399ee14288c057a32a3d7cfd302 Binary files /dev/null and b/app/utils/__pycache__/security.cpython-312.pyc differ diff --git a/app/utils/__pycache__/timing.cpython-312.pyc b/app/utils/__pycache__/timing.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..10a096fc82667b444a5f27e08c9f756fc41f2ddb Binary files /dev/null and b/app/utils/__pycache__/timing.cpython-312.pyc differ diff --git a/app/utils/image.py b/app/utils/image.py new file mode 100644 index 0000000000000000000000000000000000000000..52ec1da96b56cd1efd4df0a8f198c91ccd4c02f0 --- /dev/null +++ b/app/utils/image.py @@ -0,0 +1,119 @@ +""" +Image processing utilities. +""" + +from io import BytesIO +from typing import Optional, Tuple + +from PIL import Image + +from app.core.errors import ImageProcessingError +from app.core.logging import get_logger + +logger = get_logger(__name__) + + +def load_image_from_bytes(image_bytes: bytes) -> Image.Image: + """ + Load a PIL Image from raw bytes. + + Args: + image_bytes: Raw image bytes + + Returns: + PIL Image object + + Raises: + ImageProcessingError: If image cannot be decoded + """ + try: + image = Image.open(BytesIO(image_bytes)) + # Convert to RGB if necessary (handles RGBA, grayscale, etc.) + if image.mode != "RGB": + image = image.convert("RGB") + return image + except Exception as e: + logger.error(f"Failed to decode image: {e}") + raise ImageProcessingError( + message="Failed to decode image", + details={"error": str(e)} + ) + + +def validate_image_bytes(image_bytes: bytes) -> bool: + """ + Validate that bytes represent a valid image. + + Args: + image_bytes: Raw image bytes + + Returns: + True if valid image, False otherwise + """ + try: + image = Image.open(BytesIO(image_bytes)) + image.verify() + return True + except Exception: + return False + + +def get_image_info(image: Image.Image) -> dict: + """ + Get basic information about an image. + + Args: + image: PIL Image object + + Returns: + Dictionary with image info + """ + return { + "width": image.width, + "height": image.height, + "mode": image.mode, + "format": image.format + } + + +def resize_image( + image: Image.Image, + size: Tuple[int, int], + resample: int = Image.Resampling.LANCZOS +) -> Image.Image: + """ + Resize image to specified size. + + Args: + image: PIL Image object + size: Target (width, height) + resample: Resampling filter + + Returns: + Resized PIL Image + """ + return image.resize(size, resample=resample) + + +def image_to_bytes( + image: Image.Image, + format: str = "PNG", + quality: int = 95 +) -> bytes: + """ + Convert PIL Image to bytes. + + Args: + image: PIL Image object + format: Output format (PNG, JPEG, etc.) + quality: JPEG quality (1-95) + + Returns: + Image as bytes + """ + buffer = BytesIO() + if format.upper() == "JPEG": + image.save(buffer, format=format, quality=quality) + else: + image.save(buffer, format=format) + return buffer.getvalue() diff --git a/app/utils/security.py b/app/utils/security.py new file mode 100644 index 0000000000000000000000000000000000000000..7d3a6749b19a2595c406e17cdd17e976549bbfd9 --- /dev/null +++ b/app/utils/security.py @@ -0,0 +1,126 @@ +""" +Security utilities for the application. +""" + +import hashlib +import secrets +from typing import List, Optional + +from app.core.logging import get_logger + +logger = get_logger(__name__) + +# Maximum allowed file size (10 MB) +MAX_FILE_SIZE_BYTES = 10 * 1024 * 1024 + +# Allowed image MIME types +ALLOWED_MIME_TYPES = [ + "image/jpeg", + "image/png", + "image/gif", + "image/webp", + "image/bmp" +] + +# Image magic bytes for validation +IMAGE_SIGNATURES = { + b'\xff\xd8\xff': 'image/jpeg', + b'\x89PNG\r\n\x1a\n': 'image/png', + b'GIF87a': 'image/gif', + b'GIF89a': 'image/gif', + b'RIFF': 'image/webp', # WebP (partial) + b'BM': 'image/bmp' +} + + +def validate_file_size(content: bytes, max_size: int = MAX_FILE_SIZE_BYTES) -> bool: + """ + Validate that file size is within allowed limits. + + Args: + content: File content as bytes + max_size: Maximum allowed size in bytes + + Returns: + True if valid, False otherwise + """ + return len(content) <= max_size + + +def detect_mime_type(content: bytes) -> Optional[str]: + """ + Detect MIME type from file content using magic bytes. + + Args: + content: File content as bytes + + Returns: + Detected MIME type or None + """ + for signature, mime_type in IMAGE_SIGNATURES.items(): + if content.startswith(signature): + return mime_type + return None + + +def validate_image_content( + content: bytes, + allowed_types: List[str] = ALLOWED_MIME_TYPES +) -> bool: + """ + Validate image content by checking magic bytes. + + Args: + content: File content as bytes + allowed_types: List of allowed MIME types + + Returns: + True if valid image type, False otherwise + """ + detected_type = detect_mime_type(content) + if detected_type is None: + return False + return detected_type in allowed_types + + +def compute_file_hash(content: bytes, algorithm: str = "sha256") -> str: + """ + Compute hash of file content. + + Args: + content: File content as bytes + algorithm: Hash algorithm (sha256, md5, etc.) + + Returns: + Hex-encoded hash string + """ + hasher = hashlib.new(algorithm) + hasher.update(content) + return hasher.hexdigest() + + +def generate_request_id() -> str: + """ + Generate a unique request ID. + + Returns: + Random hex string + """ + return secrets.token_hex(8) + + +def sanitize_filename(filename: str) -> str: + """ + Sanitize a filename to prevent path traversal. + + Args: + filename: Original filename + + Returns: + Sanitized filename + """ + # Remove path separators and null bytes + sanitized = filename.replace("/", "_").replace("\\", "_").replace("\x00", "") + # Remove leading dots to prevent hidden files + sanitized = sanitized.lstrip(".") + return sanitized[:255] if sanitized else "unnamed" diff --git a/app/utils/timing.py b/app/utils/timing.py new file mode 100644 index 0000000000000000000000000000000000000000..37bc0a77e091c75a11026438a82b127348428243 --- /dev/null +++ b/app/utils/timing.py @@ -0,0 +1,112 @@ +""" +Timing utilities for performance measurement. +""" + +import time +from contextlib import contextmanager +from typing import Dict, Generator, Optional + +from app.core.logging import get_logger + +logger = get_logger(__name__) + + +class Timer: + """ + Timer class for measuring execution time of code blocks. + + Usage: + timer = Timer() + with timer.measure("inference"): + # do inference + with timer.measure("fusion"): + # do fusion + timings = timer.get_timings() + """ + + def __init__(self): + self._start_time: Optional[float] = None + self._timings: Dict[str, int] = {} + self._total_start: Optional[float] = None + + def start_total(self) -> None: + """Start the total timer.""" + self._total_start = time.perf_counter() + + def stop_total(self) -> None: + """Stop the total timer and record the duration.""" + if self._total_start is not None: + elapsed_ms = int((time.perf_counter() - self._total_start) * 1000) + self._timings["total"] = elapsed_ms + + @contextmanager + def measure(self, name: str) -> Generator[None, None, None]: + """ + Context manager to measure execution time of a block. + + Args: + name: Name for this timing measurement + + Yields: + None + """ + start = time.perf_counter() + try: + yield + finally: + elapsed_ms = int((time.perf_counter() - start) * 1000) + self._timings[name] = elapsed_ms + logger.debug(f"Timer [{name}]: {elapsed_ms}ms") + + def record(self, name: str, duration_ms: int) -> None: + """ + Manually record a timing. + + Args: + name: Name for this timing + duration_ms: Duration in milliseconds + """ + self._timings[name] = duration_ms + + def get_timings(self) -> Dict[str, int]: + """ + Get all recorded timings. + + Returns: + Dictionary of timing name -> milliseconds + """ + return self._timings.copy() + + def get(self, name: str) -> Optional[int]: + """ + Get a specific timing. + + Args: + name: Timing name + + Returns: + Duration in milliseconds, or None if not recorded + """ + return self._timings.get(name) + + def reset(self) -> None: + """Reset all timings.""" + self._timings.clear() + self._total_start = None + + +def measure_time(func): + """ + Decorator to measure function execution time. + + Logs the execution time at DEBUG level. + """ + def wrapper(*args, **kwargs): + start = time.perf_counter() + try: + result = func(*args, **kwargs) + return result + finally: + elapsed_ms = int((time.perf_counter() - start) * 1000) + logger.debug(f"Function [{func.__name__}]: {elapsed_ms}ms") + return wrapper diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..30f1a3474ed38631ef03ec4dfec12afba15269d6 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,40 @@ +# DeepFake Detector Backend - Python Dependencies + +# Web framework +fastapi>=0.109.0,<1.0.0 +uvicorn[standard]>=0.27.0,<1.0.0 + +# Data validation +pydantic>=2.5.0,<3.0.0 +pydantic-settings>=2.1.0,<3.0.0 + +# Image processing +pillow>=10.2.0,<11.0.0 +matplotlib>=3.7.0,<4.0.0 +opencv-python>=4.8.0,<5.0.0 + +# Deep Learning +torch>=2.0.0,<3.0.0 +torchvision>=0.15.0,<1.0.0 +timm>=0.9.0 + +# Machine Learning (for fusion models) +scikit-learn>=1.3.0,<2.0.0 +numpy>=1.24.0,<2.0.0 + +# Hugging Face Hub +huggingface_hub>=0.20.0,<1.0.0 +hf_xet + +# HTTP client (optional, for testing) +httpx>=0.26.0,<1.0.0 + +# File uploads +python-multipart>=0.0.6,<1.0.0 + +# Testing +pytest>=7.4.0,<9.0.0 +pytest-asyncio>=0.23.0,<1.0.0 + +# Google Gemini for LLM explanations +google-genai>=1.0.0 \ No newline at end of file diff --git a/start.sh b/start.sh new file mode 100644 index 0000000000000000000000000000000000000000..25aa75220f92a3a5a9d087e468ed49f801d5f35c --- /dev/null +++ b/start.sh @@ -0,0 +1,8 @@ +#!/bin/bash +# Startup script for deployment +# Uses PORT environment variable if set, otherwise defaults to 7860 (HuggingFace Spaces) or 8000 + +PORT=${PORT:-7860} + +echo "Starting uvicorn on port $PORT" +exec uvicorn app.main:app --host 0.0.0.0 --port "$PORT" --log-level info diff --git a/test_explainability.py b/test_explainability.py new file mode 100644 index 0000000000000000000000000000000000000000..8e63c9ca98b4daf843c981e4af345e526b6c270b --- /dev/null +++ b/test_explainability.py @@ -0,0 +1,45 @@ +"""Test script for explainability features.""" +import asyncio +import traceback +import numpy as np +from PIL import Image +import io + +async def main(): + from app.services.model_registry import get_model_registry + from app.core.config import settings + + registry = get_model_registry() + + # Load models from fusion repo + print("Loading models...") + await registry.load_from_fusion_repo(settings.HF_FUSION_REPO_ID) + print("Models loaded!") + + # Create a test image + img = Image.fromarray(np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)) + buf = io.BytesIO() + img.save(buf, format='PNG') + img_bytes = buf.getvalue() + + # Test each model + models = ['cnn-transfer', 'gradfield-cnn', 'vit-base', 'deit-distilled'] + + for model_name in models: + print(f"\nTesting {model_name}...") + try: + model = registry.get_submodel(model_name) + result = model.predict(image_bytes=img_bytes, explain=True) + has_heatmap = 'heatmap_base64' in result + print(f" Success! pred={result['pred']}, has_heatmap={has_heatmap}") + if has_heatmap: + # Check heatmap is valid base64 + import base64 + decoded = base64.b64decode(result['heatmap_base64']) + print(f" Heatmap size: {len(decoded)} bytes") + except Exception as e: + print(f" ERROR: {e}") + traceback.print_exc() + +if __name__ == "__main__": + asyncio.run(main())