Spaces:
Sleeping
Sleeping
| """Kaeva Verify V10 β Full Ensemble Deepfake Detection API. | |
| Models loaded: | |
| - image_ensemble_v2 (EfficientNet-B0, 15.6MB) β general image deepfake | |
| - ai_gen_detector (EfficientNet-B3, 44.4MB) β AI-generated image detection | |
| - spectral_detector (Dual-stream ResNet18, 131MB) β frequency/spectral analysis | |
| - frequency_analyzer (MLP+CNN, 1.5MB) β DCT/wavelet/Benford features | |
| - audio_deepfake_v10 (Wav2Vec2 full, 361MB) β 3-class audio (real/tts/vc) | |
| - audio_deepfake_model (Wav2Vec2 probe, 0.8MB) β binary audio fallback | |
| Endpoints: | |
| POST /image β V10 ensemble image detection (4 models, platform-aware) | |
| POST /audio β Audio deepfake detection (v10 primary, v1 fallback) | |
| POST /video β Video: frame ensemble + audio analysis | |
| POST /ocr β Extract text from image via pytesseract | |
| GET /health β Health check | |
| """ | |
| import io, os, traceback, tempfile, subprocess, time, json | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from fastapi import FastAPI, UploadFile, File, HTTPException, Query | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from PIL import Image | |
| from torchvision import transforms | |
| from torchvision.models import efficientnet_b0, efficientnet_b3, resnet18 | |
| from transformers import Wav2Vec2Model, Wav2Vec2FeatureExtractor | |
| import librosa | |
| app = FastAPI(title="Kaeva Verify V10", version="10.1.0") | |
| app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) | |
| DEVICE = torch.device("cpu") | |
| # ββ Ensemble configs ββ | |
| ENSEMBLE_CONFIGS = { | |
| "clean": {"weights": {"image_ensemble_v2": 0.35, "ai_gen": 0.30, "spectral": 0.20, "frequency": 0.15}, "threshold": 0.50}, | |
| "whatsapp": {"weights": {"image_ensemble_v2": 0.40, "ai_gen": 0.25, "spectral": 0.20, "frequency": 0.15}, "threshold": 0.55}, | |
| "instagram": {"weights": {"image_ensemble_v2": 0.35, "ai_gen": 0.30, "spectral": 0.20, "frequency": 0.15}, "threshold": 0.50}, | |
| "telegram": {"weights": {"image_ensemble_v2": 0.35, "ai_gen": 0.30, "spectral": 0.20, "frequency": 0.15}, "threshold": 0.50}, | |
| "screenshot": {"weights": {"image_ensemble_v2": 0.40, "ai_gen": 0.25, "spectral": 0.15, "frequency": 0.20}, "threshold": 0.50}, | |
| } | |
| # ββ Transforms ββ | |
| img_transform_224 = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| img_transform_300 = transforms.Compose([ | |
| transforms.Resize((300, 300)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| # ββ Model registry ββ | |
| models = {} | |
| # βββββββββββββββββββββββββββββββββββββββββββ | |
| # 1. IMAGE ENSEMBLE V2 (EfficientNet-B0) | |
| # βββββββββββββββββββββββββββββββββββββββββββ | |
| def load_image_ensemble_v2(): | |
| if "image_ensemble_v2" in models: | |
| return | |
| print("Loading image_ensemble_v2 (EfficientNet-B0)...", flush=True) | |
| model = efficientnet_b0(weights=None) | |
| model.classifier[1] = nn.Linear(1280, 2) | |
| sd = torch.load("image_ensemble_v2.pt", map_location=DEVICE, weights_only=False) | |
| if isinstance(sd, dict) and "model_state_dict" in sd: | |
| sd = sd["model_state_dict"] | |
| model.load_state_dict(sd, strict=True) | |
| model.eval() | |
| models["image_ensemble_v2"] = model | |
| print(" ok image_ensemble_v2", flush=True) | |
| def infer_image_ensemble_v2(img: Image.Image) -> float: | |
| load_image_ensemble_v2() | |
| tensor = img_transform_224(img).unsqueeze(0).to(DEVICE) | |
| with torch.no_grad(): | |
| logits = models["image_ensemble_v2"](tensor) | |
| return torch.softmax(logits, dim=1)[0, 1].item() | |
| # βββββββββββββββββββββββββββββββββββββββββββ | |
| # 2. AI GEN DETECTOR (EfficientNet-B3) | |
| # βββββββββββββββββββββββββββββββββββββββββββ | |
| def load_ai_gen(): | |
| if "ai_gen" in models: | |
| return | |
| print("Loading ai_gen (EfficientNet-B3)...", flush=True) | |
| model = efficientnet_b3(weights=None) | |
| model.classifier = nn.Sequential( | |
| nn.Dropout(p=0.3, inplace=True), | |
| nn.Linear(1536, 512), | |
| nn.ReLU(inplace=True), | |
| nn.Dropout(p=0.2), | |
| nn.Linear(512, 2), | |
| ) | |
| ckpt = torch.load("ai_gen_detector.pt", map_location=DEVICE, weights_only=False) | |
| sd = ckpt.get("model_state_dict", ckpt) | |
| clean_sd = {} | |
| for k, v in sd.items(): | |
| new_k = k.replace("backbone.", "", 1) if k.startswith("backbone.") else k | |
| clean_sd[new_k] = v | |
| model.load_state_dict(clean_sd, strict=True) | |
| model.eval() | |
| models["ai_gen"] = model | |
| print(f" ok ai_gen (val_acc={ckpt.get('val_acc', 'N/A')})", flush=True) | |
| def infer_ai_gen(img: Image.Image) -> float: | |
| load_ai_gen() | |
| tensor = img_transform_300(img).unsqueeze(0).to(DEVICE) | |
| with torch.no_grad(): | |
| logits = models["ai_gen"](tensor) | |
| # Class 0 = fake, Class 1 = real for this model | |
| return torch.softmax(logits, dim=1)[0, 0].item() | |
| # βββββββββββββββββββββββββββββββββββββββββββ | |
| # 3. SPECTRAL DETECTOR (Dual-stream ResNet18) | |
| # βββββββββββββββββββββββββββββββββββββββββββ | |
| class DualStreamSpectral(nn.Module): | |
| """Two ResNet18 streams: (A) magnitude spectrum, (B) phase spectrum β fused classifier.""" | |
| def __init__(self): | |
| super().__init__() | |
| # Stream A: magnitude spectrum | |
| base_a = resnet18(weights=None) | |
| base_a.conv1 = nn.Conv2d(1, 64, 7, stride=2, padding=3, bias=False) | |
| base_a.fc = nn.Linear(512, 256) | |
| self.stream_a = base_a | |
| # Stream B: phase spectrum | |
| base_b = resnet18(weights=None) | |
| base_b.conv1 = nn.Conv2d(1, 64, 7, stride=2, padding=3, bias=False) | |
| base_b.fc = nn.Linear(512, 128) | |
| self.stream_b = base_b | |
| # Classifier: 256+128 = 384 β 1 | |
| self.classifier = nn.Sequential( | |
| nn.LayerNorm(384), # classifier.0 | |
| nn.Linear(384, 128), # classifier.2 | |
| nn.ReLU(), # classifier.3 | |
| nn.Dropout(0.3), # classifier.4 | |
| nn.Linear(128, 1), # classifier.5 | |
| ) | |
| def forward(self, mag, phase): | |
| a = self.stream_a(mag) | |
| b = self.stream_b(phase) | |
| fused = torch.cat([a, b], dim=1) | |
| return self.classifier(fused) | |
| def load_spectral(): | |
| if "spectral" in models: | |
| return | |
| print("Loading spectral_detector (DualStreamSpectral)...", flush=True) | |
| model = DualStreamSpectral() | |
| ckpt = torch.load("spectral_detector.pt", map_location=DEVICE, weights_only=False) | |
| sd = ckpt.get("model", ckpt) | |
| model.load_state_dict(sd, strict=True) | |
| model.eval() | |
| models["spectral"] = model | |
| print(f" ok spectral (best_f1={ckpt.get('best_f1', 'N/A')})", flush=True) | |
| def compute_spectrum(img: Image.Image): | |
| """Convert PIL image β grayscale β FFT β magnitude & phase tensors (1,1,224,224).""" | |
| gray = np.array(img.convert("L").resize((224, 224)), dtype=np.float32) / 255.0 | |
| f = np.fft.fft2(gray) | |
| fshift = np.fft.fftshift(f) | |
| mag = np.log1p(np.abs(fshift)) | |
| phase = np.angle(fshift) | |
| # Normalize | |
| mag = (mag - mag.mean()) / (mag.std() + 1e-8) | |
| phase = (phase - phase.mean()) / (phase.std() + 1e-8) | |
| mag_t = torch.from_numpy(mag).unsqueeze(0).unsqueeze(0).float() | |
| phase_t = torch.from_numpy(phase).unsqueeze(0).unsqueeze(0).float() | |
| return mag_t, phase_t | |
| def infer_spectral(img: Image.Image) -> float: | |
| load_spectral() | |
| mag_t, phase_t = compute_spectrum(img) | |
| with torch.no_grad(): | |
| logit = models["spectral"](mag_t.to(DEVICE), phase_t.to(DEVICE)) | |
| return torch.sigmoid(logit).item() | |
| # βββββββββββββββββββββββββββββββββββββββββββ | |
| # 4. FREQUENCY ANALYZER (MLP + CNN) | |
| # βββββββββββββββββββββββββββββββββββββββββββ | |
| class FrequencyAnalyzer(nn.Module): | |
| """Handcrafted frequency features (54-dim) through MLP + spectrum (64x64) through CNN β fused classifier.""" | |
| def __init__(self, handcrafted_dim=54, spectrum_size=64): | |
| super().__init__() | |
| # MLP for handcrafted features | |
| self.mlp = nn.Sequential( | |
| nn.Linear(handcrafted_dim, 128), # mlp.0 | |
| nn.BatchNorm1d(128), # mlp.1 | |
| nn.ReLU(), # mlp.2 | |
| nn.Dropout(0.3), # mlp.3 | |
| nn.Linear(128, 64), # mlp.4 | |
| nn.ReLU(), # mlp.5 | |
| ) | |
| # CNN for spectrum image (3 conv blocks with BN) | |
| self.cnn = nn.Sequential( | |
| nn.Conv2d(1, 32, 3, padding=1), # cnn.0 | |
| nn.BatchNorm2d(32), # cnn.1 | |
| nn.ReLU(), # cnn.2 | |
| nn.MaxPool2d(2), # cnn.3 -> 32x32 | |
| nn.Conv2d(32, 64, 3, padding=1), # cnn.4 | |
| nn.BatchNorm2d(64), # cnn.5 | |
| nn.ReLU(), # cnn.6 | |
| nn.MaxPool2d(2), # cnn.7 -> 16x16 | |
| nn.Conv2d(64, 128, 3, padding=1), # cnn.8 | |
| nn.BatchNorm2d(128), # cnn.9 | |
| nn.ReLU(), # cnn.10 | |
| nn.AdaptiveAvgPool2d(4), # cnn.11 -> 4x4 | |
| nn.Flatten(), # cnn.12 -> 128*4*4 = 2048 | |
| ) | |
| self.cnn_fc = nn.Linear(2048, 128) | |
| # Classifier: 64 (mlp) + 128 (cnn) = 192 | |
| self.classifier = nn.Sequential( | |
| nn.LayerNorm(192), # classifier.0 | |
| nn.Linear(192, 64), # classifier.2 | |
| nn.ReLU(), # classifier.3 | |
| nn.Dropout(0.3), # classifier.4 | |
| nn.Linear(64, 1), # classifier.5 | |
| ) | |
| def forward(self, handcrafted, spectrum): | |
| mlp_out = self.mlp(handcrafted) | |
| cnn_out = self.cnn(spectrum) | |
| cnn_out = self.cnn_fc(cnn_out) | |
| cnn_out = torch.relu(cnn_out) | |
| fused = torch.cat([mlp_out, cnn_out], dim=1) | |
| return self.classifier(fused) | |
| def load_frequency(): | |
| if "frequency" in models: | |
| return | |
| print("Loading frequency_analyzer...", flush=True) | |
| ckpt = torch.load("frequency_analyzer.pt", map_location=DEVICE, weights_only=False) | |
| config = ckpt.get("config", {}) | |
| fdims = ckpt.get("feature_dims", {}) | |
| hc_dim = fdims.get("handcrafted_total", 54) | |
| spec_size = fdims.get("spectrum_size", 64) | |
| model = FrequencyAnalyzer(handcrafted_dim=hc_dim, spectrum_size=spec_size) | |
| sd = ckpt.get("model", ckpt) | |
| model.load_state_dict(sd, strict=True) | |
| model.eval() | |
| models["frequency"] = {"model": model, "hc_dim": hc_dim, "spec_size": spec_size} | |
| print(f" ok frequency (best_f1={ckpt.get('best_f1', 'N/A')})", flush=True) | |
| def extract_frequency_features(img: Image.Image): | |
| """Extract handcrafted frequency features + spectrum from image.""" | |
| gray = np.array(img.convert("L").resize((256, 256)), dtype=np.float32) / 255.0 | |
| # Benford's law features (18-dim): first-digit distribution of DCT coefficients | |
| from scipy.fft import dct as scipy_dct | |
| dct_coeffs = scipy_dct(scipy_dct(gray, axis=0, norm='ortho'), axis=1, norm='ortho').flatten() | |
| abs_coeffs = np.abs(dct_coeffs[dct_coeffs != 0]) | |
| if len(abs_coeffs) > 0: | |
| first_digits = (abs_coeffs / (10 ** np.floor(np.log10(abs_coeffs + 1e-10)))).astype(int) | |
| first_digits = np.clip(first_digits, 1, 9) | |
| benford = np.bincount(first_digits, minlength=10)[1:].astype(np.float32) | |
| benford = benford / (benford.sum() + 1e-8) | |
| # Expected Benford distribution | |
| expected = np.log10(1 + 1.0 / np.arange(1, 10)).astype(np.float32) | |
| benford_features = np.concatenate([benford, expected]) # 18-dim | |
| else: | |
| benford_features = np.zeros(18, dtype=np.float32) | |
| # DCT statistics (10-dim) | |
| dct_flat = scipy_dct(scipy_dct(gray, axis=0, norm='ortho'), axis=1, norm='ortho') | |
| dct_stats = np.array([ | |
| dct_flat.mean(), dct_flat.std(), np.median(dct_flat), | |
| dct_flat.min(), dct_flat.max(), | |
| np.percentile(dct_flat, 25), np.percentile(dct_flat, 75), | |
| float(np.abs(dct_flat).mean()), | |
| float((np.abs(dct_flat) > 0.01).sum()) / dct_flat.size, # sparsity | |
| float(np.abs(dct_flat[:32, :32]).sum()) / (float(np.abs(dct_flat).sum()) + 1e-8), # low-freq energy ratio | |
| ], dtype=np.float32) | |
| # Wavelet features (26-dim) β simplified using numpy | |
| # Use multi-level Haar wavelet approximation | |
| def haar_wavelet_1level(x): | |
| h = x.shape[0] // 2 | |
| w = x.shape[1] // 2 | |
| ll = (x[0::2, 0::2] + x[1::2, 0::2] + x[0::2, 1::2] + x[1::2, 1::2]) / 4 | |
| lh = (x[0::2, 0::2] - x[1::2, 0::2] + x[0::2, 1::2] - x[1::2, 1::2]) / 4 | |
| hl = (x[0::2, 0::2] + x[1::2, 0::2] - x[0::2, 1::2] - x[1::2, 1::2]) / 4 | |
| hh = (x[0::2, 0::2] - x[1::2, 0::2] - x[0::2, 1::2] + x[1::2, 1::2]) / 4 | |
| return ll, lh, hl, hh | |
| wavelet_feats = [] | |
| current = gray | |
| for level in range(3): | |
| if current.shape[0] < 4 or current.shape[1] < 4: | |
| break | |
| h = (current.shape[0] // 2) * 2 | |
| w = (current.shape[1] // 2) * 2 | |
| current_even = current[:h, :w] | |
| ll, lh, hl, hh = haar_wavelet_1level(current_even) | |
| for band in [lh, hl, hh]: | |
| wavelet_feats.extend([band.mean(), band.std()]) | |
| # Energy ratio | |
| total_energy = float(np.sum(current_even ** 2)) + 1e-8 | |
| detail_energy = float(np.sum(lh**2) + np.sum(hl**2) + np.sum(hh**2)) | |
| wavelet_feats.append(detail_energy / total_energy) | |
| wavelet_feats.append(float(np.abs(hh).mean())) # diagonal detail | |
| current = ll | |
| # Pad to 26 dims | |
| wavelet_arr = np.array(wavelet_feats[:26], dtype=np.float32) | |
| if len(wavelet_arr) < 26: | |
| wavelet_arr = np.pad(wavelet_arr, (0, 26 - len(wavelet_arr))) | |
| # Combine all handcrafted features (18 + 10 + 26 = 54) | |
| handcrafted = np.concatenate([benford_features, dct_stats, wavelet_arr]) | |
| # Spectrum image (64x64) | |
| f = np.fft.fft2(gray) | |
| fshift = np.fft.fftshift(f) | |
| mag = np.log1p(np.abs(fshift)) | |
| # Resize to 64x64 | |
| from PIL import Image as PILImage | |
| mag_img = PILImage.fromarray(((mag - mag.min()) / (mag.max() - mag.min() + 1e-8) * 255).astype(np.uint8)) | |
| mag_img = mag_img.resize((64, 64)) | |
| spectrum = np.array(mag_img, dtype=np.float32) / 255.0 | |
| return handcrafted, spectrum | |
| def infer_frequency(img: Image.Image) -> float: | |
| load_frequency() | |
| handcrafted, spectrum = extract_frequency_features(img) | |
| freq_data = models["frequency"] | |
| hc_tensor = torch.from_numpy(handcrafted).unsqueeze(0).float().to(DEVICE) | |
| spec_tensor = torch.from_numpy(spectrum).unsqueeze(0).unsqueeze(0).float().to(DEVICE) | |
| with torch.no_grad(): | |
| logit = freq_data["model"](hc_tensor, spec_tensor) | |
| return torch.sigmoid(logit).item() | |
| # βββββββββββββββββββββββββββββββββββββββββββ | |
| # 5. AUDIO V10 (Wav2Vec2 full, 3-class) | |
| # βββββββββββββββββββββββββββββββββββββββββββ | |
| class AudioV10Model(nn.Module): | |
| """Full Wav2Vec2 backbone + classification head for 3-class audio deepfake.""" | |
| def __init__(self, num_classes=3): | |
| super().__init__() | |
| self.backbone = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base") | |
| self.head = nn.Sequential( | |
| nn.LayerNorm(768), # head.0 | |
| nn.Linear(768, 256), # head.1 | |
| nn.ReLU(), # head.2 | |
| nn.Dropout(0.3), # head.3 | |
| nn.Linear(256, 128), # head.4 | |
| nn.ReLU(), # head.5 | |
| nn.Dropout(0.2), # head.6 | |
| nn.Linear(128, num_classes), # head.7 | |
| ) | |
| def forward(self, input_values, attention_mask=None): | |
| outputs = self.backbone(input_values=input_values, attention_mask=attention_mask) | |
| hidden = outputs.last_hidden_state.mean(dim=1) | |
| return self.head(hidden) | |
| audio_v10 = None | |
| audio_feature_extractor = None | |
| def load_audio_v10(): | |
| global audio_v10, audio_feature_extractor | |
| if audio_v10 is not None: | |
| return | |
| print("Loading audio_deepfake_v10 (Wav2Vec2 full, 3-class)...", flush=True) | |
| audio_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-base") | |
| ckpt = torch.load("audio_deepfake_v10.pt", map_location=DEVICE, weights_only=False) | |
| num_classes = ckpt.get("num_classes", 3) | |
| audio_v10 = AudioV10Model(num_classes=num_classes) | |
| sd = ckpt.get("model_state_dict", ckpt) | |
| if isinstance(sd, dict) and any(k.startswith("backbone.") or k.startswith("head.") for k in sd.keys()): | |
| audio_v10.load_state_dict(sd, strict=True) | |
| else: | |
| # Try loading just the head | |
| audio_v10.load_state_dict(sd, strict=False) | |
| audio_v10.eval() | |
| print(f" ok audio_v10 (val_acc={ckpt.get('val_acc', 'N/A')}, classes={ckpt.get('classes', [])})", flush=True) | |
| # βββββββββββββββββββββββββββββββββββββββββββ | |
| # 6. AUDIO V1 FALLBACK (Wav2Vec2 probe) | |
| # βββββββββββββββββββββββββββββββββββββββββββ | |
| class AudioClassifierV1(nn.Module): | |
| def __init__(self, input_dim=768, hidden_dim=256): | |
| super().__init__() | |
| self.classifier = nn.Sequential( | |
| nn.Linear(input_dim, hidden_dim), | |
| nn.ReLU(), | |
| nn.Dropout(0.3), | |
| nn.Linear(hidden_dim, 64), | |
| nn.ReLU(), | |
| nn.Dropout(0.2), | |
| nn.Linear(64, 2) | |
| ) | |
| def forward(self, x): | |
| return self.classifier(x) | |
| audio_v1_backbone = None | |
| audio_v1_classifier = None | |
| def load_audio_v1(): | |
| global audio_v1_backbone, audio_v1_classifier, audio_feature_extractor | |
| if audio_v1_classifier is not None: | |
| return | |
| print("Loading audio_deepfake_model (v1 fallback)...", flush=True) | |
| if audio_feature_extractor is None: | |
| audio_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-base") | |
| audio_v1_backbone = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base") | |
| audio_v1_backbone.eval() | |
| audio_v1_classifier = AudioClassifierV1() | |
| state = torch.load("audio_deepfake_model.pt", map_location="cpu", weights_only=False) | |
| sd = state.get("classifier_state_dict", state) | |
| if isinstance(sd, dict) and any(k[0].isdigit() for k in sd.keys()): | |
| sd = {f"classifier.{k}": v for k, v in sd.items()} | |
| audio_v1_classifier.load_state_dict(sd) | |
| audio_v1_classifier.eval() | |
| print(" ok audio_v1", flush=True) | |
| def process_audio(data_bytes, max_seconds=10): | |
| """Common audio preprocessing.""" | |
| if audio_feature_extractor is None: | |
| load_audio_v10() | |
| audio_np, sr = librosa.load(io.BytesIO(data_bytes), sr=16000, mono=True) | |
| max_len = 16000 * max_seconds | |
| if len(audio_np) > max_len: | |
| audio_np = audio_np[:max_len] | |
| elif len(audio_np) < 16000: | |
| audio_np = np.pad(audio_np, (0, 16000 - len(audio_np))) | |
| return audio_np | |
| # βββββββββββββββββββββββββββββββββββββββββββ | |
| # ENDPOINTS | |
| # βββββββββββββββββββββββββββββββββββββββββββ | |
| def health(): | |
| loaded = list(models.keys()) | |
| if audio_v10 is not None: | |
| loaded.append("audio_v10") | |
| if audio_v1_classifier is not None: | |
| loaded.append("audio_v1") | |
| return { | |
| "status": "ok", | |
| "version": "10.1.0", | |
| "models_loaded": loaded, | |
| "available_models": ["image_ensemble_v2", "ai_gen", "spectral", "frequency", "audio_v10", "audio_v1"], | |
| "platforms": list(ENSEMBLE_CONFIGS.keys()), | |
| "endpoints": ["/image", "/audio", "/video", "/ocr", "/health"], | |
| "ensemble": "4-model image ensemble (EfficientNet-B0 + EfficientNet-B3 + DualStreamSpectral + FrequencyAnalyzer)" | |
| } | |
| async def analyze_image( | |
| file: UploadFile = File(...), | |
| platform: str = Query("clean", pattern="^(clean|whatsapp|instagram|telegram|screenshot)$") | |
| ): | |
| """Analyze image using full V10 4-model ensemble with platform-specific weighting.""" | |
| try: | |
| data = await file.read() | |
| img = Image.open(io.BytesIO(data)).convert("RGB") | |
| config = ENSEMBLE_CONFIGS[platform] | |
| weights = config["weights"] | |
| threshold = config["threshold"] | |
| infer_fns = { | |
| "image_ensemble_v2": infer_image_ensemble_v2, | |
| "ai_gen": infer_ai_gen, | |
| "spectral": infer_spectral, | |
| "frequency": infer_frequency, | |
| } | |
| ensemble_scores = {} | |
| for name, weight in weights.items(): | |
| try: | |
| ensemble_scores[name] = infer_fns[name](img) | |
| except Exception as e: | |
| print(f"Model {name} failed: {e}", flush=True) | |
| traceback.print_exc() | |
| ensemble_scores[name] = 0.5 | |
| # Weighted average | |
| total_weight = 0 | |
| weighted_score = 0 | |
| for name, weight in weights.items(): | |
| if name in ensemble_scores: | |
| weighted_score += ensemble_scores[name] * weight | |
| total_weight += weight | |
| fake_prob = weighted_score / total_weight if total_weight > 0 else 0.5 | |
| real_prob = 1 - fake_prob | |
| verdict = "fake" if fake_prob > threshold else "real" | |
| confidence = max(fake_prob, real_prob) | |
| return { | |
| "verdict": verdict, | |
| "confidence": round(confidence, 4), | |
| "scores": {"real": round(real_prob, 4), "fake": round(fake_prob, 4)}, | |
| "ensemble_scores": {k: round(v, 4) for k, v in ensemble_scores.items()}, | |
| "platform": platform, | |
| "models_used": list(weights.keys()), | |
| "threshold": threshold, | |
| "model": "kaeva-v10-full-ensemble", | |
| "version": "10.1.0" | |
| } | |
| except Exception as e: | |
| traceback.print_exc() | |
| raise HTTPException(500, str(e)) | |
| async def analyze_audio(file: UploadFile = File(...)): | |
| """Analyze audio using V10 3-class model (real/tts/vc), with v1 binary fallback.""" | |
| try: | |
| data = await file.read() | |
| audio_np = process_audio(data, max_seconds=10) | |
| results = {} | |
| # V10: 3-class (real, tts, vc) | |
| try: | |
| load_audio_v10() | |
| inputs = audio_feature_extractor(audio_np, sampling_rate=16000, return_tensors="pt", padding=True) | |
| with torch.no_grad(): | |
| logits = audio_v10(inputs["input_values"]) | |
| probs = torch.softmax(logits, dim=-1)[0] | |
| classes = ["real", "tts", "vc"] | |
| class_scores = {c: round(float(probs[i]), 4) for i, c in enumerate(classes)} | |
| fake_prob = 1 - float(probs[0]) # tts + vc combined | |
| results["v10"] = { | |
| "class_scores": class_scores, | |
| "predicted_class": classes[int(probs.argmax())], | |
| "fake_prob": round(fake_prob, 4), | |
| } | |
| except Exception as e: | |
| print(f"Audio V10 failed: {e}", flush=True) | |
| traceback.print_exc() | |
| results["v10"] = {"error": str(e)} | |
| # V1 fallback: binary | |
| try: | |
| load_audio_v1() | |
| inputs = audio_feature_extractor(audio_np, sampling_rate=16000, return_tensors="pt", padding=True) | |
| with torch.no_grad(): | |
| outputs = audio_v1_backbone(**inputs) | |
| hidden = outputs.last_hidden_state.mean(dim=1) | |
| logits = audio_v1_classifier(hidden) | |
| probs = torch.softmax(logits, dim=-1)[0] | |
| results["v1"] = { | |
| "real": round(float(probs[0]), 4), | |
| "fake": round(float(probs[1]), 4), | |
| } | |
| except Exception as e: | |
| print(f"Audio V1 failed: {e}", flush=True) | |
| results["v1"] = {"error": str(e)} | |
| # Combined verdict: prefer v10, fallback to v1 | |
| v10 = results.get("v10", {}) | |
| v1 = results.get("v1", {}) | |
| if "error" not in v10: | |
| fake_prob = v10["fake_prob"] | |
| verdict = "fake" if fake_prob > 0.5 else "real" | |
| detail = v10["predicted_class"] | |
| elif "error" not in v1: | |
| fake_prob = v1["fake"] | |
| verdict = "fake" if fake_prob > 0.5 else "real" | |
| detail = "binary" | |
| else: | |
| fake_prob = 0.5 | |
| verdict = "inconclusive" | |
| detail = "both models failed" | |
| return { | |
| "verdict": verdict, | |
| "confidence": round(max(fake_prob, 1 - fake_prob), 4), | |
| "scores": {"real": round(1 - fake_prob, 4), "fake": round(fake_prob, 4)}, | |
| "detail": detail, | |
| "model_results": results, | |
| "model": "kaeva-v10-audio", | |
| "version": "10.1.0" | |
| } | |
| except Exception as e: | |
| traceback.print_exc() | |
| raise HTTPException(500, str(e)) | |
| async def analyze_video( | |
| file: UploadFile = File(...), | |
| platform: str = Query("clean", pattern="^(clean|whatsapp|instagram|telegram|screenshot)$") | |
| ): | |
| """Analyze video: extract frames -> full 4-model ensemble, extract audio -> v10 audio.""" | |
| start_time = time.time() | |
| try: | |
| data = await file.read() | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| video_path = os.path.join(tmpdir, "input_video") | |
| with open(video_path, "wb") as f: | |
| f.write(data) | |
| # Get video info | |
| probe_cmd = ["ffprobe", "-v", "quiet", "-print_format", "json", "-show_format", "-show_streams", video_path] | |
| probe_result = subprocess.run(probe_cmd, capture_output=True, text=True, timeout=15) | |
| video_info = json.loads(probe_result.stdout) if probe_result.returncode == 0 else {} | |
| duration = float(video_info.get("format", {}).get("duration", 0)) | |
| resolution = "unknown" | |
| fps = 30.0 | |
| has_audio = False | |
| for stream in video_info.get("streams", []): | |
| if stream.get("codec_type") == "video": | |
| resolution = f"{stream.get('width', '?')}x{stream.get('height', '?')}" | |
| try: | |
| num, den = stream.get("r_frame_rate", "30/1").split("/") | |
| fps = float(num) / float(den) | |
| except: | |
| pass | |
| elif stream.get("codec_type") == "audio": | |
| has_audio = True | |
| max_frames = min(8, max(1, int(duration))) if duration > 0 else 5 | |
| frame_interval = max(1.0, duration / max_frames) if duration > 0 else 1.0 | |
| frame_dir = os.path.join(tmpdir, "frames") | |
| os.makedirs(frame_dir) | |
| ffmpeg_cmd = [ | |
| "ffmpeg", "-i", video_path, | |
| "-vf", f"fps=1/{frame_interval}", | |
| "-frames:v", str(max_frames), | |
| "-q:v", "2", | |
| os.path.join(frame_dir, "frame_%03d.jpg"), | |
| "-y", "-loglevel", "error" | |
| ] | |
| subprocess.run(ffmpeg_cmd, timeout=30, check=True) | |
| # Run full 4-model ensemble on each frame | |
| frame_files = sorted([f for f in os.listdir(frame_dir) if f.endswith(".jpg")]) | |
| config = ENSEMBLE_CONFIGS[platform] | |
| weights = config["weights"] | |
| infer_fns = { | |
| "image_ensemble_v2": infer_image_ensemble_v2, | |
| "ai_gen": infer_ai_gen, | |
| "spectral": infer_spectral, | |
| "frequency": infer_frequency, | |
| } | |
| frame_scores = [] | |
| per_model_scores = {name: [] for name in weights} | |
| for fname in frame_files: | |
| fpath = os.path.join(frame_dir, fname) | |
| img = Image.open(fpath).convert("RGB") | |
| frame_model_scores = {} | |
| for name in weights: | |
| try: | |
| score = infer_fns[name](img) | |
| except: | |
| score = 0.5 | |
| frame_model_scores[name] = score | |
| per_model_scores[name].append(score) | |
| weighted = sum(frame_model_scores.get(n, 0.5) * w for n, w in weights.items()) | |
| total_w = sum(weights.values()) | |
| frame_scores.append(round(weighted / total_w, 4)) | |
| temporal_consistency = 1.0 - float(np.std(frame_scores)) if len(frame_scores) > 1 else 1.0 | |
| avg_frame_score = float(np.mean(frame_scores)) if frame_scores else 0.5 | |
| # Audio analysis | |
| audio_result = None | |
| if has_audio: | |
| audio_path = os.path.join(tmpdir, "audio.wav") | |
| audio_cmd = ["ffmpeg", "-i", video_path, "-vn", "-acodec", "pcm_s16le", "-ar", "16000", "-ac", "1", audio_path, "-y", "-loglevel", "error"] | |
| audio_extract = subprocess.run(audio_cmd, timeout=20) | |
| if audio_extract.returncode == 0 and os.path.exists(audio_path): | |
| try: | |
| with open(audio_path, "rb") as af: | |
| audio_bytes = af.read() | |
| audio_np = process_audio(audio_bytes, max_seconds=10) | |
| # V10 audio | |
| load_audio_v10() | |
| inputs = audio_feature_extractor(audio_np, sampling_rate=16000, return_tensors="pt", padding=True) | |
| with torch.no_grad(): | |
| logits = audio_v10(inputs["input_values"]) | |
| probs = torch.softmax(logits, dim=-1)[0] | |
| classes = ["real", "tts", "vc"] | |
| audio_fake_prob = 1 - float(probs[0]) | |
| audio_result = { | |
| "verdict": "fake" if audio_fake_prob > 0.5 else "real", | |
| "confidence": round(max(audio_fake_prob, 1 - audio_fake_prob), 4), | |
| "scores": {"real": round(float(probs[0]), 4), "fake": round(audio_fake_prob, 4)}, | |
| "predicted_class": classes[int(probs.argmax())], | |
| "class_scores": {c: round(float(probs[i]), 4) for i, c in enumerate(classes)}, | |
| } | |
| except Exception as ae: | |
| print(f"Audio analysis error: {ae}", flush=True) | |
| # Overall: 70% visual, 30% audio (if available) | |
| if audio_result: | |
| overall_score = avg_frame_score * 0.7 + audio_result["scores"]["fake"] * 0.3 | |
| else: | |
| overall_score = avg_frame_score | |
| flags = [] | |
| if avg_frame_score > 0.7: | |
| flags.append("HIGH_FAKE_SCORE_ACROSS_FRAMES") | |
| if temporal_consistency < 0.8: | |
| flags.append("INCONSISTENT_FRAME_SCORES") | |
| if audio_result and audio_result["scores"]["fake"] > 0.7: | |
| flags.append("AUDIO_FAKE_DETECTED") | |
| if audio_result and ((avg_frame_score > 0.5) != (audio_result["scores"]["fake"] > 0.5)): | |
| flags.append("AUDIO_VISUAL_DISAGREEMENT") | |
| verdict = "fake" if overall_score > config["threshold"] else "real" | |
| return { | |
| "verdict": verdict, | |
| "confidence": round(max(overall_score, 1 - overall_score), 4), | |
| "overall_score": round(overall_score, 4), | |
| "frame_scores": frame_scores, | |
| "per_model_averages": {name: round(float(np.mean(scores)), 4) for name, scores in per_model_scores.items() if scores}, | |
| "temporal_consistency": round(temporal_consistency, 4), | |
| "frame_count": len(frame_scores), | |
| "fps": round(fps, 2), | |
| "resolution": resolution, | |
| "duration_seconds": round(duration, 2), | |
| "flags": flags, | |
| "audio_analysis": audio_result, | |
| "platform": platform, | |
| "model": "kaeva-v10-full-ensemble", | |
| "version": "10.1.0", | |
| "processing_time_ms": int((time.time() - start_time) * 1000), | |
| } | |
| except subprocess.TimeoutExpired: | |
| raise HTTPException(504, "Video processing timed out") | |
| except Exception as e: | |
| traceback.print_exc() | |
| raise HTTPException(500, str(e)) | |
| async def extract_text(file: UploadFile = File(...)): | |
| """Extract text from image using pytesseract OCR.""" | |
| try: | |
| import pytesseract | |
| data = await file.read() | |
| img = Image.open(io.BytesIO(data)) | |
| text = pytesseract.image_to_string(img) | |
| # Also get confidence data | |
| ocr_data = pytesseract.image_to_data(img, output_type=pytesseract.Output.DICT) | |
| words = [] | |
| for i, word in enumerate(ocr_data["text"]): | |
| if word.strip(): | |
| words.append({ | |
| "text": word, | |
| "confidence": ocr_data["conf"][i], | |
| "x": ocr_data["left"][i], | |
| "y": ocr_data["top"][i], | |
| "w": ocr_data["width"][i], | |
| "h": ocr_data["height"][i], | |
| }) | |
| avg_conf = np.mean([w["confidence"] for w in words]) if words else 0 | |
| return { | |
| "text": text.strip(), | |
| "word_count": len(words), | |
| "average_confidence": round(float(avg_conf), 2), | |
| "words": words, | |
| } | |
| except ImportError: | |
| raise HTTPException(501, "pytesseract not installed") | |
| except Exception as e: | |
| traceback.print_exc() | |
| raise HTTPException(500, str(e)) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |