Upload inference.py with huggingface_hub
Browse files- inference.py +148 -0
inference.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Easy inference script for Fake Image Detection
|
| 3 |
+
Usage: python inference.py --image path/to/image.jpg
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torchvision import transforms
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import pickle
|
| 10 |
+
import json
|
| 11 |
+
import argparse
|
| 12 |
+
from huggingface_hub import hf_hub_download
|
| 13 |
+
from model import EnhancedFreqVAE, EdgeNormalizingFlow, SemanticDeepSVDD, Ensemble
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def load_models(device='cuda'):
|
| 17 |
+
"""Load all models from Hugging Face"""
|
| 18 |
+
repo_id = "ash12321/fake-image-detection-ensemble"
|
| 19 |
+
|
| 20 |
+
print("📥 Downloading models from Hugging Face...")
|
| 21 |
+
|
| 22 |
+
# Load config
|
| 23 |
+
config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
|
| 24 |
+
with open(config_path, 'r') as f:
|
| 25 |
+
config = json.load(f)
|
| 26 |
+
|
| 27 |
+
# Load PyTorch models
|
| 28 |
+
print("Loading Frequency VAE...")
|
| 29 |
+
freq_vae = EnhancedFreqVAE()
|
| 30 |
+
vae_path = hf_hub_download(repo_id=repo_id, filename="freq_vae.pth")
|
| 31 |
+
freq_vae.load_state_dict(torch.load(vae_path, map_location=device))
|
| 32 |
+
freq_vae.to(device)
|
| 33 |
+
freq_vae.eval()
|
| 34 |
+
|
| 35 |
+
print("Loading Edge Flow...")
|
| 36 |
+
edge_flow = EdgeNormalizingFlow()
|
| 37 |
+
flow_path = hf_hub_download(repo_id=repo_id, filename="edge_flow.pth")
|
| 38 |
+
edge_flow.load_state_dict(torch.load(flow_path, map_location=device))
|
| 39 |
+
edge_flow.to(device)
|
| 40 |
+
edge_flow.eval()
|
| 41 |
+
|
| 42 |
+
print("Loading Semantic SVDD...")
|
| 43 |
+
semantic_svdd = SemanticDeepSVDD()
|
| 44 |
+
svdd_path = hf_hub_download(repo_id=repo_id, filename="semantic_svdd.pth")
|
| 45 |
+
checkpoint = torch.load(svdd_path, map_location=device)
|
| 46 |
+
semantic_svdd.load_state_dict(checkpoint['model'])
|
| 47 |
+
semantic_svdd.center = checkpoint['center']
|
| 48 |
+
semantic_svdd.to(device)
|
| 49 |
+
semantic_svdd.eval()
|
| 50 |
+
|
| 51 |
+
# Load sklearn models
|
| 52 |
+
print("Loading traditional ML models...")
|
| 53 |
+
texture_path = hf_hub_download(repo_id=repo_id, filename="texture_ocsvm.pkl")
|
| 54 |
+
with open(texture_path, 'rb') as f:
|
| 55 |
+
texture_ocsvm = pickle.load(f)
|
| 56 |
+
|
| 57 |
+
color_path = hf_hub_download(repo_id=repo_id, filename="color_model.pkl")
|
| 58 |
+
with open(color_path, 'rb') as f:
|
| 59 |
+
color_model = pickle.load(f)
|
| 60 |
+
|
| 61 |
+
stat_path = hf_hub_download(repo_id=repo_id, filename="stat.pkl")
|
| 62 |
+
with open(stat_path, 'rb') as f:
|
| 63 |
+
stat = pickle.load(f)
|
| 64 |
+
|
| 65 |
+
iforest_path = hf_hub_download(repo_id=repo_id, filename="iforest.pkl")
|
| 66 |
+
with open(iforest_path, 'rb') as f:
|
| 67 |
+
iforest = pickle.load(f)
|
| 68 |
+
|
| 69 |
+
lof_path = hf_hub_download(repo_id=repo_id, filename="lof.pkl")
|
| 70 |
+
with open(lof_path, 'rb') as f:
|
| 71 |
+
lof = pickle.load(f)
|
| 72 |
+
|
| 73 |
+
gmm_path = hf_hub_download(repo_id=repo_id, filename="gmm.pkl")
|
| 74 |
+
with open(gmm_path, 'rb') as f:
|
| 75 |
+
gmm = pickle.load(f)
|
| 76 |
+
|
| 77 |
+
# Create ensemble
|
| 78 |
+
models_dict = {
|
| 79 |
+
'freq_vae': freq_vae,
|
| 80 |
+
'texture_ocsvm': texture_ocsvm,
|
| 81 |
+
'color_model': color_model,
|
| 82 |
+
'edge_flow': edge_flow,
|
| 83 |
+
'semantic_svdd': semantic_svdd,
|
| 84 |
+
'stat': stat,
|
| 85 |
+
'iforest': iforest,
|
| 86 |
+
'lof': lof,
|
| 87 |
+
'gmm': gmm
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
ensemble = Ensemble(models_dict)
|
| 91 |
+
ensemble.wts = config['weights']
|
| 92 |
+
ensemble.norms = config['norms']
|
| 93 |
+
ensemble.thresh = config['thresh']
|
| 94 |
+
|
| 95 |
+
print("✓ All models loaded!\n")
|
| 96 |
+
return ensemble, device
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def predict_image(image_path, ensemble, device):
|
| 100 |
+
"""Predict if an image is fake"""
|
| 101 |
+
# Load and preprocess image
|
| 102 |
+
img = Image.open(image_path)
|
| 103 |
+
img = img.resize((256, 256), Image.LANCZOS).convert('RGB')
|
| 104 |
+
|
| 105 |
+
tfm = transforms.Compose([
|
| 106 |
+
transforms.ToTensor(),
|
| 107 |
+
transforms.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225])
|
| 108 |
+
])
|
| 109 |
+
img_tensor = tfm(img)
|
| 110 |
+
|
| 111 |
+
# Predict
|
| 112 |
+
is_fake, score, individual_scores = ensemble.predict(img_tensor, device)
|
| 113 |
+
|
| 114 |
+
return {
|
| 115 |
+
'prediction': 'FAKE' if is_fake else 'REAL',
|
| 116 |
+
'confidence': abs(score),
|
| 117 |
+
'anomaly_score': score,
|
| 118 |
+
'individual_scores': individual_scores
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
if __name__ == "__main__":
|
| 123 |
+
parser = argparse.ArgumentParser(description='Detect fake images')
|
| 124 |
+
parser.add_argument('--image', type=str, required=True, help='Path to image')
|
| 125 |
+
parser.add_argument('--device', type=str, default='cuda', help='Device (cuda/cpu)')
|
| 126 |
+
args = parser.parse_args()
|
| 127 |
+
|
| 128 |
+
# Check device
|
| 129 |
+
device = args.device if torch.cuda.is_available() else 'cpu'
|
| 130 |
+
print(f"Using device: {device}\n")
|
| 131 |
+
|
| 132 |
+
# Load models
|
| 133 |
+
ensemble, device = load_models(device)
|
| 134 |
+
|
| 135 |
+
# Predict
|
| 136 |
+
print(f"Analyzing: {args.image}")
|
| 137 |
+
result = predict_image(args.image, ensemble, device)
|
| 138 |
+
|
| 139 |
+
print("\n" + "="*50)
|
| 140 |
+
print("RESULT")
|
| 141 |
+
print("="*50)
|
| 142 |
+
print(f"Prediction: {result['prediction']}")
|
| 143 |
+
print(f"Confidence: {result['confidence']:.4f}")
|
| 144 |
+
print(f"Anomaly Score: {result['anomaly_score']:.4f}")
|
| 145 |
+
print(f"\nIndividual Model Scores:")
|
| 146 |
+
for model, score in result['individual_scores'].items():
|
| 147 |
+
print(f" {model}: {score:.4f}")
|
| 148 |
+
print("="*50)
|