Veritas-AI / test_debug.py
Aditya-Jadhav150
Deploy explainable 9-feature XGBoost Fusion Engine and Dynamic Dashboard
f2584f0
Raw
History Blame Contribute Delete
4.34 kB
import torch
from transformers import AutoImageProcessor, AutoModelForImageClassification
from PIL import Image
from facenet_pytorch import MTCNN
from torchvision import transforms
import timm
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Evaluating on:", device)
mtcnn = MTCNN(keep_all=False, device=device)
# --- Load ViT ---
model_id = "prithivMLmods/Deep-Fake-Detector-v2-Model"
processor = AutoImageProcessor.from_pretrained(model_id)
vit_model = AutoModelForImageClassification.from_pretrained(model_id).to(device)
vit_model.eval()
# --- Load Local EfficientNet ---
eff_model = timm.create_model('efficientnet_b3', pretrained=False, num_classes=2)
eff_model.load_state_dict(torch.load("model_best.pth", map_location=device))
eff_model.to(device)
eff_model.eval()
eff_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def debug_ensemble(image_path):
print(f"\n======================================")
print(f"DEBUGGING INFERENCE OVER: {image_path}")
print(f"======================================")
image = Image.open(image_path).convert("RGB")
# Defaults
vit_image = image # Default to raw image
eff_image = image # Default to raw image
boxes, probs = mtcnn.detect(image)
if boxes is not None and len(boxes) > 0:
box = boxes[0]
# 1. Zero-Margin Crop (Exactly matches how extract_faces.py generated the EfficientNet training dataset)
bleft, btop, bright, bbottom = box[0], box[1], box[2], box[3]
eff_image = image.crop((int(bleft), int(btop), int(bright), int(bbottom)))
# 2. 15% Padded Crop (for ViT, if needed)
w, h = box[2] - box[0], box[3] - box[1]
b1, b2 = max(0, box[0] - w * 0.15), max(0, box[1] - h * 0.15)
b3, b4 = min(image.width, box[2] + w * 0.15), min(image.height, box[3] + h * 0.15)
vit_image = image.crop((int(b1), int(b2), int(b3), int(b4)))
else:
print(">> WARNING: MTCNN found no face! Both models falling back to raw image.")
print("\n--- 1. CLOUD ViT DIAGNOSTIC ---")
inputs = processor(images=vit_image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = vit_model(**inputs)
model_probs = torch.nn.functional.softmax(outputs.logits, dim=1)
labels = vit_model.config.id2label
vit_fake_prob = 0.0
vit_real_prob = 0.0
for idx, label_name in labels.items():
prob = model_probs[0][idx].item()
l = label_name.lower()
print(f">> IDX [{idx}] ({label_name}): {prob * 100:.2f}%")
if 'fake' in l or 'deepfake' in l or 'spoof' in l:
vit_fake_prob += prob
elif 'real' in l or 'pristine' in l:
vit_real_prob += prob
print(f">> ViT Final Fake Score: {vit_fake_prob * 100:.2f}%")
print(f">> ViT Final Real Score: {vit_real_prob * 100:.2f}%")
print("\n--- 2. LOCAL EFFICIENTNET DIAGNOSTIC ---")
eff_input = eff_transform(eff_image).unsqueeze(0).to(device)
with torch.no_grad():
eff_outputs = eff_model(eff_input)
eff_probs = torch.nn.functional.softmax(eff_outputs, dim=1)
eff_fake_prob = eff_probs[0][0].item() # 0 is Fake
eff_real_prob = eff_probs[0][1].item() # 1 is Real
print(f">> Local Model FAKE (IDX 0): {eff_fake_prob * 100:.2f}%")
print(f">> Local Model REAL (IDX 1): {eff_real_prob * 100:.2f}%")
print("\n--- 3. MATHEMATICAL ENSEMBLE ---")
# Current app.py weighting
final_fake_prob = (0.7 * vit_fake_prob) + (0.3 * eff_fake_prob)
final_real_prob = (0.7 * vit_real_prob) + (0.3 * eff_real_prob)
print(f"> Current Output [70% ViT / 30% Eff]: FAKE={final_fake_prob*100:.2f}%, REAL={final_real_prob*100:.2f}%")
# Proposed balanced weighting
bal_fake = (0.5 * vit_fake_prob) + (0.5 * eff_fake_prob)
bal_real = (0.5 * vit_real_prob) + (0.5 * eff_real_prob)
print(f"> Balanced Output [50% ViT / 50% Eff]: FAKE={bal_fake*100:.2f}%, REAL={bal_real*100:.2f}%")
if __name__ == "__main__":
import sys
if len(sys.argv) > 1:
debug_ensemble(sys.argv[1])
else:
print("Please provide an image path: python test_debug.py <path_to_image>")