ash12321's picture
Upload inference.py with huggingface_hub
e55a650 verified
"""
Easy inference script for Fake Image Detection
Usage: python inference.py --image path/to/image.jpg
"""
import torch
from torchvision import transforms
from PIL import Image
import pickle
import json
import argparse
from huggingface_hub import hf_hub_download
from model import EnhancedFreqVAE, EdgeNormalizingFlow, SemanticDeepSVDD, Ensemble
def load_models(device='cuda'):
"""Load all models from Hugging Face"""
repo_id = "ash12321/fake-image-detection-ensemble"
print("📥 Downloading models from Hugging Face...")
# Load config
config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
with open(config_path, 'r') as f:
config = json.load(f)
# Load PyTorch models
print("Loading Frequency VAE...")
freq_vae = EnhancedFreqVAE()
vae_path = hf_hub_download(repo_id=repo_id, filename="freq_vae.pth")
freq_vae.load_state_dict(torch.load(vae_path, map_location=device))
freq_vae.to(device)
freq_vae.eval()
print("Loading Edge Flow...")
edge_flow = EdgeNormalizingFlow()
flow_path = hf_hub_download(repo_id=repo_id, filename="edge_flow.pth")
edge_flow.load_state_dict(torch.load(flow_path, map_location=device))
edge_flow.to(device)
edge_flow.eval()
print("Loading Semantic SVDD...")
semantic_svdd = SemanticDeepSVDD()
svdd_path = hf_hub_download(repo_id=repo_id, filename="semantic_svdd.pth")
checkpoint = torch.load(svdd_path, map_location=device)
semantic_svdd.load_state_dict(checkpoint['model'])
semantic_svdd.center = checkpoint['center']
semantic_svdd.to(device)
semantic_svdd.eval()
# Load sklearn models
print("Loading traditional ML models...")
texture_path = hf_hub_download(repo_id=repo_id, filename="texture_ocsvm.pkl")
with open(texture_path, 'rb') as f:
texture_ocsvm = pickle.load(f)
color_path = hf_hub_download(repo_id=repo_id, filename="color_model.pkl")
with open(color_path, 'rb') as f:
color_model = pickle.load(f)
stat_path = hf_hub_download(repo_id=repo_id, filename="stat.pkl")
with open(stat_path, 'rb') as f:
stat = pickle.load(f)
iforest_path = hf_hub_download(repo_id=repo_id, filename="iforest.pkl")
with open(iforest_path, 'rb') as f:
iforest = pickle.load(f)
lof_path = hf_hub_download(repo_id=repo_id, filename="lof.pkl")
with open(lof_path, 'rb') as f:
lof = pickle.load(f)
gmm_path = hf_hub_download(repo_id=repo_id, filename="gmm.pkl")
with open(gmm_path, 'rb') as f:
gmm = pickle.load(f)
# Create ensemble
models_dict = {
'freq_vae': freq_vae,
'texture_ocsvm': texture_ocsvm,
'color_model': color_model,
'edge_flow': edge_flow,
'semantic_svdd': semantic_svdd,
'stat': stat,
'iforest': iforest,
'lof': lof,
'gmm': gmm
}
ensemble = Ensemble(models_dict)
ensemble.wts = config['weights']
ensemble.norms = config['norms']
ensemble.thresh = config['thresh']
print("✓ All models loaded!\n")
return ensemble, device
def predict_image(image_path, ensemble, device):
"""Predict if an image is fake"""
# Load and preprocess image
img = Image.open(image_path)
img = img.resize((256, 256), Image.LANCZOS).convert('RGB')
tfm = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225])
])
img_tensor = tfm(img)
# Predict
is_fake, score, individual_scores = ensemble.predict(img_tensor, device)
return {
'prediction': 'FAKE' if is_fake else 'REAL',
'confidence': abs(score),
'anomaly_score': score,
'individual_scores': individual_scores
}
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Detect fake images')
parser.add_argument('--image', type=str, required=True, help='Path to image')
parser.add_argument('--device', type=str, default='cuda', help='Device (cuda/cpu)')
args = parser.parse_args()
# Check device
device = args.device if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}\n")
# Load models
ensemble, device = load_models(device)
# Predict
print(f"Analyzing: {args.image}")
result = predict_image(args.image, ensemble, device)
print("\n" + "="*50)
print("RESULT")
print("="*50)
print(f"Prediction: {result['prediction']}")
print(f"Confidence: {result['confidence']:.4f}")
print(f"Anomaly Score: {result['anomaly_score']:.4f}")
print(f"\nIndividual Model Scores:")
for model, score in result['individual_scores'].items():
print(f" {model}: {score:.4f}")
print("="*50)