jenithjain's picture
Use correct deepfake checkpoint and reject incompatible weights
14399c4
import io
import os
import time
import cv2
import numpy as np
import torch
import torch.nn as nn
from efficientnet_pytorch import EfficientNet
from fastapi import FastAPI, File, HTTPException, UploadFile
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI(title="Deepfake Detection API", version="2.0.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DETECTION_THRESHOLD = 0.40
class DeepfakeEfficientNet(nn.Module):
def __init__(self, pretrained: bool = True, dropout: float = 0.5):
super().__init__()
if pretrained:
self.net = EfficientNet.from_pretrained("efficientnet-b0")
else:
self.net = EfficientNet.from_name("efficientnet-b0")
in_features = self.net._fc.in_features
self.net._fc = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(in_features, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Dropout(dropout * 0.7),
nn.Linear(512, 256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Dropout(dropout * 0.5),
nn.Linear(256, 1),
)
def forward(self, rgb_input, freq_input=None):
return self.net(rgb_input)
class TemporalTracker:
def __init__(self, window_size: int = 60, voting_window: int = 10, threshold: float = DETECTION_THRESHOLD):
self.window_size = window_size
self.voting_window = voting_window
self.threshold = threshold
self.score_history = []
self.frame_votes = []
def update(self, fake_probability: float):
self.score_history.append(float(fake_probability))
if len(self.score_history) > self.window_size:
self.score_history = self.score_history[-self.window_size :]
vote = "FAKE" if fake_probability > self.threshold else "REAL"
self.frame_votes.append(vote)
if len(self.frame_votes) > self.voting_window:
self.frame_votes = self.frame_votes[-self.voting_window :]
def get_temporal_average(self) -> float:
if not self.score_history:
return 0.0
return float(sum(self.score_history) / len(self.score_history))
def get_stability_score(self) -> float:
if len(self.score_history) < 10:
return 0.0
arr = np.array(self.score_history[-10:], dtype=np.float32)
variance = float(np.var(arr))
return float(1.0 - min(variance * 4.0, 1.0))
def get_confidence_level(self) -> str:
if len(self.frame_votes) < self.voting_window:
return "UNCERTAIN"
fake_count = sum(1 for x in self.frame_votes if x == "FAKE")
real_count = len(self.frame_votes) - fake_count
return "FAKE" if fake_count > real_count else "REAL"
def reset(self):
self.score_history = []
self.frame_votes = []
class ForensicAnalyzer:
def __init__(self, analysis_size=(256, 256)):
self.analysis_size = analysis_size
self.prev_gray = None
def analyze(self, frame_bgr: np.ndarray):
resized = cv2.resize(frame_bgr, self.analysis_size, interpolation=cv2.INTER_LINEAR)
frequency = self._analyze_frequency(resized)
noise = self._analyze_noise(resized)
ela = self._analyze_ela(resized)
edge = self._analyze_edges(resized)
temporal = self._analyze_temporal(resized)
score = (
0.32 * frequency
+ 0.20 * noise
+ 0.18 * ela
+ 0.18 * edge
+ 0.12 * temporal
)
return {
"fake_probability": float(np.clip(score, 0.0, 1.0)),
"scores": {
"frequency": float(frequency),
"noise": float(noise),
"ela": float(ela),
"edge": float(edge),
"temporal": float(temporal),
},
}
def _analyze_frequency(self, frame):
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY).astype(np.float32)
magnitude = np.log1p(np.abs(np.fft.fftshift(np.fft.fft2(gray))))
h, w = magnitude.shape
cy, cx = h // 2, w // 2
y_grid, x_grid = np.ogrid[:h, :w]
dist = np.sqrt((x_grid - cx) ** 2 + (y_grid - cy) ** 2)
inner = min(h, w) // 8
outer = min(h, w) // 3
low = magnitude[dist <= inner]
high = magnitude[(dist > inner) & (dist <= outer)]
low_mean = float(low.mean()) if low.size else 0.0
high_mean = float(high.mean()) if high.size else 0.0
ratio = high_mean / (low_mean + high_mean + 1e-9)
if ratio < 0.18:
return 0.75
if ratio < 0.24:
return 0.45
return 0.10
def _analyze_noise(self, frame):
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY).astype(np.float32)
blurred = cv2.GaussianBlur(gray, (5, 5), 0)
residual = gray - blurred
std = float(np.std(residual))
if std < 2.0:
return 0.70
if std < 4.0:
return 0.35
return 0.12
def _analyze_ela(self, frame):
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 90]
ok, encoded = cv2.imencode(".jpg", frame, encode_param)
if not ok:
return 0.0
recompressed = cv2.imdecode(encoded, cv2.IMREAD_COLOR)
if recompressed is None:
return 0.0
diff = cv2.absdiff(frame, recompressed)
mean_diff = float(np.mean(diff))
if mean_diff > 14:
return 0.65
if mean_diff > 8:
return 0.35
return 0.08
def _analyze_edges(self, frame):
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
edges = cv2.Canny(gray, 50, 150)
edge_density = float(np.mean(edges > 0))
lap_var = float(np.var(cv2.Laplacian(gray, cv2.CV_64F)))
score = 0.0
if edge_density < 0.02:
score += 0.45
elif edge_density < 0.04:
score += 0.20
if lap_var < 60:
score += 0.35
elif lap_var < 120:
score += 0.15
return float(np.clip(score, 0.0, 1.0))
def _analyze_temporal(self, frame):
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY).astype(np.float32)
if self.prev_gray is None:
self.prev_gray = gray
return 0.0
diff = cv2.absdiff(gray, self.prev_gray)
self.prev_gray = gray
mean_delta = float(np.mean(diff))
if mean_delta < 1.2:
return 0.40
if mean_delta < 2.5:
return 0.20
return 0.08
def reset(self):
self.prev_gray = None
model = None
model_loaded = False
tracker = TemporalTracker()
forensics = ForensicAnalyzer()
frame_count = 0
def load_checkpoint_model():
global model
global model_loaded
checkpoint_candidates = [
os.path.join(os.path.dirname(__file__), "models", "best_model.pth"),
os.path.join(os.path.dirname(__file__), "models", "latest_model_ft.pth"),
]
model = DeepfakeEfficientNet(pretrained=True).to(DEVICE)
loaded_any = False
for path in checkpoint_candidates:
if not os.path.exists(path):
continue
try:
checkpoint = torch.load(path, map_location=DEVICE, weights_only=False)
state_dict = checkpoint.get("model_state_dict", checkpoint)
# Only keep keys that belong to this architecture and match tensor shapes.
model_state = model.state_dict()
filtered_state = {}
for key, value in state_dict.items():
if key in model_state and hasattr(value, "shape") and model_state[key].shape == value.shape:
filtered_state[key] = value
if not filtered_state:
print(f"Rejected checkpoint (no compatible keys): {path}")
continue
load_result = model.load_state_dict(filtered_state, strict=False)
# Guardrail: require substantial overlap so unrelated checkpoints don't load.
loaded_ratio = len(filtered_state) / max(len(model_state), 1)
if loaded_ratio < 0.7:
print(
f"Rejected checkpoint (too few compatible keys: {len(filtered_state)}/{len(model_state)} = {loaded_ratio:.2%}): {path}"
)
continue
loaded_any = True
print(f"Loaded checkpoint: {path}")
print(
f"Compatible keys: {len(filtered_state)}/{len(model_state)} | "
f"Missing: {len(load_result.missing_keys)} | Unexpected ignored: {len(load_result.unexpected_keys)}"
)
break
except Exception as ex:
print(f"Failed loading checkpoint {path}: {ex}")
if loaded_any:
model.eval()
model_loaded = True
else:
model_loaded = False
print("No compatible deepfake checkpoint found; running forensic-only mode.")
@app.on_event("startup")
async def startup_event():
load_checkpoint_model()
@app.get("/")
async def root_health():
return {
"status": "healthy",
"model_loaded": model_loaded,
"device": DEVICE,
"frame_count": frame_count,
}
@app.get("/health")
async def health_check():
return {
"status": "healthy",
"model_loaded": model_loaded,
"device": DEVICE,
"frame_count": frame_count,
"capabilities": {
"frame_forensics": True,
"temporal_tracking": True,
"face_detection": False,
},
}
@app.post("/reset")
async def reset_state():
global frame_count
tracker.reset()
forensics.reset()
frame_count = 0
return {"success": True, "message": "Detector state reset"}
def _prepare_model_tensor(frame_bgr: np.ndarray) -> torch.Tensor:
rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
resized = cv2.resize(rgb, (224, 224), interpolation=cv2.INTER_AREA)
arr = resized.astype(np.float32) / 255.0
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
arr = (arr - mean) / std
chw = np.transpose(arr, (2, 0, 1))
tensor = torch.from_numpy(chw).unsqueeze(0).to(DEVICE)
return tensor
def _run_model(frame_bgr: np.ndarray):
if not model_loaded or model is None:
return None
try:
tensor = _prepare_model_tensor(frame_bgr)
with torch.no_grad():
logits = model(tensor).squeeze()
prob = torch.sigmoid(logits).item()
return float(np.clip(prob, 0.0, 1.0))
except Exception as ex:
print(f"Model inference failed: {ex}")
return None
@app.post("/analyze")
async def analyze_frame(frame: UploadFile = File(None), file: UploadFile = File(None)):
global frame_count
start = time.time()
uploaded = frame or file
if uploaded is None:
raise HTTPException(status_code=400, detail="No frame provided. Use multipart form field 'frame'.")
raw = await uploaded.read()
if not raw:
raise HTTPException(status_code=400, detail="Empty file")
np_bytes = np.frombuffer(raw, np.uint8)
image = cv2.imdecode(np_bytes, cv2.IMREAD_COLOR)
if image is None:
raise HTTPException(status_code=400, detail="Invalid image format")
frame_count += 1
forensic_result = forensics.analyze(image)
forensic_prob = float(forensic_result["fake_probability"])
model_prob = _run_model(image)
if model_prob is None:
combined_prob = forensic_prob
analysis_mode = "frame_only"
else:
combined_prob = float(np.clip(0.70 * model_prob + 0.30 * forensic_prob, 0.0, 1.0))
analysis_mode = "model+frame"
tracker.update(combined_prob)
elapsed_ms = (time.time() - start) * 1000.0
return {
"success": True,
"analysis_mode": analysis_mode,
"faces_detected": 0,
"fake_probability": combined_prob,
"model_probability": model_prob,
"frame_forensic_probability": forensic_prob,
"real_probability": float(1.0 - combined_prob),
"confidence_level": tracker.get_confidence_level(),
"temporal_average": tracker.get_temporal_average(),
"stability_score": tracker.get_stability_score(),
"frame_count": frame_count,
"processing_time_ms": round(elapsed_ms, 1),
"forensic_scores": forensic_result["scores"],
}
if __name__ == "__main__":
import uvicorn
port = int(os.environ.get("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)