|
|
""" |
|
|
Easy inference script for Fake Image Detection |
|
|
Usage: python inference.py --image path/to/image.jpg |
|
|
""" |
|
|
|
|
|
import torch |
|
|
from torchvision import transforms |
|
|
from PIL import Image |
|
|
import pickle |
|
|
import json |
|
|
import argparse |
|
|
from huggingface_hub import hf_hub_download |
|
|
from model import EnhancedFreqVAE, EdgeNormalizingFlow, SemanticDeepSVDD, Ensemble |
|
|
|
|
|
|
|
|
def load_models(device='cuda'): |
|
|
"""Load all models from Hugging Face""" |
|
|
repo_id = "ash12321/fake-image-detection-ensemble" |
|
|
|
|
|
print("📥 Downloading models from Hugging Face...") |
|
|
|
|
|
|
|
|
config_path = hf_hub_download(repo_id=repo_id, filename="config.json") |
|
|
with open(config_path, 'r') as f: |
|
|
config = json.load(f) |
|
|
|
|
|
|
|
|
print("Loading Frequency VAE...") |
|
|
freq_vae = EnhancedFreqVAE() |
|
|
vae_path = hf_hub_download(repo_id=repo_id, filename="freq_vae.pth") |
|
|
freq_vae.load_state_dict(torch.load(vae_path, map_location=device)) |
|
|
freq_vae.to(device) |
|
|
freq_vae.eval() |
|
|
|
|
|
print("Loading Edge Flow...") |
|
|
edge_flow = EdgeNormalizingFlow() |
|
|
flow_path = hf_hub_download(repo_id=repo_id, filename="edge_flow.pth") |
|
|
edge_flow.load_state_dict(torch.load(flow_path, map_location=device)) |
|
|
edge_flow.to(device) |
|
|
edge_flow.eval() |
|
|
|
|
|
print("Loading Semantic SVDD...") |
|
|
semantic_svdd = SemanticDeepSVDD() |
|
|
svdd_path = hf_hub_download(repo_id=repo_id, filename="semantic_svdd.pth") |
|
|
checkpoint = torch.load(svdd_path, map_location=device) |
|
|
semantic_svdd.load_state_dict(checkpoint['model']) |
|
|
semantic_svdd.center = checkpoint['center'] |
|
|
semantic_svdd.to(device) |
|
|
semantic_svdd.eval() |
|
|
|
|
|
|
|
|
print("Loading traditional ML models...") |
|
|
texture_path = hf_hub_download(repo_id=repo_id, filename="texture_ocsvm.pkl") |
|
|
with open(texture_path, 'rb') as f: |
|
|
texture_ocsvm = pickle.load(f) |
|
|
|
|
|
color_path = hf_hub_download(repo_id=repo_id, filename="color_model.pkl") |
|
|
with open(color_path, 'rb') as f: |
|
|
color_model = pickle.load(f) |
|
|
|
|
|
stat_path = hf_hub_download(repo_id=repo_id, filename="stat.pkl") |
|
|
with open(stat_path, 'rb') as f: |
|
|
stat = pickle.load(f) |
|
|
|
|
|
iforest_path = hf_hub_download(repo_id=repo_id, filename="iforest.pkl") |
|
|
with open(iforest_path, 'rb') as f: |
|
|
iforest = pickle.load(f) |
|
|
|
|
|
lof_path = hf_hub_download(repo_id=repo_id, filename="lof.pkl") |
|
|
with open(lof_path, 'rb') as f: |
|
|
lof = pickle.load(f) |
|
|
|
|
|
gmm_path = hf_hub_download(repo_id=repo_id, filename="gmm.pkl") |
|
|
with open(gmm_path, 'rb') as f: |
|
|
gmm = pickle.load(f) |
|
|
|
|
|
|
|
|
models_dict = { |
|
|
'freq_vae': freq_vae, |
|
|
'texture_ocsvm': texture_ocsvm, |
|
|
'color_model': color_model, |
|
|
'edge_flow': edge_flow, |
|
|
'semantic_svdd': semantic_svdd, |
|
|
'stat': stat, |
|
|
'iforest': iforest, |
|
|
'lof': lof, |
|
|
'gmm': gmm |
|
|
} |
|
|
|
|
|
ensemble = Ensemble(models_dict) |
|
|
ensemble.wts = config['weights'] |
|
|
ensemble.norms = config['norms'] |
|
|
ensemble.thresh = config['thresh'] |
|
|
|
|
|
print("✓ All models loaded!\n") |
|
|
return ensemble, device |
|
|
|
|
|
|
|
|
def predict_image(image_path, ensemble, device): |
|
|
"""Predict if an image is fake""" |
|
|
|
|
|
img = Image.open(image_path) |
|
|
img = img.resize((256, 256), Image.LANCZOS).convert('RGB') |
|
|
|
|
|
tfm = transforms.Compose([ |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225]) |
|
|
]) |
|
|
img_tensor = tfm(img) |
|
|
|
|
|
|
|
|
is_fake, score, individual_scores = ensemble.predict(img_tensor, device) |
|
|
|
|
|
return { |
|
|
'prediction': 'FAKE' if is_fake else 'REAL', |
|
|
'confidence': abs(score), |
|
|
'anomaly_score': score, |
|
|
'individual_scores': individual_scores |
|
|
} |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser(description='Detect fake images') |
|
|
parser.add_argument('--image', type=str, required=True, help='Path to image') |
|
|
parser.add_argument('--device', type=str, default='cuda', help='Device (cuda/cpu)') |
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
device = args.device if torch.cuda.is_available() else 'cpu' |
|
|
print(f"Using device: {device}\n") |
|
|
|
|
|
|
|
|
ensemble, device = load_models(device) |
|
|
|
|
|
|
|
|
print(f"Analyzing: {args.image}") |
|
|
result = predict_image(args.image, ensemble, device) |
|
|
|
|
|
print("\n" + "="*50) |
|
|
print("RESULT") |
|
|
print("="*50) |
|
|
print(f"Prediction: {result['prediction']}") |
|
|
print(f"Confidence: {result['confidence']:.4f}") |
|
|
print(f"Anomaly Score: {result['anomaly_score']:.4f}") |
|
|
print(f"\nIndividual Model Scores:") |
|
|
for model, score in result['individual_scores'].items(): |
|
|
print(f" {model}: {score:.4f}") |
|
|
print("="*50) |
|
|
|