from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.middleware.cors import CORSMiddleware import tensorflow as tf from tensorflow.keras.models import load_model from tensorflow.keras.applications.efficientnet_v2 import preprocess_input from transformers import AutoModelForAudioClassification, AutoFeatureExtractor import torch import librosa from PIL import Image import numpy as np import cv2 import io import os import base64 import tempfile # ========================================== # 1. INITIAL SETUP # ========================================== app = FastAPI(title="Veritas Forensic Engine") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # --- GLOBAL VARIABLES --- IMAGE_MODEL_PATH = "with_flux_model.keras" AUDIO_MODEL_ID = "MelodyMachine/Deepfake-audio-detection-V2" image_model = None audio_model = None feature_extractor = None # ========================================== # 2. STARTUP EVENT (Loads BOTH Models) # ========================================== @app.on_event("startup") async def startup_event(): global image_model, audio_model, feature_extractor # A. Load Image Model if os.path.exists(IMAGE_MODEL_PATH): print("📷 Loading Image/Video Model...") image_model = load_model(IMAGE_MODEL_PATH, compile=False) print("✅ Image Model Ready.") else: print(f"❌ CRITICAL: {IMAGE_MODEL_PATH} not found.") # B. Load Audio Model print("🎤 Loading Audio Model (Wav2Vec2)...") try: feature_extractor = AutoFeatureExtractor.from_pretrained(AUDIO_MODEL_ID) audio_model = AutoModelForAudioClassification.from_pretrained(AUDIO_MODEL_ID) print("✅ Audio Model Ready.") except Exception as e: print(f"❌ Audio Model Failed: {e}") # ========================================== # 3. HELPER FUNCTIONS (Image/GradCAM) # ========================================== def preprocess_image_array(img_array): img_resized = tf.image.resize(img_array, (224, 224), method='lanczos3') img_expanded = tf.expand_dims(img_resized, axis=0) return preprocess_input(img_expanded) def generate_heatmap_safe(img_tensor, pred_index): try: target_layer = None for layer in image_model.layers: if "efficientnet" in layer.name or "top_activation" in layer.name: target_layer = layer break if not target_layer: return None grad_model = tf.keras.models.Model( [image_model.inputs], [target_layer.output, image_model.output] ) with tf.GradientTape() as tape: last_conv_layer_output, preds = grad_model(img_tensor) class_channel = preds[:, pred_index] grads = tape.gradient(class_channel, last_conv_layer_output) pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2)) last_conv_layer_output = last_conv_layer_output[0] heatmap = last_conv_layer_output @ pooled_grads[..., tf.newaxis] heatmap = tf.squeeze(heatmap) heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap) return heatmap.numpy() except Exception as e: print(f"⚠️ GradCAM Skipped: {e}") return None def overlay_heatmap(original_img_pil, heatmap): img = np.array(original_img_pil) if heatmap is None: is_success, buffer = cv2.imencode(".jpg", cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) return base64.b64encode(buffer).decode("utf-8") heatmap = np.uint8(255 * heatmap) jet = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) jet = cv2.resize(jet, (img.shape[1], img.shape[0])) superimposed_img = jet * 0.4 + img * 0.6 superimposed_img = np.clip(superimposed_img, 0, 255).astype("uint8") is_success, buffer = cv2.imencode(".jpg", cv2.cvtColor(superimposed_img, cv2.COLOR_RGB2BGR)) return base64.b64encode(buffer).decode("utf-8") # ========================================== # 4. ENDPOINT: IMAGE ANALYSIS # ========================================== @app.post("/api/analyze-image") async def analyze_image(file: UploadFile = File(...)): if not image_model: raise HTTPException(500, "Image model not loaded.") try: contents = await file.read() img = Image.open(io.BytesIO(contents)).convert('RGB') processed_tensor = preprocess_image_array(np.array(img)) preds = image_model.predict(processed_tensor) ai_score = float(preds[0][0]) real_score = float(preds[0][1]) confidence = max(ai_score, real_score) * 100 label = "AI" if ai_score > real_score else "Real" pred_index = 0 if ai_score > real_score else 1 heatmap = generate_heatmap_safe(processed_tensor, pred_index) heatmap_b64 = overlay_heatmap(img, heatmap) return { "type": "image", "prediction": label, "confidence": round(confidence, 2), "heatmap_base64": heatmap_b64, "probabilities": {"ai": ai_score, "real": real_score} } except Exception as e: return {"error": str(e)} # ========================================== # 5. ENDPOINT: VIDEO ANALYSIS # ========================================== @app.post("/api/analyze-video") async def analyze_video(file: UploadFile = File(...)): if not image_model: raise HTTPException(500, "Image model not loaded.") try: with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_vid: temp_vid.write(await file.read()) temp_path = temp_vid.name cap = cv2.VideoCapture(temp_path) frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) fps = int(cap.get(cv2.CAP_PROP_FPS)) frames_to_analyze = 10 step = max(1, frame_count // frames_to_analyze) timeline_results = [] fake_frame_count = 0 for i in range(0, frame_count, step): cap.set(cv2.CAP_PROP_POS_FRAMES, i) ret, frame = cap.read() if not ret: break frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) processed_tensor = preprocess_image_array(frame_rgb) preds = image_model.predict(processed_tensor) ai_score = float(preds[0][0]) timestamp = round(i / fps, 2) if fps > 0 else 0 timeline_results.append({ "timestamp": timestamp, "ai_score": ai_score, "status": "FAKE" if ai_score > 0.5 else "REAL" }) if ai_score > 0.5: fake_frame_count += 1 if len(timeline_results) >= frames_to_analyze: break cap.release() os.unlink(temp_path) overall_fake_percent = (fake_frame_count / len(timeline_results)) * 100 if len(timeline_results) > 0 else 0 final_verdict = "DEEPFAKE DETECTED" if overall_fake_percent > 40 else "AUTHENTIC VIDEO" return { "type": "video", "prediction": final_verdict, "fake_percentage": round(overall_fake_percent, 2), "timeline": timeline_results } except Exception as e: return {"error": str(e)} # ========================================== # 6. ENDPOINT: AUDIO ANALYSIS # ========================================== @app.post("/api/analyze-audio") async def analyze_audio(file: UploadFile = File(...)): if not audio_model: raise HTTPException(500, "Audio model not loaded.") try: with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_audio: temp_audio.write(await file.read()) temp_path = temp_audio.name audio_input, sample_rate = librosa.load(temp_path, sr=16000) inputs = feature_extractor(audio_input, sampling_rate=16000, return_tensors="pt") with torch.no_grad(): logits = audio_model(**inputs).logits probs = torch.nn.functional.softmax(logits, dim=-1) # --- FIXED MAPPING FOR MELODY MACHINE V2 --- # Label 0 is usually REAL, Label 1 is usually FAKE/SPOOF real_score = float(probs[0][0]) fake_score = float(probs[0][1]) os.unlink(temp_path) verdict = "FAKE AUDIO DETECTED" if fake_score > real_score else "AUTHENTIC AUDIO" confidence = max(fake_score, real_score) * 100 return { "type": "audio", "prediction": verdict, "confidence": round(confidence, 2), "probabilities": {"ai": fake_score, "real": real_score} } except Exception as e: return {"error": str(e)}