""" Easy inference script for Fake Image Detection Usage: python inference.py --image path/to/image.jpg """ import torch from torchvision import transforms from PIL import Image import pickle import json import argparse from huggingface_hub import hf_hub_download from model import EnhancedFreqVAE, EdgeNormalizingFlow, SemanticDeepSVDD, Ensemble def load_models(device='cuda'): """Load all models from Hugging Face""" repo_id = "ash12321/fake-image-detection-ensemble" print("📥 Downloading models from Hugging Face...") # Load config config_path = hf_hub_download(repo_id=repo_id, filename="config.json") with open(config_path, 'r') as f: config = json.load(f) # Load PyTorch models print("Loading Frequency VAE...") freq_vae = EnhancedFreqVAE() vae_path = hf_hub_download(repo_id=repo_id, filename="freq_vae.pth") freq_vae.load_state_dict(torch.load(vae_path, map_location=device)) freq_vae.to(device) freq_vae.eval() print("Loading Edge Flow...") edge_flow = EdgeNormalizingFlow() flow_path = hf_hub_download(repo_id=repo_id, filename="edge_flow.pth") edge_flow.load_state_dict(torch.load(flow_path, map_location=device)) edge_flow.to(device) edge_flow.eval() print("Loading Semantic SVDD...") semantic_svdd = SemanticDeepSVDD() svdd_path = hf_hub_download(repo_id=repo_id, filename="semantic_svdd.pth") checkpoint = torch.load(svdd_path, map_location=device) semantic_svdd.load_state_dict(checkpoint['model']) semantic_svdd.center = checkpoint['center'] semantic_svdd.to(device) semantic_svdd.eval() # Load sklearn models print("Loading traditional ML models...") texture_path = hf_hub_download(repo_id=repo_id, filename="texture_ocsvm.pkl") with open(texture_path, 'rb') as f: texture_ocsvm = pickle.load(f) color_path = hf_hub_download(repo_id=repo_id, filename="color_model.pkl") with open(color_path, 'rb') as f: color_model = pickle.load(f) stat_path = hf_hub_download(repo_id=repo_id, filename="stat.pkl") with open(stat_path, 'rb') as f: stat = pickle.load(f) iforest_path = hf_hub_download(repo_id=repo_id, filename="iforest.pkl") with open(iforest_path, 'rb') as f: iforest = pickle.load(f) lof_path = hf_hub_download(repo_id=repo_id, filename="lof.pkl") with open(lof_path, 'rb') as f: lof = pickle.load(f) gmm_path = hf_hub_download(repo_id=repo_id, filename="gmm.pkl") with open(gmm_path, 'rb') as f: gmm = pickle.load(f) # Create ensemble models_dict = { 'freq_vae': freq_vae, 'texture_ocsvm': texture_ocsvm, 'color_model': color_model, 'edge_flow': edge_flow, 'semantic_svdd': semantic_svdd, 'stat': stat, 'iforest': iforest, 'lof': lof, 'gmm': gmm } ensemble = Ensemble(models_dict) ensemble.wts = config['weights'] ensemble.norms = config['norms'] ensemble.thresh = config['thresh'] print("✓ All models loaded!\n") return ensemble, device def predict_image(image_path, ensemble, device): """Predict if an image is fake""" # Load and preprocess image img = Image.open(image_path) img = img.resize((256, 256), Image.LANCZOS).convert('RGB') tfm = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225]) ]) img_tensor = tfm(img) # Predict is_fake, score, individual_scores = ensemble.predict(img_tensor, device) return { 'prediction': 'FAKE' if is_fake else 'REAL', 'confidence': abs(score), 'anomaly_score': score, 'individual_scores': individual_scores } if __name__ == "__main__": parser = argparse.ArgumentParser(description='Detect fake images') parser.add_argument('--image', type=str, required=True, help='Path to image') parser.add_argument('--device', type=str, default='cuda', help='Device (cuda/cpu)') args = parser.parse_args() # Check device device = args.device if torch.cuda.is_available() else 'cpu' print(f"Using device: {device}\n") # Load models ensemble, device = load_models(device) # Predict print(f"Analyzing: {args.image}") result = predict_image(args.image, ensemble, device) print("\n" + "="*50) print("RESULT") print("="*50) print(f"Prediction: {result['prediction']}") print(f"Confidence: {result['confidence']:.4f}") print(f"Anomaly Score: {result['anomaly_score']:.4f}") print(f"\nIndividual Model Scores:") for model, score in result['individual_scores'].items(): print(f" {model}: {score:.4f}") print("="*50)