Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, File, UploadFile, Form, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| import os | |
| import tempfile | |
| import shutil | |
| from dotenv import load_dotenv | |
| from model_utils import load_model, get_device | |
| from preprocessing import preprocess_video, predict | |
| # Audio deepfake detection imports (separate from video pipeline) | |
| from audio_predict import predict_audio, AudioPredictionError | |
| from audio_preprocessing import AudioValidationError, AudioLoadError | |
| # Load environment variables | |
| load_dotenv() | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="Deepfake Detection API", | |
| description="Video and Audio deepfake detection API", | |
| version="1.0.0" | |
| ) | |
| # CORS configuration for production-ready deployment | |
| # Get frontend URL from environment variable (blank for development) | |
| frontend_url = os.getenv("FRONTEND_URL", "").strip() | |
| # Build CORS allowed origins list | |
| allowed_origins = [ | |
| "http://localhost:8080", # Vite dev server default | |
| "http://localhost:5173", # Alternative Vite port | |
| "http://127.0.0.1:8080", | |
| "http://127.0.0.1:5173", | |
| ] | |
| # Add production frontend URL if specified | |
| if frontend_url: | |
| allowed_origins.append(frontend_url) | |
| print(f"✓ Production frontend URL added to CORS: {frontend_url}") | |
| else: | |
| print("✓ Development mode: Using localhost CORS origins") | |
| # Configure CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=allowed_origins, | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| print("\n=== Deepfake Detection Server Configuration ===") | |
| print(f"Allowed CORS origins: {allowed_origins}") | |
| print(f"Device: {get_device()}") | |
| print("=" * 50 + "\n") | |
| async def root(): | |
| """Health check endpoint""" | |
| return { | |
| "status": "online", | |
| "service": "Deepfake Detection API", | |
| "version": "1.1.0", | |
| "device": get_device(), | |
| "capabilities": ["video", "audio"] | |
| } | |
| async def predict_video_endpoint( | |
| file: UploadFile = File(...), | |
| sequence_length: int = Form(...), | |
| face_focus: bool = Form(True) | |
| ): | |
| """ | |
| Predict whether a video is real or fake. | |
| Args: | |
| file: Video file to analyze | |
| sequence_length: Number of frames to extract (10, 20, 40, 60, 80, 100) | |
| face_focus: Whether to focus on faces (currently always enabled) | |
| Returns: | |
| JSON response with prediction result and confidence | |
| """ | |
| temp_video_path = None | |
| try: | |
| # Validate sequence length | |
| valid_lengths = [10, 20, 40, 60, 80, 100] | |
| if sequence_length not in valid_lengths: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Invalid sequence_length. Must be one of {valid_lengths}" | |
| ) | |
| # Validate file type | |
| if not file.content_type or not file.content_type.startswith('video/'): | |
| raise HTTPException( | |
| status_code=400, | |
| detail="File must be a video" | |
| ) | |
| print(f"\n{'='*50}") | |
| print(f"Processing video: {file.filename}") | |
| print(f"Sequence length: {sequence_length} frames") | |
| print(f"Face focus: {face_focus}") | |
| print(f"{'='*50}\n") | |
| # Save uploaded video to temporary file | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp_file: | |
| shutil.copyfileobj(file.file, temp_file) | |
| temp_video_path = temp_file.name | |
| print(f"✓ Video saved to: {temp_video_path}") | |
| # Load model for the specified sequence length | |
| device = get_device() | |
| model = load_model(sequence_length, device) | |
| if model is None: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Failed to load model for {sequence_length} frames" | |
| ) | |
| print(f"✓ Model loaded successfully") | |
| # Preprocess video | |
| print(f"⏳ Preprocessing video...") | |
| frames_tensor, preprocessed_images, face_cropped_images, faces_found = preprocess_video( | |
| temp_video_path, | |
| sequence_length, | |
| save_preprocessed=False # Set to True if you want to save frames | |
| ) | |
| print(f"✓ Preprocessing complete") | |
| # Make prediction | |
| print(f"⏳ Running prediction...") | |
| prediction_int, confidence = predict(model, frames_tensor, device) | |
| # Convert prediction to label | |
| prediction_label = "REAL" if prediction_int == 1 else "FAKE" | |
| print(f"\n{'='*50}") | |
| print(f"✓ PREDICTION: {prediction_label}") | |
| print(f"✓ CONFIDENCE: {confidence:.1f}%") | |
| print(f"{'='*50}\n") | |
| # Return response with frame images for frontend display | |
| # Limit to max 6 frames to keep response size reasonable | |
| display_frames = face_cropped_images[:6] if len(face_cropped_images) > 6 else face_cropped_images | |
| return JSONResponse(content={ | |
| "prediction": prediction_label, | |
| "confidence": round(confidence, 1), | |
| "sequence_length": sequence_length, | |
| "device": device, | |
| "faces_found": faces_found, | |
| "total_frames_analyzed": len(face_cropped_images), | |
| "frame_images": display_frames | |
| }) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| print(f"\n❌ Error during prediction: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Prediction failed: {str(e)}" | |
| ) | |
| finally: | |
| # Clean up temporary file | |
| if temp_video_path and os.path.exists(temp_video_path): | |
| try: | |
| os.unlink(temp_video_path) | |
| print(f"✓ Cleaned up temporary file") | |
| except Exception as e: | |
| print(f"⚠ Warning: Could not delete temporary file: {e}") | |
| # ============================================================================= | |
| # AUDIO DEEPFAKE DETECTION ENDPOINT | |
| # ============================================================================= | |
| async def predict_audio_endpoint(file: UploadFile = File(...)): | |
| """ | |
| Predict whether an audio file is real or fake (deepfake). | |
| Uses MelodyMachine/Deepfake-audio-detection-V2 (Wav2Vec2-based model). | |
| Args: | |
| file: Audio file to analyze (WAV, MP3, FLAC, M4A, OGG supported) | |
| Returns: | |
| JSON response with prediction result and confidence: | |
| { | |
| "prediction": "REAL" | "FAKE", | |
| "confidence": float (0-100), | |
| "model": "MelodyMachine/Deepfake-audio-detection-V2", | |
| "all_scores": {"real": float, "fake": float} | |
| } | |
| """ | |
| temp_audio_path = None | |
| try: | |
| # Validate content type | |
| if file.content_type and not file.content_type.startswith('audio/'): | |
| raise HTTPException( | |
| status_code=400, | |
| detail="File must be an audio file" | |
| ) | |
| print(f"\n{'='*50}") | |
| print(f"Processing audio: {file.filename}") | |
| print(f"Content type: {file.content_type}") | |
| print(f"{'='*50}\n") | |
| # Get file extension from filename | |
| file_ext = os.path.splitext(file.filename)[1] if file.filename else '.wav' | |
| # Save uploaded audio to temporary file | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=file_ext) as temp_file: | |
| shutil.copyfileobj(file.file, temp_file) | |
| temp_audio_path = temp_file.name | |
| print(f"✓ Audio saved to: {temp_audio_path}") | |
| # Run audio prediction | |
| print(f"⏳ Running audio deepfake detection...") | |
| result = predict_audio(temp_audio_path, file.content_type) | |
| print(f"\n{'='*50}") | |
| print(f"✓ AUDIO PREDICTION: {result['prediction']}") | |
| print(f"✓ CONFIDENCE: {result['confidence']:.1f}%") | |
| print(f"{'='*50}\n") | |
| return JSONResponse(content=result) | |
| except AudioValidationError as e: | |
| print(f"\n❌ Audio validation error: {str(e)}") | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| except AudioLoadError as e: | |
| print(f"\n❌ Audio load error: {str(e)}") | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| except AudioPredictionError as e: | |
| print(f"\n❌ Audio prediction error: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| print(f"\n❌ Error during audio prediction: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Audio prediction failed: {str(e)}" | |
| ) | |
| finally: | |
| # Clean up temporary file | |
| if temp_audio_path and os.path.exists(temp_audio_path): | |
| try: | |
| os.unlink(temp_audio_path) | |
| print(f"✓ Cleaned up temporary audio file") | |
| except Exception as e: | |
| print(f"⚠ Warning: Could not delete temporary audio file: {e}") | |
| async def list_available_models(): | |
| """List all available models and their frame counts""" | |
| import glob | |
| models_dir = "models" | |
| model_files = glob.glob(os.path.join(models_dir, "*.pt")) | |
| models_info = [] | |
| for model_path in model_files: | |
| filename = os.path.basename(model_path) | |
| try: | |
| parts = filename.split("_") | |
| accuracy = parts[1] | |
| frames = parts[3] | |
| models_info.append({ | |
| "filename": filename, | |
| "frames": int(frames), | |
| "accuracy": f"{accuracy}%" | |
| }) | |
| except (IndexError, ValueError): | |
| continue | |
| return { | |
| "available_models": sorted(models_info, key=lambda x: x["frames"]), | |
| "total": len(models_info) | |
| } | |
| if __name__ == "__main__": | |
| import uvicorn | |
| port = int(os.getenv("PORT", 8000)) | |
| uvicorn.run(app, host="0.0.0.0", port=port) | |