kaeva-verify / app.py
Vi0509's picture
Upload app.py with huggingface_hub
c9d0f73 verified
"""Kaeva Verify V10 β€” Full Ensemble Deepfake Detection API.
Models loaded:
- image_ensemble_v2 (EfficientNet-B0, 15.6MB) β€” general image deepfake
- ai_gen_detector (EfficientNet-B3, 44.4MB) β€” AI-generated image detection
- spectral_detector (Dual-stream ResNet18, 131MB) β€” frequency/spectral analysis
- frequency_analyzer (MLP+CNN, 1.5MB) β€” DCT/wavelet/Benford features
- audio_deepfake_v10 (Wav2Vec2 full, 361MB) β€” 3-class audio (real/tts/vc)
- audio_deepfake_model (Wav2Vec2 probe, 0.8MB) β€” binary audio fallback
Endpoints:
POST /image β€” V10 ensemble image detection (4 models, platform-aware)
POST /audio β€” Audio deepfake detection (v10 primary, v1 fallback)
POST /video β€” Video: frame ensemble + audio analysis
POST /ocr β€” Extract text from image via pytesseract
GET /health β€” Health check
"""
import io, os, traceback, tempfile, subprocess, time, json
import numpy as np
import torch
import torch.nn as nn
from fastapi import FastAPI, UploadFile, File, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image
from torchvision import transforms
from torchvision.models import efficientnet_b0, efficientnet_b3, resnet18
from transformers import Wav2Vec2Model, Wav2Vec2FeatureExtractor
import librosa
app = FastAPI(title="Kaeva Verify V10", version="10.1.0")
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
DEVICE = torch.device("cpu")
# ── Ensemble configs ──
ENSEMBLE_CONFIGS = {
"clean": {"weights": {"image_ensemble_v2": 0.35, "ai_gen": 0.30, "spectral": 0.20, "frequency": 0.15}, "threshold": 0.50},
"whatsapp": {"weights": {"image_ensemble_v2": 0.40, "ai_gen": 0.25, "spectral": 0.20, "frequency": 0.15}, "threshold": 0.55},
"instagram": {"weights": {"image_ensemble_v2": 0.35, "ai_gen": 0.30, "spectral": 0.20, "frequency": 0.15}, "threshold": 0.50},
"telegram": {"weights": {"image_ensemble_v2": 0.35, "ai_gen": 0.30, "spectral": 0.20, "frequency": 0.15}, "threshold": 0.50},
"screenshot": {"weights": {"image_ensemble_v2": 0.40, "ai_gen": 0.25, "spectral": 0.15, "frequency": 0.20}, "threshold": 0.50},
}
# ── Transforms ──
img_transform_224 = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
img_transform_300 = transforms.Compose([
transforms.Resize((300, 300)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# ── Model registry ──
models = {}
# ═══════════════════════════════════════════
# 1. IMAGE ENSEMBLE V2 (EfficientNet-B0)
# ═══════════════════════════════════════════
def load_image_ensemble_v2():
if "image_ensemble_v2" in models:
return
print("Loading image_ensemble_v2 (EfficientNet-B0)...", flush=True)
model = efficientnet_b0(weights=None)
model.classifier[1] = nn.Linear(1280, 2)
sd = torch.load("image_ensemble_v2.pt", map_location=DEVICE, weights_only=False)
if isinstance(sd, dict) and "model_state_dict" in sd:
sd = sd["model_state_dict"]
model.load_state_dict(sd, strict=True)
model.eval()
models["image_ensemble_v2"] = model
print(" ok image_ensemble_v2", flush=True)
def infer_image_ensemble_v2(img: Image.Image) -> float:
load_image_ensemble_v2()
tensor = img_transform_224(img).unsqueeze(0).to(DEVICE)
with torch.no_grad():
logits = models["image_ensemble_v2"](tensor)
return torch.softmax(logits, dim=1)[0, 1].item()
# ═══════════════════════════════════════════
# 2. AI GEN DETECTOR (EfficientNet-B3)
# ═══════════════════════════════════════════
def load_ai_gen():
if "ai_gen" in models:
return
print("Loading ai_gen (EfficientNet-B3)...", flush=True)
model = efficientnet_b3(weights=None)
model.classifier = nn.Sequential(
nn.Dropout(p=0.3, inplace=True),
nn.Linear(1536, 512),
nn.ReLU(inplace=True),
nn.Dropout(p=0.2),
nn.Linear(512, 2),
)
ckpt = torch.load("ai_gen_detector.pt", map_location=DEVICE, weights_only=False)
sd = ckpt.get("model_state_dict", ckpt)
clean_sd = {}
for k, v in sd.items():
new_k = k.replace("backbone.", "", 1) if k.startswith("backbone.") else k
clean_sd[new_k] = v
model.load_state_dict(clean_sd, strict=True)
model.eval()
models["ai_gen"] = model
print(f" ok ai_gen (val_acc={ckpt.get('val_acc', 'N/A')})", flush=True)
def infer_ai_gen(img: Image.Image) -> float:
load_ai_gen()
tensor = img_transform_300(img).unsqueeze(0).to(DEVICE)
with torch.no_grad():
logits = models["ai_gen"](tensor)
# Class 0 = fake, Class 1 = real for this model
return torch.softmax(logits, dim=1)[0, 0].item()
# ═══════════════════════════════════════════
# 3. SPECTRAL DETECTOR (Dual-stream ResNet18)
# ═══════════════════════════════════════════
class DualStreamSpectral(nn.Module):
"""Two ResNet18 streams: (A) magnitude spectrum, (B) phase spectrum β†’ fused classifier."""
def __init__(self):
super().__init__()
# Stream A: magnitude spectrum
base_a = resnet18(weights=None)
base_a.conv1 = nn.Conv2d(1, 64, 7, stride=2, padding=3, bias=False)
base_a.fc = nn.Linear(512, 256)
self.stream_a = base_a
# Stream B: phase spectrum
base_b = resnet18(weights=None)
base_b.conv1 = nn.Conv2d(1, 64, 7, stride=2, padding=3, bias=False)
base_b.fc = nn.Linear(512, 128)
self.stream_b = base_b
# Classifier: 256+128 = 384 β†’ 1
self.classifier = nn.Sequential(
nn.LayerNorm(384), # classifier.0
nn.Linear(384, 128), # classifier.2
nn.ReLU(), # classifier.3
nn.Dropout(0.3), # classifier.4
nn.Linear(128, 1), # classifier.5
)
def forward(self, mag, phase):
a = self.stream_a(mag)
b = self.stream_b(phase)
fused = torch.cat([a, b], dim=1)
return self.classifier(fused)
def load_spectral():
if "spectral" in models:
return
print("Loading spectral_detector (DualStreamSpectral)...", flush=True)
model = DualStreamSpectral()
ckpt = torch.load("spectral_detector.pt", map_location=DEVICE, weights_only=False)
sd = ckpt.get("model", ckpt)
model.load_state_dict(sd, strict=True)
model.eval()
models["spectral"] = model
print(f" ok spectral (best_f1={ckpt.get('best_f1', 'N/A')})", flush=True)
def compute_spectrum(img: Image.Image):
"""Convert PIL image β†’ grayscale β†’ FFT β†’ magnitude & phase tensors (1,1,224,224)."""
gray = np.array(img.convert("L").resize((224, 224)), dtype=np.float32) / 255.0
f = np.fft.fft2(gray)
fshift = np.fft.fftshift(f)
mag = np.log1p(np.abs(fshift))
phase = np.angle(fshift)
# Normalize
mag = (mag - mag.mean()) / (mag.std() + 1e-8)
phase = (phase - phase.mean()) / (phase.std() + 1e-8)
mag_t = torch.from_numpy(mag).unsqueeze(0).unsqueeze(0).float()
phase_t = torch.from_numpy(phase).unsqueeze(0).unsqueeze(0).float()
return mag_t, phase_t
def infer_spectral(img: Image.Image) -> float:
load_spectral()
mag_t, phase_t = compute_spectrum(img)
with torch.no_grad():
logit = models["spectral"](mag_t.to(DEVICE), phase_t.to(DEVICE))
return torch.sigmoid(logit).item()
# ═══════════════════════════════════════════
# 4. FREQUENCY ANALYZER (MLP + CNN)
# ═══════════════════════════════════════════
class FrequencyAnalyzer(nn.Module):
"""Handcrafted frequency features (54-dim) through MLP + spectrum (64x64) through CNN β†’ fused classifier."""
def __init__(self, handcrafted_dim=54, spectrum_size=64):
super().__init__()
# MLP for handcrafted features
self.mlp = nn.Sequential(
nn.Linear(handcrafted_dim, 128), # mlp.0
nn.BatchNorm1d(128), # mlp.1
nn.ReLU(), # mlp.2
nn.Dropout(0.3), # mlp.3
nn.Linear(128, 64), # mlp.4
nn.ReLU(), # mlp.5
)
# CNN for spectrum image (3 conv blocks with BN)
self.cnn = nn.Sequential(
nn.Conv2d(1, 32, 3, padding=1), # cnn.0
nn.BatchNorm2d(32), # cnn.1
nn.ReLU(), # cnn.2
nn.MaxPool2d(2), # cnn.3 -> 32x32
nn.Conv2d(32, 64, 3, padding=1), # cnn.4
nn.BatchNorm2d(64), # cnn.5
nn.ReLU(), # cnn.6
nn.MaxPool2d(2), # cnn.7 -> 16x16
nn.Conv2d(64, 128, 3, padding=1), # cnn.8
nn.BatchNorm2d(128), # cnn.9
nn.ReLU(), # cnn.10
nn.AdaptiveAvgPool2d(4), # cnn.11 -> 4x4
nn.Flatten(), # cnn.12 -> 128*4*4 = 2048
)
self.cnn_fc = nn.Linear(2048, 128)
# Classifier: 64 (mlp) + 128 (cnn) = 192
self.classifier = nn.Sequential(
nn.LayerNorm(192), # classifier.0
nn.Linear(192, 64), # classifier.2
nn.ReLU(), # classifier.3
nn.Dropout(0.3), # classifier.4
nn.Linear(64, 1), # classifier.5
)
def forward(self, handcrafted, spectrum):
mlp_out = self.mlp(handcrafted)
cnn_out = self.cnn(spectrum)
cnn_out = self.cnn_fc(cnn_out)
cnn_out = torch.relu(cnn_out)
fused = torch.cat([mlp_out, cnn_out], dim=1)
return self.classifier(fused)
def load_frequency():
if "frequency" in models:
return
print("Loading frequency_analyzer...", flush=True)
ckpt = torch.load("frequency_analyzer.pt", map_location=DEVICE, weights_only=False)
config = ckpt.get("config", {})
fdims = ckpt.get("feature_dims", {})
hc_dim = fdims.get("handcrafted_total", 54)
spec_size = fdims.get("spectrum_size", 64)
model = FrequencyAnalyzer(handcrafted_dim=hc_dim, spectrum_size=spec_size)
sd = ckpt.get("model", ckpt)
model.load_state_dict(sd, strict=True)
model.eval()
models["frequency"] = {"model": model, "hc_dim": hc_dim, "spec_size": spec_size}
print(f" ok frequency (best_f1={ckpt.get('best_f1', 'N/A')})", flush=True)
def extract_frequency_features(img: Image.Image):
"""Extract handcrafted frequency features + spectrum from image."""
gray = np.array(img.convert("L").resize((256, 256)), dtype=np.float32) / 255.0
# Benford's law features (18-dim): first-digit distribution of DCT coefficients
from scipy.fft import dct as scipy_dct
dct_coeffs = scipy_dct(scipy_dct(gray, axis=0, norm='ortho'), axis=1, norm='ortho').flatten()
abs_coeffs = np.abs(dct_coeffs[dct_coeffs != 0])
if len(abs_coeffs) > 0:
first_digits = (abs_coeffs / (10 ** np.floor(np.log10(abs_coeffs + 1e-10)))).astype(int)
first_digits = np.clip(first_digits, 1, 9)
benford = np.bincount(first_digits, minlength=10)[1:].astype(np.float32)
benford = benford / (benford.sum() + 1e-8)
# Expected Benford distribution
expected = np.log10(1 + 1.0 / np.arange(1, 10)).astype(np.float32)
benford_features = np.concatenate([benford, expected]) # 18-dim
else:
benford_features = np.zeros(18, dtype=np.float32)
# DCT statistics (10-dim)
dct_flat = scipy_dct(scipy_dct(gray, axis=0, norm='ortho'), axis=1, norm='ortho')
dct_stats = np.array([
dct_flat.mean(), dct_flat.std(), np.median(dct_flat),
dct_flat.min(), dct_flat.max(),
np.percentile(dct_flat, 25), np.percentile(dct_flat, 75),
float(np.abs(dct_flat).mean()),
float((np.abs(dct_flat) > 0.01).sum()) / dct_flat.size, # sparsity
float(np.abs(dct_flat[:32, :32]).sum()) / (float(np.abs(dct_flat).sum()) + 1e-8), # low-freq energy ratio
], dtype=np.float32)
# Wavelet features (26-dim) β€” simplified using numpy
# Use multi-level Haar wavelet approximation
def haar_wavelet_1level(x):
h = x.shape[0] // 2
w = x.shape[1] // 2
ll = (x[0::2, 0::2] + x[1::2, 0::2] + x[0::2, 1::2] + x[1::2, 1::2]) / 4
lh = (x[0::2, 0::2] - x[1::2, 0::2] + x[0::2, 1::2] - x[1::2, 1::2]) / 4
hl = (x[0::2, 0::2] + x[1::2, 0::2] - x[0::2, 1::2] - x[1::2, 1::2]) / 4
hh = (x[0::2, 0::2] - x[1::2, 0::2] - x[0::2, 1::2] + x[1::2, 1::2]) / 4
return ll, lh, hl, hh
wavelet_feats = []
current = gray
for level in range(3):
if current.shape[0] < 4 or current.shape[1] < 4:
break
h = (current.shape[0] // 2) * 2
w = (current.shape[1] // 2) * 2
current_even = current[:h, :w]
ll, lh, hl, hh = haar_wavelet_1level(current_even)
for band in [lh, hl, hh]:
wavelet_feats.extend([band.mean(), band.std()])
# Energy ratio
total_energy = float(np.sum(current_even ** 2)) + 1e-8
detail_energy = float(np.sum(lh**2) + np.sum(hl**2) + np.sum(hh**2))
wavelet_feats.append(detail_energy / total_energy)
wavelet_feats.append(float(np.abs(hh).mean())) # diagonal detail
current = ll
# Pad to 26 dims
wavelet_arr = np.array(wavelet_feats[:26], dtype=np.float32)
if len(wavelet_arr) < 26:
wavelet_arr = np.pad(wavelet_arr, (0, 26 - len(wavelet_arr)))
# Combine all handcrafted features (18 + 10 + 26 = 54)
handcrafted = np.concatenate([benford_features, dct_stats, wavelet_arr])
# Spectrum image (64x64)
f = np.fft.fft2(gray)
fshift = np.fft.fftshift(f)
mag = np.log1p(np.abs(fshift))
# Resize to 64x64
from PIL import Image as PILImage
mag_img = PILImage.fromarray(((mag - mag.min()) / (mag.max() - mag.min() + 1e-8) * 255).astype(np.uint8))
mag_img = mag_img.resize((64, 64))
spectrum = np.array(mag_img, dtype=np.float32) / 255.0
return handcrafted, spectrum
def infer_frequency(img: Image.Image) -> float:
load_frequency()
handcrafted, spectrum = extract_frequency_features(img)
freq_data = models["frequency"]
hc_tensor = torch.from_numpy(handcrafted).unsqueeze(0).float().to(DEVICE)
spec_tensor = torch.from_numpy(spectrum).unsqueeze(0).unsqueeze(0).float().to(DEVICE)
with torch.no_grad():
logit = freq_data["model"](hc_tensor, spec_tensor)
return torch.sigmoid(logit).item()
# ═══════════════════════════════════════════
# 5. AUDIO V10 (Wav2Vec2 full, 3-class)
# ═══════════════════════════════════════════
class AudioV10Model(nn.Module):
"""Full Wav2Vec2 backbone + classification head for 3-class audio deepfake."""
def __init__(self, num_classes=3):
super().__init__()
self.backbone = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
self.head = nn.Sequential(
nn.LayerNorm(768), # head.0
nn.Linear(768, 256), # head.1
nn.ReLU(), # head.2
nn.Dropout(0.3), # head.3
nn.Linear(256, 128), # head.4
nn.ReLU(), # head.5
nn.Dropout(0.2), # head.6
nn.Linear(128, num_classes), # head.7
)
def forward(self, input_values, attention_mask=None):
outputs = self.backbone(input_values=input_values, attention_mask=attention_mask)
hidden = outputs.last_hidden_state.mean(dim=1)
return self.head(hidden)
audio_v10 = None
audio_feature_extractor = None
def load_audio_v10():
global audio_v10, audio_feature_extractor
if audio_v10 is not None:
return
print("Loading audio_deepfake_v10 (Wav2Vec2 full, 3-class)...", flush=True)
audio_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-base")
ckpt = torch.load("audio_deepfake_v10.pt", map_location=DEVICE, weights_only=False)
num_classes = ckpt.get("num_classes", 3)
audio_v10 = AudioV10Model(num_classes=num_classes)
sd = ckpt.get("model_state_dict", ckpt)
if isinstance(sd, dict) and any(k.startswith("backbone.") or k.startswith("head.") for k in sd.keys()):
audio_v10.load_state_dict(sd, strict=True)
else:
# Try loading just the head
audio_v10.load_state_dict(sd, strict=False)
audio_v10.eval()
print(f" ok audio_v10 (val_acc={ckpt.get('val_acc', 'N/A')}, classes={ckpt.get('classes', [])})", flush=True)
# ═══════════════════════════════════════════
# 6. AUDIO V1 FALLBACK (Wav2Vec2 probe)
# ═══════════════════════════════════════════
class AudioClassifierV1(nn.Module):
def __init__(self, input_dim=768, hidden_dim=256):
super().__init__()
self.classifier = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(hidden_dim, 64),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(64, 2)
)
def forward(self, x):
return self.classifier(x)
audio_v1_backbone = None
audio_v1_classifier = None
def load_audio_v1():
global audio_v1_backbone, audio_v1_classifier, audio_feature_extractor
if audio_v1_classifier is not None:
return
print("Loading audio_deepfake_model (v1 fallback)...", flush=True)
if audio_feature_extractor is None:
audio_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-base")
audio_v1_backbone = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
audio_v1_backbone.eval()
audio_v1_classifier = AudioClassifierV1()
state = torch.load("audio_deepfake_model.pt", map_location="cpu", weights_only=False)
sd = state.get("classifier_state_dict", state)
if isinstance(sd, dict) and any(k[0].isdigit() for k in sd.keys()):
sd = {f"classifier.{k}": v for k, v in sd.items()}
audio_v1_classifier.load_state_dict(sd)
audio_v1_classifier.eval()
print(" ok audio_v1", flush=True)
def process_audio(data_bytes, max_seconds=10):
"""Common audio preprocessing."""
if audio_feature_extractor is None:
load_audio_v10()
audio_np, sr = librosa.load(io.BytesIO(data_bytes), sr=16000, mono=True)
max_len = 16000 * max_seconds
if len(audio_np) > max_len:
audio_np = audio_np[:max_len]
elif len(audio_np) < 16000:
audio_np = np.pad(audio_np, (0, 16000 - len(audio_np)))
return audio_np
# ═══════════════════════════════════════════
# ENDPOINTS
# ═══════════════════════════════════════════
@app.get("/health")
def health():
loaded = list(models.keys())
if audio_v10 is not None:
loaded.append("audio_v10")
if audio_v1_classifier is not None:
loaded.append("audio_v1")
return {
"status": "ok",
"version": "10.1.0",
"models_loaded": loaded,
"available_models": ["image_ensemble_v2", "ai_gen", "spectral", "frequency", "audio_v10", "audio_v1"],
"platforms": list(ENSEMBLE_CONFIGS.keys()),
"endpoints": ["/image", "/audio", "/video", "/ocr", "/health"],
"ensemble": "4-model image ensemble (EfficientNet-B0 + EfficientNet-B3 + DualStreamSpectral + FrequencyAnalyzer)"
}
@app.post("/image")
async def analyze_image(
file: UploadFile = File(...),
platform: str = Query("clean", pattern="^(clean|whatsapp|instagram|telegram|screenshot)$")
):
"""Analyze image using full V10 4-model ensemble with platform-specific weighting."""
try:
data = await file.read()
img = Image.open(io.BytesIO(data)).convert("RGB")
config = ENSEMBLE_CONFIGS[platform]
weights = config["weights"]
threshold = config["threshold"]
infer_fns = {
"image_ensemble_v2": infer_image_ensemble_v2,
"ai_gen": infer_ai_gen,
"spectral": infer_spectral,
"frequency": infer_frequency,
}
ensemble_scores = {}
for name, weight in weights.items():
try:
ensemble_scores[name] = infer_fns[name](img)
except Exception as e:
print(f"Model {name} failed: {e}", flush=True)
traceback.print_exc()
ensemble_scores[name] = 0.5
# Weighted average
total_weight = 0
weighted_score = 0
for name, weight in weights.items():
if name in ensemble_scores:
weighted_score += ensemble_scores[name] * weight
total_weight += weight
fake_prob = weighted_score / total_weight if total_weight > 0 else 0.5
real_prob = 1 - fake_prob
verdict = "fake" if fake_prob > threshold else "real"
confidence = max(fake_prob, real_prob)
return {
"verdict": verdict,
"confidence": round(confidence, 4),
"scores": {"real": round(real_prob, 4), "fake": round(fake_prob, 4)},
"ensemble_scores": {k: round(v, 4) for k, v in ensemble_scores.items()},
"platform": platform,
"models_used": list(weights.keys()),
"threshold": threshold,
"model": "kaeva-v10-full-ensemble",
"version": "10.1.0"
}
except Exception as e:
traceback.print_exc()
raise HTTPException(500, str(e))
@app.post("/audio")
async def analyze_audio(file: UploadFile = File(...)):
"""Analyze audio using V10 3-class model (real/tts/vc), with v1 binary fallback."""
try:
data = await file.read()
audio_np = process_audio(data, max_seconds=10)
results = {}
# V10: 3-class (real, tts, vc)
try:
load_audio_v10()
inputs = audio_feature_extractor(audio_np, sampling_rate=16000, return_tensors="pt", padding=True)
with torch.no_grad():
logits = audio_v10(inputs["input_values"])
probs = torch.softmax(logits, dim=-1)[0]
classes = ["real", "tts", "vc"]
class_scores = {c: round(float(probs[i]), 4) for i, c in enumerate(classes)}
fake_prob = 1 - float(probs[0]) # tts + vc combined
results["v10"] = {
"class_scores": class_scores,
"predicted_class": classes[int(probs.argmax())],
"fake_prob": round(fake_prob, 4),
}
except Exception as e:
print(f"Audio V10 failed: {e}", flush=True)
traceback.print_exc()
results["v10"] = {"error": str(e)}
# V1 fallback: binary
try:
load_audio_v1()
inputs = audio_feature_extractor(audio_np, sampling_rate=16000, return_tensors="pt", padding=True)
with torch.no_grad():
outputs = audio_v1_backbone(**inputs)
hidden = outputs.last_hidden_state.mean(dim=1)
logits = audio_v1_classifier(hidden)
probs = torch.softmax(logits, dim=-1)[0]
results["v1"] = {
"real": round(float(probs[0]), 4),
"fake": round(float(probs[1]), 4),
}
except Exception as e:
print(f"Audio V1 failed: {e}", flush=True)
results["v1"] = {"error": str(e)}
# Combined verdict: prefer v10, fallback to v1
v10 = results.get("v10", {})
v1 = results.get("v1", {})
if "error" not in v10:
fake_prob = v10["fake_prob"]
verdict = "fake" if fake_prob > 0.5 else "real"
detail = v10["predicted_class"]
elif "error" not in v1:
fake_prob = v1["fake"]
verdict = "fake" if fake_prob > 0.5 else "real"
detail = "binary"
else:
fake_prob = 0.5
verdict = "inconclusive"
detail = "both models failed"
return {
"verdict": verdict,
"confidence": round(max(fake_prob, 1 - fake_prob), 4),
"scores": {"real": round(1 - fake_prob, 4), "fake": round(fake_prob, 4)},
"detail": detail,
"model_results": results,
"model": "kaeva-v10-audio",
"version": "10.1.0"
}
except Exception as e:
traceback.print_exc()
raise HTTPException(500, str(e))
@app.post("/video")
async def analyze_video(
file: UploadFile = File(...),
platform: str = Query("clean", pattern="^(clean|whatsapp|instagram|telegram|screenshot)$")
):
"""Analyze video: extract frames -> full 4-model ensemble, extract audio -> v10 audio."""
start_time = time.time()
try:
data = await file.read()
with tempfile.TemporaryDirectory() as tmpdir:
video_path = os.path.join(tmpdir, "input_video")
with open(video_path, "wb") as f:
f.write(data)
# Get video info
probe_cmd = ["ffprobe", "-v", "quiet", "-print_format", "json", "-show_format", "-show_streams", video_path]
probe_result = subprocess.run(probe_cmd, capture_output=True, text=True, timeout=15)
video_info = json.loads(probe_result.stdout) if probe_result.returncode == 0 else {}
duration = float(video_info.get("format", {}).get("duration", 0))
resolution = "unknown"
fps = 30.0
has_audio = False
for stream in video_info.get("streams", []):
if stream.get("codec_type") == "video":
resolution = f"{stream.get('width', '?')}x{stream.get('height', '?')}"
try:
num, den = stream.get("r_frame_rate", "30/1").split("/")
fps = float(num) / float(den)
except:
pass
elif stream.get("codec_type") == "audio":
has_audio = True
max_frames = min(8, max(1, int(duration))) if duration > 0 else 5
frame_interval = max(1.0, duration / max_frames) if duration > 0 else 1.0
frame_dir = os.path.join(tmpdir, "frames")
os.makedirs(frame_dir)
ffmpeg_cmd = [
"ffmpeg", "-i", video_path,
"-vf", f"fps=1/{frame_interval}",
"-frames:v", str(max_frames),
"-q:v", "2",
os.path.join(frame_dir, "frame_%03d.jpg"),
"-y", "-loglevel", "error"
]
subprocess.run(ffmpeg_cmd, timeout=30, check=True)
# Run full 4-model ensemble on each frame
frame_files = sorted([f for f in os.listdir(frame_dir) if f.endswith(".jpg")])
config = ENSEMBLE_CONFIGS[platform]
weights = config["weights"]
infer_fns = {
"image_ensemble_v2": infer_image_ensemble_v2,
"ai_gen": infer_ai_gen,
"spectral": infer_spectral,
"frequency": infer_frequency,
}
frame_scores = []
per_model_scores = {name: [] for name in weights}
for fname in frame_files:
fpath = os.path.join(frame_dir, fname)
img = Image.open(fpath).convert("RGB")
frame_model_scores = {}
for name in weights:
try:
score = infer_fns[name](img)
except:
score = 0.5
frame_model_scores[name] = score
per_model_scores[name].append(score)
weighted = sum(frame_model_scores.get(n, 0.5) * w for n, w in weights.items())
total_w = sum(weights.values())
frame_scores.append(round(weighted / total_w, 4))
temporal_consistency = 1.0 - float(np.std(frame_scores)) if len(frame_scores) > 1 else 1.0
avg_frame_score = float(np.mean(frame_scores)) if frame_scores else 0.5
# Audio analysis
audio_result = None
if has_audio:
audio_path = os.path.join(tmpdir, "audio.wav")
audio_cmd = ["ffmpeg", "-i", video_path, "-vn", "-acodec", "pcm_s16le", "-ar", "16000", "-ac", "1", audio_path, "-y", "-loglevel", "error"]
audio_extract = subprocess.run(audio_cmd, timeout=20)
if audio_extract.returncode == 0 and os.path.exists(audio_path):
try:
with open(audio_path, "rb") as af:
audio_bytes = af.read()
audio_np = process_audio(audio_bytes, max_seconds=10)
# V10 audio
load_audio_v10()
inputs = audio_feature_extractor(audio_np, sampling_rate=16000, return_tensors="pt", padding=True)
with torch.no_grad():
logits = audio_v10(inputs["input_values"])
probs = torch.softmax(logits, dim=-1)[0]
classes = ["real", "tts", "vc"]
audio_fake_prob = 1 - float(probs[0])
audio_result = {
"verdict": "fake" if audio_fake_prob > 0.5 else "real",
"confidence": round(max(audio_fake_prob, 1 - audio_fake_prob), 4),
"scores": {"real": round(float(probs[0]), 4), "fake": round(audio_fake_prob, 4)},
"predicted_class": classes[int(probs.argmax())],
"class_scores": {c: round(float(probs[i]), 4) for i, c in enumerate(classes)},
}
except Exception as ae:
print(f"Audio analysis error: {ae}", flush=True)
# Overall: 70% visual, 30% audio (if available)
if audio_result:
overall_score = avg_frame_score * 0.7 + audio_result["scores"]["fake"] * 0.3
else:
overall_score = avg_frame_score
flags = []
if avg_frame_score > 0.7:
flags.append("HIGH_FAKE_SCORE_ACROSS_FRAMES")
if temporal_consistency < 0.8:
flags.append("INCONSISTENT_FRAME_SCORES")
if audio_result and audio_result["scores"]["fake"] > 0.7:
flags.append("AUDIO_FAKE_DETECTED")
if audio_result and ((avg_frame_score > 0.5) != (audio_result["scores"]["fake"] > 0.5)):
flags.append("AUDIO_VISUAL_DISAGREEMENT")
verdict = "fake" if overall_score > config["threshold"] else "real"
return {
"verdict": verdict,
"confidence": round(max(overall_score, 1 - overall_score), 4),
"overall_score": round(overall_score, 4),
"frame_scores": frame_scores,
"per_model_averages": {name: round(float(np.mean(scores)), 4) for name, scores in per_model_scores.items() if scores},
"temporal_consistency": round(temporal_consistency, 4),
"frame_count": len(frame_scores),
"fps": round(fps, 2),
"resolution": resolution,
"duration_seconds": round(duration, 2),
"flags": flags,
"audio_analysis": audio_result,
"platform": platform,
"model": "kaeva-v10-full-ensemble",
"version": "10.1.0",
"processing_time_ms": int((time.time() - start_time) * 1000),
}
except subprocess.TimeoutExpired:
raise HTTPException(504, "Video processing timed out")
except Exception as e:
traceback.print_exc()
raise HTTPException(500, str(e))
@app.post("/ocr")
async def extract_text(file: UploadFile = File(...)):
"""Extract text from image using pytesseract OCR."""
try:
import pytesseract
data = await file.read()
img = Image.open(io.BytesIO(data))
text = pytesseract.image_to_string(img)
# Also get confidence data
ocr_data = pytesseract.image_to_data(img, output_type=pytesseract.Output.DICT)
words = []
for i, word in enumerate(ocr_data["text"]):
if word.strip():
words.append({
"text": word,
"confidence": ocr_data["conf"][i],
"x": ocr_data["left"][i],
"y": ocr_data["top"][i],
"w": ocr_data["width"][i],
"h": ocr_data["height"][i],
})
avg_conf = np.mean([w["confidence"] for w in words]) if words else 0
return {
"text": text.strip(),
"word_count": len(words),
"average_confidence": round(float(avg_conf), 2),
"words": words,
}
except ImportError:
raise HTTPException(501, "pytesseract not installed")
except Exception as e:
traceback.print_exc()
raise HTTPException(500, str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)