ash12321 commited on
Commit
0a19a23
·
verified ·
1 Parent(s): bb9059d

Upload inference_example.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. inference_example.py +98 -0
inference_example.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Example: Using the model for deepfake detection
3
+ """
4
+
5
+ import torch
6
+ from torchvision import transforms
7
+ from PIL import Image
8
+ from model import load_model
9
+ import json
10
+
11
+ # Load model
12
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
+ model = load_model('model_best_checkpoint.ckpt', device=device)
14
+
15
+ # Load calibrated thresholds
16
+ with open('thresholds_calibrated.json', 'r') as f:
17
+ config = json.load(f)
18
+ threshold = config['reconstruction_thresholds']['thresholds']['balanced']['value']
19
+
20
+ print(f"Using threshold: {threshold:.6f}")
21
+
22
+ # Prepare image preprocessing
23
+ transform = transforms.Compose([
24
+ transforms.Resize((128, 128)),
25
+ transforms.ToTensor(),
26
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
27
+ ])
28
+
29
+ def detect_deepfake(image_path, model, threshold, device):
30
+ """
31
+ Detect if an image is likely a deepfake based on reconstruction error.
32
+
33
+ Args:
34
+ image_path: Path to image file
35
+ model: Loaded autoencoder model
36
+ threshold: MSE threshold for detection
37
+ device: torch device
38
+
39
+ Returns:
40
+ is_fake: Boolean indicating if image is likely fake
41
+ error: Reconstruction error value
42
+ confidence: Confidence score (0-1)
43
+ """
44
+ # Load and preprocess image
45
+ image = Image.open(image_path).convert('RGB')
46
+ input_tensor = transform(image).unsqueeze(0).to(device)
47
+
48
+ # Calculate reconstruction error
49
+ with torch.no_grad():
50
+ error = model.reconstruction_error(input_tensor, reduction='none')
51
+
52
+ error_value = error.item()
53
+ is_fake = error_value > threshold
54
+
55
+ # Calculate confidence (normalized error relative to threshold)
56
+ confidence = min(abs(error_value - threshold) / threshold, 1.0)
57
+
58
+ return is_fake, error_value, confidence
59
+
60
+ # Example usage
61
+ image_path = "test_image.jpg"
62
+ is_fake, error, confidence = detect_deepfake(image_path, model, threshold, device)
63
+
64
+ print(f"\nResults for: {image_path}")
65
+ print(f"Reconstruction Error: {error:.6f}")
66
+ print(f"Threshold: {threshold:.6f}")
67
+ print(f"Classification: {'FAKE' if is_fake else 'REAL'}")
68
+ print(f"Confidence: {confidence:.2%}")
69
+
70
+ # Batch processing example
71
+ def batch_detect(image_paths, model, threshold, device):
72
+ """Process multiple images at once"""
73
+ images = []
74
+ for path in image_paths:
75
+ img = Image.open(path).convert('RGB')
76
+ images.append(transform(img))
77
+
78
+ batch = torch.stack(images).to(device)
79
+
80
+ with torch.no_grad():
81
+ errors = model.reconstruction_error(batch, reduction='none')
82
+
83
+ results = []
84
+ for i, error in enumerate(errors):
85
+ is_fake = error.item() > threshold
86
+ results.append({
87
+ 'path': image_paths[i],
88
+ 'error': error.item(),
89
+ 'is_fake': is_fake
90
+ })
91
+
92
+ return results
93
+
94
+ # Example batch processing
95
+ # image_paths = ["img1.jpg", "img2.jpg", "img3.jpg"]
96
+ # results = batch_detect(image_paths, model, threshold, device)
97
+ # for r in results:
98
+ # print(f"{r['path']}: {'FAKE' if r['is_fake'] else 'REAL'} (error: {r['error']:.6f})")