Spaces:
Sleeping
Sleeping
| import os | |
| import shutil | |
| import tempfile | |
| import uvicorn | |
| import warnings | |
| from fastapi import FastAPI, UploadFile, File, HTTPException, Depends | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from mtcnn.mtcnn import MTCNN | |
| import tensorflow as tf | |
| from huggingface_hub import hf_hub_download | |
| # --- Suppress TensorFlow & MTCNN Warnings --- | |
| os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' | |
| tf.get_logger().setLevel('ERROR') | |
| warnings.filterwarnings('ignore') | |
| # HuggingFace Hub Configuration | |
| HF_REPO_ID = "piyushnaula/deepfake_model_return0" | |
| HF_TOKEN = os.getenv("HF_TOKEN") # Set this in environment variables | |
| # --- Imports for config and prediction functions --- | |
| try: | |
| from . import config | |
| from .predict import get_image_prediction | |
| from .predict_video_model import get_video_prediction | |
| from .database import connect_to_mongo, close_mongo_connection, get_database | |
| from .auth import ( | |
| UserSignup, UserLogin, UserResponse, UsageResponse, | |
| hash_password, verify_password, generate_api_key, create_user_document, | |
| validate_api_key, hash_api_key | |
| ) | |
| except ImportError: | |
| # This fallback lets us run the file directly if needed | |
| import config | |
| from predict import get_image_prediction | |
| from predict_video_model import get_video_prediction | |
| from database import connect_to_mongo, close_mongo_connection, get_database | |
| from auth import ( | |
| UserSignup, UserLogin, UserResponse, UsageResponse, | |
| hash_password, verify_password, generate_api_key, create_user_document, | |
| validate_api_key, hash_api_key | |
| ) | |
| # --- 1. Create the FastAPI app --- | |
| app = FastAPI( | |
| title="Deepfake Detector API", | |
| description="An API to detect deepfake images and videos using advanced ML models.", | |
| version="1.0.0" | |
| ) | |
| # --- CORS Configuration --- | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Allows all origins | |
| allow_credentials=True, | |
| allow_methods=["*"], # Allows all methods | |
| allow_headers=["*"], # Allows all headers | |
| ) | |
| # --- 2. Load Models at Startup (Best Practice) --- | |
| # This dictionary will hold our models, loaded ONCE. | |
| # This is far more efficient than loading them for every request. | |
| models = {} | |
| async def startup_event(): | |
| """ | |
| Startup: Connect to MongoDB and load ML models. | |
| """ | |
| # Connect to MongoDB first | |
| await connect_to_mongo() | |
| # Then load ML models | |
| await load_models() | |
| async def shutdown_event(): | |
| """Shutdown: Close MongoDB connection.""" | |
| await close_mongo_connection() | |
| async def load_models(): | |
| """ | |
| Load all ML models from HuggingFace Hub when the API server starts. | |
| """ | |
| print("--- Loading models from HuggingFace Hub... ---") | |
| # --- Load Image Model from HuggingFace --- | |
| try: | |
| print("Downloading Image Model (baseline_model.h5) from HuggingFace...") | |
| image_model_path = hf_hub_download( | |
| repo_id=HF_REPO_ID, | |
| filename="baseline_model.h5", | |
| token=HF_TOKEN | |
| ) | |
| models["image_model"] = tf.keras.models.load_model(image_model_path, compile=False) | |
| print("Image model loaded successfully.") | |
| except Exception as e: | |
| print(f"WARNING: Failed to load Image Model: {e}") | |
| # --- Load Video Model from HuggingFace --- | |
| try: | |
| print("Downloading Video Model from HuggingFace...") | |
| # Download finetuned encoder first (needed for video_model.py) | |
| finetuned_path = None | |
| try: | |
| finetuned_path = hf_hub_download( | |
| repo_id=HF_REPO_ID, | |
| filename="finetuned_model.h5", | |
| token=HF_TOKEN | |
| ) | |
| print(f"Finetuned encoder downloaded to: {finetuned_path}") | |
| except Exception as e: | |
| print(f"WARNING: Could not download finetuned_model.h5: {e}") | |
| print("Will use ImageNet weights as fallback...") | |
| # Import build_video_model here to avoid circular imports | |
| try: | |
| from .video_model import build_video_model | |
| except ImportError: | |
| from video_model import build_video_model | |
| # Build the model architecture (pass the downloaded path) | |
| print("Building video model architecture...") | |
| video_model = build_video_model(finetuned_model_path=finetuned_path) | |
| # Try to download and load weights | |
| try: | |
| video_weights_path = hf_hub_download( | |
| repo_id=HF_REPO_ID, | |
| filename="video_model_v2.keras", | |
| token=HF_TOKEN | |
| ) | |
| video_model.load_weights(video_weights_path) | |
| print("Video model weights loaded successfully.") | |
| except Exception as e: | |
| print(f"WARNING: Could not load video weights: {e}") | |
| print("Video model will use untrained weights (less accurate).") | |
| models["video_model"] = video_model | |
| print("Video model initialized successfully.") | |
| except Exception as e: | |
| print(f"WARNING: Failed to load Video Model: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| # --- Load MTCNN Detector --- | |
| models["mtcnn_detector"] = MTCNN() | |
| print("MTCNN detector initialized.") | |
| # --- Load Audio Model (HuggingFace Transformers) --- | |
| try: | |
| print("Loading Audio Model (wav2vec2-base-finetuned)...") | |
| from transformers import Wav2Vec2FeatureExtractor, AutoModelForAudioClassification | |
| # Use FeatureExtractor instead of Processor (no tokenizer needed for classification) | |
| models["audio_processor"] = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-base") | |
| models["audio_model"] = AutoModelForAudioClassification.from_pretrained("mo-thecreator/wav2vec2-base-finetuned") | |
| print("Audio Model and Feature Extractor loaded successfully!") | |
| except Exception as e: | |
| print(f"CRITICAL: Failed to load Audio Model: {type(e).__name__}: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| # Print final status | |
| print("--- Model loading complete ---") | |
| print(f"Loaded models: {list(models.keys())}") | |
| # --- 3. Define API Endpoints --- | |
| def read_root(): | |
| """A simple 'health check' endpoint to see if the server is running.""" | |
| return {"status": "Deepfake Detector API is online and running."} | |
| # --- 4. Authentication Endpoints --- | |
| async def signup(user: UserSignup): | |
| """ | |
| Create a new user account and get your API key. | |
| ⚠️ IMPORTANT: Save your API key! It will only be shown ONCE. | |
| """ | |
| db = get_database() | |
| # Check if email already exists | |
| existing = await db.users.find_one({"email": user.email}) | |
| if existing: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Email already registered. Please login to get your API key." | |
| ) | |
| # Create new user (returns user_doc and raw_api_key) | |
| user_doc, raw_api_key = create_user_document(user.email, user.password) | |
| await db.users.insert_one(user_doc) | |
| return UserResponse( | |
| email=user.email, | |
| api_key=raw_api_key, | |
| message="Account created! ⚠️ SAVE YOUR API KEY NOW - it will NOT be shown again!" | |
| ) | |
| async def login(user: UserLogin): | |
| """ | |
| Login to view your API key prefix. | |
| Note: For security, full key is only shown at signup. | |
| Use /regenerate-key to get a new key if lost. | |
| """ | |
| db = get_database() | |
| # Find user | |
| existing = await db.users.find_one({"email": user.email}) | |
| if not existing: | |
| raise HTTPException(status_code=404, detail="User not found. Please signup first.") | |
| # Verify password | |
| if not verify_password(user.password, existing["password_hash"]): | |
| raise HTTPException(status_code=401, detail="Invalid password.") | |
| # Update last login | |
| from datetime import datetime | |
| await db.users.update_one( | |
| {"email": user.email}, | |
| {"$set": {"last_login": datetime.utcnow()}} | |
| ) | |
| return UserResponse( | |
| email=user.email, | |
| api_key=existing.get("api_key_prefix", "Key hidden for security"), | |
| message="Login successful. Use /regenerate-key if you need a new API key." | |
| ) | |
| async def regenerate_key(user: UserLogin): | |
| """ | |
| Generate a new API key. The old key will stop working. | |
| ⚠️ IMPORTANT: Save your new API key! It will only be shown ONCE. | |
| """ | |
| db = get_database() | |
| # Find user | |
| existing = await db.users.find_one({"email": user.email}) | |
| if not existing: | |
| raise HTTPException(status_code=404, detail="User not found.") | |
| # Verify password | |
| if not verify_password(user.password, existing["password_hash"]): | |
| raise HTTPException(status_code=401, detail="Invalid password.") | |
| # Generate new key (returns tuple) | |
| raw_key, key_hash, key_prefix = generate_api_key() | |
| await db.users.update_one( | |
| {"email": user.email}, | |
| {"$set": {"api_key_hash": key_hash, "api_key_prefix": key_prefix}} | |
| ) | |
| return UserResponse( | |
| email=user.email, | |
| api_key=raw_key, | |
| message="New API key generated! ⚠️ SAVE IT NOW - old key is now invalid!" | |
| ) | |
| async def get_usage(user: dict = Depends(validate_api_key)): | |
| """ | |
| Check your API usage and remaining quota. | |
| Requires x-api-key header. | |
| """ | |
| rate_limit = user.get("rate_limit", 100) | |
| requests_today = user.get("requests_today", 0) | |
| return UsageResponse( | |
| email=user["email"], | |
| requests_today=requests_today, | |
| rate_limit=rate_limit, | |
| remaining=max(0, rate_limit - requests_today), | |
| total_requests=user.get("total_requests", 0) | |
| ) | |
| # --- 5. Prediction Endpoints (Protected) --- | |
| async def predict_image_api( | |
| file: UploadFile = File(...), | |
| user: dict = Depends(validate_api_key) | |
| ): | |
| """ | |
| Endpoint for predicting a single deepfake image. | |
| Requires API key in x-api-key header. | |
| """ | |
| if "image_model" not in models: | |
| raise HTTPException(status_code=500, detail="Image model is not loaded.") | |
| # We must save the uploaded file to a temporary path | |
| # because our prediction function expects a file path. | |
| temp_file_path = "" | |
| try: | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=file.filename) as temp_file: | |
| shutil.copyfileobj(file.file, temp_file) | |
| temp_file_path = temp_file.name | |
| print(f"Processing image: {temp_file_path}") | |
| # Call our prediction function and pass it the pre-loaded model | |
| result = get_image_prediction( | |
| image_path=temp_file_path, | |
| model=models["image_model"] | |
| ) | |
| return result | |
| except Exception as e: | |
| # If anything goes wrong, return an error | |
| raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}") | |
| finally: | |
| # CRITICAL: Always clean up the temp file | |
| if os.path.exists(temp_file_path): | |
| os.remove(temp_file_path) | |
| async def predict_video_api( | |
| file: UploadFile = File(...), | |
| user: dict = Depends(validate_api_key) | |
| ): | |
| """ | |
| Endpoint for predicting a single deepfake video. | |
| Requires API key in x-api-key header. | |
| """ | |
| if "video_model" not in models or "mtcnn_detector" not in models: | |
| raise HTTPException(status_code=500, detail="Video models are not loaded.") | |
| temp_file_path = "" | |
| try: | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=file.filename) as temp_file: | |
| shutil.copyfileobj(file.file, temp_file) | |
| temp_file_path = temp_file.name | |
| print(f"Processing video: {temp_file_path}") | |
| # Call the video prediction function | |
| result = get_video_prediction( | |
| video_path=temp_file_path, | |
| video_model=models["video_model"], | |
| detector=models["mtcnn_detector"] | |
| ) | |
| return result | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}") | |
| finally: | |
| # CRITICAL: Always clean up the temp file | |
| if os.path.exists(temp_file_path): | |
| os.remove(temp_file_path) | |
| async def predict_audio_api( | |
| file: UploadFile = File(...), | |
| user: dict = Depends(validate_api_key) | |
| ): | |
| """ | |
| Endpoint for predicting a single deepfake audio. | |
| Requires API key in x-api-key header. | |
| """ | |
| if "audio_model" not in models or "audio_processor" not in models: | |
| # Try to reload if missing | |
| try: | |
| from transformers import Wav2Vec2FeatureExtractor, AutoModelForAudioClassification | |
| models["audio_processor"] = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-base") | |
| models["audio_model"] = AutoModelForAudioClassification.from_pretrained("mo-thecreator/wav2vec2-base-finetuned") | |
| except: | |
| raise HTTPException(status_code=500, detail="Audio model is not loaded.") | |
| temp_file_path = "" | |
| try: | |
| # Save temp file | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=file.filename) as temp_file: | |
| shutil.copyfileobj(file.file, temp_file) | |
| temp_file_path = temp_file.name | |
| print(f"Processing audio: {temp_file_path}") | |
| # Load audio using librosa (required for wav2vec2) | |
| import librosa | |
| import torch | |
| # Load audio and resample to 16kHz (required for wav2vec2) | |
| audio_array, sampling_rate = librosa.load(temp_file_path, sr=16000) | |
| print(f"Audio loaded: {len(audio_array)} samples at {sampling_rate}Hz") | |
| # Process audio with the processor | |
| processor = models["audio_processor"] | |
| model = models["audio_model"] | |
| inputs = processor(audio_array, sampling_rate=16000, return_tensors="pt", padding=True) | |
| # Run inference | |
| with torch.no_grad(): | |
| logits = model(**inputs).logits | |
| # Get probabilities using softmax | |
| probabilities = torch.nn.functional.softmax(logits, dim=-1) | |
| print(f"Raw logits: {logits}") | |
| print(f"Probabilities: {probabilities}") | |
| # Get the predicted class | |
| predicted_class_id = logits.argmax().item() | |
| predicted_label = model.config.id2label[predicted_class_id] | |
| # Get individual scores (id2label: {0: "fake", 1: "real"}) | |
| fake_score = probabilities[0][0].item() * 100 # Index 0 = fake | |
| real_score = probabilities[0][1].item() * 100 # Index 1 = real | |
| print(f"Fake Score: {fake_score:.2f}%, Real Score: {real_score:.2f}%") | |
| print(f"Predicted: {predicted_label}") | |
| # Determine prediction | |
| if fake_score > real_score: | |
| prediction = "FAKE" | |
| confidence = fake_score | |
| else: | |
| prediction = "REAL" | |
| confidence = real_score | |
| return { | |
| "prediction": prediction, | |
| "confidence": round(confidence, 2), | |
| "fake_score": round(fake_score, 2), | |
| "real_score": round(real_score, 2), | |
| "raw": f"Fake: {fake_score:.2f}%, Real: {real_score:.2f}%" | |
| } | |
| except Exception as e: | |
| print(f"Audio prediction error: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}") | |
| finally: | |
| # CRITICAL: Always clean up the temp file | |
| if os.path.exists(temp_file_path): | |
| os.remove(temp_file_path) | |
| # --- 4. How to run this file for development --- | |
| if __name__ == "__main__": | |
| print("--- Starting FastAPI server directly (for development) ---") | |
| print("--- Go to http://127.0.0.1:8000 for the API ---") | |
| uvicorn.run("main:app", host="127.0.0.1", port=8000, reload=True, app_dir="src") |