BugFreeAli's picture
Update main.py
12b9074 verified
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.applications.efficientnet_v2 import preprocess_input
from transformers import AutoModelForAudioClassification, AutoFeatureExtractor
import torch
import librosa
from PIL import Image
import numpy as np
import cv2
import io
import os
import base64
import tempfile
# ==========================================
# 1. INITIAL SETUP
# ==========================================
app = FastAPI(title="Veritas Forensic Engine")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# --- GLOBAL VARIABLES ---
IMAGE_MODEL_PATH = "with_flux_model.keras"
AUDIO_MODEL_ID = "MelodyMachine/Deepfake-audio-detection-V2"
image_model = None
audio_model = None
feature_extractor = None
# ==========================================
# 2. STARTUP EVENT (Loads BOTH Models)
# ==========================================
@app.on_event("startup")
async def startup_event():
global image_model, audio_model, feature_extractor
# A. Load Image Model
if os.path.exists(IMAGE_MODEL_PATH):
print("πŸ“· Loading Image/Video Model...")
image_model = load_model(IMAGE_MODEL_PATH, compile=False)
print("βœ… Image Model Ready.")
else:
print(f"❌ CRITICAL: {IMAGE_MODEL_PATH} not found.")
# B. Load Audio Model
print("🎀 Loading Audio Model (Wav2Vec2)...")
try:
feature_extractor = AutoFeatureExtractor.from_pretrained(AUDIO_MODEL_ID)
audio_model = AutoModelForAudioClassification.from_pretrained(AUDIO_MODEL_ID)
print("βœ… Audio Model Ready.")
except Exception as e:
print(f"❌ Audio Model Failed: {e}")
# ==========================================
# 3. HELPER FUNCTIONS (Image/GradCAM)
# ==========================================
def preprocess_image_array(img_array):
img_resized = tf.image.resize(img_array, (224, 224), method='lanczos3')
img_expanded = tf.expand_dims(img_resized, axis=0)
return preprocess_input(img_expanded)
def generate_heatmap_safe(img_tensor, pred_index):
try:
target_layer = None
for layer in image_model.layers:
if "efficientnet" in layer.name or "top_activation" in layer.name:
target_layer = layer
break
if not target_layer: return None
grad_model = tf.keras.models.Model(
[image_model.inputs],
[target_layer.output, image_model.output]
)
with tf.GradientTape() as tape:
last_conv_layer_output, preds = grad_model(img_tensor)
class_channel = preds[:, pred_index]
grads = tape.gradient(class_channel, last_conv_layer_output)
pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
last_conv_layer_output = last_conv_layer_output[0]
heatmap = last_conv_layer_output @ pooled_grads[..., tf.newaxis]
heatmap = tf.squeeze(heatmap)
heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
return heatmap.numpy()
except Exception as e:
print(f"⚠️ GradCAM Skipped: {e}")
return None
def overlay_heatmap(original_img_pil, heatmap):
img = np.array(original_img_pil)
if heatmap is None:
is_success, buffer = cv2.imencode(".jpg", cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
return base64.b64encode(buffer).decode("utf-8")
heatmap = np.uint8(255 * heatmap)
jet = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
jet = cv2.resize(jet, (img.shape[1], img.shape[0]))
superimposed_img = jet * 0.4 + img * 0.6
superimposed_img = np.clip(superimposed_img, 0, 255).astype("uint8")
is_success, buffer = cv2.imencode(".jpg", cv2.cvtColor(superimposed_img, cv2.COLOR_RGB2BGR))
return base64.b64encode(buffer).decode("utf-8")
# ==========================================
# 4. ENDPOINT: IMAGE ANALYSIS
# ==========================================
@app.post("/api/analyze-image")
async def analyze_image(file: UploadFile = File(...)):
if not image_model: raise HTTPException(500, "Image model not loaded.")
try:
contents = await file.read()
img = Image.open(io.BytesIO(contents)).convert('RGB')
processed_tensor = preprocess_image_array(np.array(img))
preds = image_model.predict(processed_tensor)
ai_score = float(preds[0][0])
real_score = float(preds[0][1])
confidence = max(ai_score, real_score) * 100
label = "AI" if ai_score > real_score else "Real"
pred_index = 0 if ai_score > real_score else 1
heatmap = generate_heatmap_safe(processed_tensor, pred_index)
heatmap_b64 = overlay_heatmap(img, heatmap)
return {
"type": "image",
"prediction": label,
"confidence": round(confidence, 2),
"heatmap_base64": heatmap_b64,
"probabilities": {"ai": ai_score, "real": real_score}
}
except Exception as e:
return {"error": str(e)}
# ==========================================
# 5. ENDPOINT: VIDEO ANALYSIS
# ==========================================
@app.post("/api/analyze-video")
async def analyze_video(file: UploadFile = File(...)):
if not image_model: raise HTTPException(500, "Image model not loaded.")
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_vid:
temp_vid.write(await file.read())
temp_path = temp_vid.name
cap = cv2.VideoCapture(temp_path)
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = int(cap.get(cv2.CAP_PROP_FPS))
frames_to_analyze = 10
step = max(1, frame_count // frames_to_analyze)
timeline_results = []
fake_frame_count = 0
for i in range(0, frame_count, step):
cap.set(cv2.CAP_PROP_POS_FRAMES, i)
ret, frame = cap.read()
if not ret: break
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
processed_tensor = preprocess_image_array(frame_rgb)
preds = image_model.predict(processed_tensor)
ai_score = float(preds[0][0])
timestamp = round(i / fps, 2) if fps > 0 else 0
timeline_results.append({
"timestamp": timestamp,
"ai_score": ai_score,
"status": "FAKE" if ai_score > 0.5 else "REAL"
})
if ai_score > 0.5: fake_frame_count += 1
if len(timeline_results) >= frames_to_analyze: break
cap.release()
os.unlink(temp_path)
overall_fake_percent = (fake_frame_count / len(timeline_results)) * 100 if len(timeline_results) > 0 else 0
final_verdict = "DEEPFAKE DETECTED" if overall_fake_percent > 40 else "AUTHENTIC VIDEO"
return {
"type": "video",
"prediction": final_verdict,
"fake_percentage": round(overall_fake_percent, 2),
"timeline": timeline_results
}
except Exception as e:
return {"error": str(e)}
# ==========================================
# 6. ENDPOINT: AUDIO ANALYSIS
# ==========================================
@app.post("/api/analyze-audio")
async def analyze_audio(file: UploadFile = File(...)):
if not audio_model: raise HTTPException(500, "Audio model not loaded.")
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_audio:
temp_audio.write(await file.read())
temp_path = temp_audio.name
audio_input, sample_rate = librosa.load(temp_path, sr=16000)
inputs = feature_extractor(audio_input, sampling_rate=16000, return_tensors="pt")
with torch.no_grad():
logits = audio_model(**inputs).logits
probs = torch.nn.functional.softmax(logits, dim=-1)
# --- FIXED MAPPING FOR MELODY MACHINE V2 ---
# Label 0 is usually REAL, Label 1 is usually FAKE/SPOOF
real_score = float(probs[0][0])
fake_score = float(probs[0][1])
os.unlink(temp_path)
verdict = "FAKE AUDIO DETECTED" if fake_score > real_score else "AUTHENTIC AUDIO"
confidence = max(fake_score, real_score) * 100
return {
"type": "audio",
"prediction": verdict,
"confidence": round(confidence, 2),
"probabilities": {"ai": fake_score, "real": real_score}
}
except Exception as e:
return {"error": str(e)}