ash12321 commited on
Commit
e55a650
·
verified ·
1 Parent(s): 121343d

Upload inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. inference.py +148 -0
inference.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Easy inference script for Fake Image Detection
3
+ Usage: python inference.py --image path/to/image.jpg
4
+ """
5
+
6
+ import torch
7
+ from torchvision import transforms
8
+ from PIL import Image
9
+ import pickle
10
+ import json
11
+ import argparse
12
+ from huggingface_hub import hf_hub_download
13
+ from model import EnhancedFreqVAE, EdgeNormalizingFlow, SemanticDeepSVDD, Ensemble
14
+
15
+
16
+ def load_models(device='cuda'):
17
+ """Load all models from Hugging Face"""
18
+ repo_id = "ash12321/fake-image-detection-ensemble"
19
+
20
+ print("📥 Downloading models from Hugging Face...")
21
+
22
+ # Load config
23
+ config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
24
+ with open(config_path, 'r') as f:
25
+ config = json.load(f)
26
+
27
+ # Load PyTorch models
28
+ print("Loading Frequency VAE...")
29
+ freq_vae = EnhancedFreqVAE()
30
+ vae_path = hf_hub_download(repo_id=repo_id, filename="freq_vae.pth")
31
+ freq_vae.load_state_dict(torch.load(vae_path, map_location=device))
32
+ freq_vae.to(device)
33
+ freq_vae.eval()
34
+
35
+ print("Loading Edge Flow...")
36
+ edge_flow = EdgeNormalizingFlow()
37
+ flow_path = hf_hub_download(repo_id=repo_id, filename="edge_flow.pth")
38
+ edge_flow.load_state_dict(torch.load(flow_path, map_location=device))
39
+ edge_flow.to(device)
40
+ edge_flow.eval()
41
+
42
+ print("Loading Semantic SVDD...")
43
+ semantic_svdd = SemanticDeepSVDD()
44
+ svdd_path = hf_hub_download(repo_id=repo_id, filename="semantic_svdd.pth")
45
+ checkpoint = torch.load(svdd_path, map_location=device)
46
+ semantic_svdd.load_state_dict(checkpoint['model'])
47
+ semantic_svdd.center = checkpoint['center']
48
+ semantic_svdd.to(device)
49
+ semantic_svdd.eval()
50
+
51
+ # Load sklearn models
52
+ print("Loading traditional ML models...")
53
+ texture_path = hf_hub_download(repo_id=repo_id, filename="texture_ocsvm.pkl")
54
+ with open(texture_path, 'rb') as f:
55
+ texture_ocsvm = pickle.load(f)
56
+
57
+ color_path = hf_hub_download(repo_id=repo_id, filename="color_model.pkl")
58
+ with open(color_path, 'rb') as f:
59
+ color_model = pickle.load(f)
60
+
61
+ stat_path = hf_hub_download(repo_id=repo_id, filename="stat.pkl")
62
+ with open(stat_path, 'rb') as f:
63
+ stat = pickle.load(f)
64
+
65
+ iforest_path = hf_hub_download(repo_id=repo_id, filename="iforest.pkl")
66
+ with open(iforest_path, 'rb') as f:
67
+ iforest = pickle.load(f)
68
+
69
+ lof_path = hf_hub_download(repo_id=repo_id, filename="lof.pkl")
70
+ with open(lof_path, 'rb') as f:
71
+ lof = pickle.load(f)
72
+
73
+ gmm_path = hf_hub_download(repo_id=repo_id, filename="gmm.pkl")
74
+ with open(gmm_path, 'rb') as f:
75
+ gmm = pickle.load(f)
76
+
77
+ # Create ensemble
78
+ models_dict = {
79
+ 'freq_vae': freq_vae,
80
+ 'texture_ocsvm': texture_ocsvm,
81
+ 'color_model': color_model,
82
+ 'edge_flow': edge_flow,
83
+ 'semantic_svdd': semantic_svdd,
84
+ 'stat': stat,
85
+ 'iforest': iforest,
86
+ 'lof': lof,
87
+ 'gmm': gmm
88
+ }
89
+
90
+ ensemble = Ensemble(models_dict)
91
+ ensemble.wts = config['weights']
92
+ ensemble.norms = config['norms']
93
+ ensemble.thresh = config['thresh']
94
+
95
+ print("✓ All models loaded!\n")
96
+ return ensemble, device
97
+
98
+
99
+ def predict_image(image_path, ensemble, device):
100
+ """Predict if an image is fake"""
101
+ # Load and preprocess image
102
+ img = Image.open(image_path)
103
+ img = img.resize((256, 256), Image.LANCZOS).convert('RGB')
104
+
105
+ tfm = transforms.Compose([
106
+ transforms.ToTensor(),
107
+ transforms.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225])
108
+ ])
109
+ img_tensor = tfm(img)
110
+
111
+ # Predict
112
+ is_fake, score, individual_scores = ensemble.predict(img_tensor, device)
113
+
114
+ return {
115
+ 'prediction': 'FAKE' if is_fake else 'REAL',
116
+ 'confidence': abs(score),
117
+ 'anomaly_score': score,
118
+ 'individual_scores': individual_scores
119
+ }
120
+
121
+
122
+ if __name__ == "__main__":
123
+ parser = argparse.ArgumentParser(description='Detect fake images')
124
+ parser.add_argument('--image', type=str, required=True, help='Path to image')
125
+ parser.add_argument('--device', type=str, default='cuda', help='Device (cuda/cpu)')
126
+ args = parser.parse_args()
127
+
128
+ # Check device
129
+ device = args.device if torch.cuda.is_available() else 'cpu'
130
+ print(f"Using device: {device}\n")
131
+
132
+ # Load models
133
+ ensemble, device = load_models(device)
134
+
135
+ # Predict
136
+ print(f"Analyzing: {args.image}")
137
+ result = predict_image(args.image, ensemble, device)
138
+
139
+ print("\n" + "="*50)
140
+ print("RESULT")
141
+ print("="*50)
142
+ print(f"Prediction: {result['prediction']}")
143
+ print(f"Confidence: {result['confidence']:.4f}")
144
+ print(f"Anomaly Score: {result['anomaly_score']:.4f}")
145
+ print(f"\nIndividual Model Scores:")
146
+ for model, score in result['individual_scores'].items():
147
+ print(f" {model}: {score:.4f}")
148
+ print("="*50)