LoliRimuru commited on
Commit
516b412
·
verified ·
1 Parent(s): d38655f

Upload 2 files

Browse files
Files changed (2) hide show
  1. inference-example.py +259 -0
  2. model.pt +3 -0
inference-example.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import json
4
+ from pathlib import Path
5
+ from typing import List, Dict
6
+ import random
7
+ import tqdm
8
+ from PIL import Image
9
+ from torch.utils.data import Dataset, DataLoader
10
+ from torchvision import transforms
11
+
12
+
13
+ # ==================== TEST CONFIGURATION (EDIT THESE) ====================
14
+ class TestConfig:
15
+ # Folder paths - edit these directly
16
+ AI_IMAGE_DIR = "/path/to/ai-images"
17
+ REAL_IMAGE_DIR = "path/to/images"
18
+ CHECKPOINT_PATH = "./checkpoints/model.pt"
19
+
20
+ # Test parameters
21
+ SAMPLE_SIZE = 400 # How many images to randomly sample from each folder
22
+ CROP_SIZE = 512 # Must match training crop size
23
+ BATCH_SIZE = 1 # Adjust based on GPU memory
24
+ DEVICE = "cpu" # or "cuda"
25
+
26
+ # Model heads (match training config)
27
+ MODELS = ['flux', 'flux2', 'sdxl', 'sd15']
28
+
29
+
30
+ # ==================== MODEL DEFINITION ====================
31
+ class BAILU(nn.Module):
32
+ """Same model architecture as training"""
33
+
34
+ def __init__(self):
35
+ super().__init__()
36
+ self.conv_blocks = nn.Sequential(
37
+ nn.Conv2d(3, 16, kernel_size=4, stride=1, padding=0), nn.GELU(),
38
+ nn.Conv2d(16, 32, kernel_size=4, stride=1, padding=0), nn.GELU(),
39
+ nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0), nn.GELU(),
40
+ nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=0), nn.GELU(),
41
+ nn.Conv2d(128, 256, kernel_size=4, stride=4, padding=0), nn.GELU(),
42
+ nn.Conv2d(256, 256, kernel_size=4, stride=4, padding=0), nn.GELU(),
43
+ nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=0), nn.GELU(),
44
+ nn.AdaptiveAvgPool2d(1)
45
+ )
46
+ self.head = nn.Sequential(
47
+ nn.Linear(256, 32), nn.GELU(),
48
+ nn.Linear(32, 4), # 4 heads: flux, flux2, sdxl, sd15
49
+ )
50
+
51
+ def forward(self, x):
52
+ features = self.conv_blocks(x) # (B, 256, 1, 1)
53
+ features = features.view(features.size(0), -1)
54
+ return self.head(features) # (B, 4)
55
+
56
+
57
+ # ==================== TEST DATASET ====================
58
+ class TestDataset(Dataset):
59
+ """Loads and processes images from AI and Real folders"""
60
+
61
+ def __init__(self, ai_paths: List[Path], real_paths: List[Path], sample_size: int):
62
+ # Randomly sample images from each category
63
+ ai_sample = random.sample(ai_paths, min(sample_size, len(ai_paths))) if ai_paths else []
64
+ real_sample = random.sample(real_paths, min(sample_size, len(real_paths))) if real_paths else []
65
+
66
+ self.image_paths = ai_sample + real_sample
67
+ self.labels = [1] * len(ai_sample) + [0] * len(real_sample) # 1=AI, 0=Real
68
+
69
+ # Inference transform: deterministic pad + center crop
70
+ self.transform = transforms.Compose([
71
+ transforms.CenterCrop(TestConfig.CROP_SIZE),
72
+ transforms.ToTensor(),
73
+ ])
74
+
75
+ def __len__(self):
76
+ return len(self.image_paths)
77
+
78
+ def __getitem__(self, idx):
79
+ path = self.image_paths[idx]
80
+ try:
81
+ with Image.open(path) as img:
82
+ image = img.convert('RGB')
83
+ image_tensor = self.transform(image)
84
+ return {
85
+ 'image': image_tensor,
86
+ 'label': self.labels[idx],
87
+ 'path': str(path)
88
+ }
89
+ except Exception as e:
90
+ print(f"Warning: Could not load {path} - {e}")
91
+ # Return a dummy image and mark as error
92
+ dummy = torch.zeros(3, TestConfig.CROP_SIZE, TestConfig.CROP_SIZE)
93
+ return {'image': dummy, 'label': self.labels[idx], 'path': str(path), 'error': True}
94
+
95
+
96
+ # ==================== EVALUATION FUNCTION ====================
97
+ def evaluate_model():
98
+ """Main evaluation loop"""
99
+ print("=" * 60)
100
+ print("BAILU Model Test Evaluation")
101
+ print("=" * 60)
102
+ print(f"AI folder: {TestConfig.AI_IMAGE_DIR}")
103
+ print(f"Real folder: {TestConfig.REAL_IMAGE_DIR}")
104
+ print(f"Checkpoint: {TestConfig.CHECKPOINT_PATH}")
105
+ print(f"Sample size: {TestConfig.SAMPLE_SIZE} images per class")
106
+
107
+ # Setup device
108
+ device = torch.device(TestConfig.DEVICE)
109
+ torch.manual_seed(42) # For reproducible sampling
110
+
111
+ # Load model
112
+ print("\n📦 Loading model...")
113
+ model = BAILU().to(device)
114
+
115
+ if not Path(TestConfig.CHECKPOINT_PATH).exists():
116
+ raise FileNotFoundError(f"Checkpoint not found: {TestConfig.CHECKPOINT_PATH}")
117
+
118
+ checkpoint = torch.load(TestConfig.CHECKPOINT_PATH, map_location=device)
119
+ model.load_state_dict(checkpoint['model_state_dict'])
120
+ model.eval()
121
+ print(f"✓ Model loaded (epoch {checkpoint.get('epoch', 'unknown')})")
122
+
123
+ # Load image paths
124
+ print("\n📂 Scanning folders...")
125
+ ai_paths = []
126
+ real_paths = []
127
+ for ext in ['*.png', '*.jpg', '*.jpeg']:
128
+ ai_paths.extend(Path(TestConfig.AI_IMAGE_DIR).rglob(ext))
129
+ real_paths.extend(Path(TestConfig.REAL_IMAGE_DIR).rglob(ext))
130
+
131
+ print(f"Found {len(ai_paths)} AI images, {len(real_paths)} real images")
132
+
133
+ if not ai_paths and not real_paths:
134
+ raise ValueError("No images found! Check folder paths.")
135
+
136
+ # Create dataset and dataloader
137
+ test_dataset = TestDataset(ai_paths, real_paths, TestConfig.SAMPLE_SIZE)
138
+ test_loader = DataLoader(
139
+ test_dataset,
140
+ batch_size=TestConfig.BATCH_SIZE,
141
+ shuffle=False,
142
+ num_workers=0 # Simpler for single-threaded inference
143
+ )
144
+
145
+ print(f"\n🧪 Evaluating {len(test_dataset)} images...")
146
+
147
+ # Metrics tracking
148
+ total_correct = 0
149
+ total_samples = 0
150
+ ai_correct = 0
151
+ real_correct = 0
152
+ ai_total = 0
153
+ real_total = 0
154
+
155
+ # Per-format tracking
156
+ num_formats = 4
157
+ per_format_ai_correct = torch.zeros(num_formats, device=device)
158
+ per_format_real_correct = torch.zeros(num_formats, device=device)
159
+ ai_count = 0
160
+ real_count = 0
161
+
162
+ # Run inference
163
+ with torch.no_grad():
164
+ pbar = tqdm.tqdm(test_loader, desc="Processing", unit="batch")
165
+ for batch in pbar:
166
+ images = batch['image'].to(device)
167
+ labels = batch['label'].to(device)
168
+
169
+ # Forward pass
170
+ predictions = model(images) # (B, 4)
171
+ probs = torch.sigmoid(predictions)
172
+
173
+ # Classification rule: AI if ANY head > 0.5
174
+ max_probs, _ = probs.max(dim=1)
175
+ pred_labels = (max_probs > 0.5).long()
176
+
177
+ # Update overall metrics
178
+ correct = (pred_labels == labels).float()
179
+ total_correct += correct.sum().item()
180
+ total_samples += len(labels)
181
+
182
+ # Update per-class metrics
183
+ ai_mask = labels == 1
184
+ real_mask = labels == 0
185
+
186
+ ai_correct += correct[ai_mask].sum().item()
187
+ real_correct += correct[real_mask].sum().item()
188
+ ai_total += ai_mask.sum().item()
189
+ real_total += real_mask.sum().item()
190
+
191
+ # Per-format metrics
192
+ if ai_mask.any():
193
+ ai_probs = probs[ai_mask]
194
+ per_format_ai_correct += (ai_probs > 0.5).sum(dim=0)
195
+ ai_count += ai_probs.shape[0]
196
+
197
+ if real_mask.any():
198
+ real_probs = probs[real_mask]
199
+ per_format_real_correct += (real_probs <= 0.5).sum(dim=0)
200
+ real_count += real_probs.shape[0]
201
+
202
+ # Update progress bar
203
+ current_acc = total_correct / total_samples * 100 if total_samples > 0 else 0
204
+ pbar.set_postfix_str(f"Acc: {current_acc:.2f}%")
205
+
206
+ # Calculate final metrics
207
+ print("\n" + "=" * 60)
208
+ print("RESULTS")
209
+ print("=" * 60)
210
+
211
+ overall_acc = total_correct / total_samples * 100
212
+ ai_acc = ai_correct / ai_total * 100 if ai_total > 0 else 0
213
+ real_acc = real_correct / real_total * 100 if real_total > 0 else 0
214
+
215
+ print(f"Overall Accuracy: {overall_acc:.2f}% ({total_correct:.0f}/{total_samples})")
216
+ print(f"AI Detection Rate: {ai_acc:.2f}% ({ai_correct:.0f}/{ai_total})")
217
+ print(f"Real Accuracy: {real_acc:.2f}% ({real_correct:.0f}/{real_total})")
218
+
219
+ # Per-format results
220
+ per_format_ai_acc = (per_format_ai_correct / ai_count * 100).cpu().tolist() if ai_count > 0 else [0] * 4
221
+ per_format_real_acc = (per_format_real_correct / real_count * 100).cpu().tolist() if real_count > 0 else [0] * 4
222
+
223
+ print(f"\nPer-Format AI Detection (true positive rate):")
224
+ for i, name in enumerate(TestConfig.MODELS):
225
+ print(f" {name:6s}: {per_format_ai_acc[i]:6.2f}%")
226
+
227
+ print(f"\nPer-Format Real Rejection (true negative rate):")
228
+ for i, name in enumerate(TestConfig.MODELS):
229
+ print(f" {name:6s}: {per_format_real_acc[i]:6.2f}%")
230
+
231
+ # Save results
232
+ results = {
233
+ 'config': {
234
+ 'ai_folder': TestConfig.AI_IMAGE_DIR,
235
+ 'real_folder': TestConfig.REAL_IMAGE_DIR,
236
+ 'checkpoint': TestConfig.CHECKPOINT_PATH,
237
+ 'sample_size': TestConfig.SAMPLE_SIZE,
238
+ },
239
+ 'metrics': {
240
+ 'overall_accuracy': overall_acc,
241
+ 'ai_detection_accuracy': ai_acc,
242
+ 'real_detection_accuracy': real_acc,
243
+ 'per_format_ai_detection': dict(zip(TestConfig.MODELS, per_format_ai_acc)),
244
+ 'per_format_real_rejection': dict(zip(TestConfig.MODELS, per_format_real_acc)),
245
+ }
246
+ }
247
+
248
+ output_dir = Path("./results")
249
+ output_dir.mkdir(exist_ok=True)
250
+ output_file = output_dir / "test_evaluation_results.json"
251
+
252
+ with open(output_file, 'w') as f:
253
+ json.dump(results, f, indent=2, default=str)
254
+
255
+ print(f"\n✓ Detailed results saved to: {output_file}")
256
+
257
+
258
+ if __name__ == "__main__":
259
+ evaluate_model()
model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:24e4001fc11d4bee1bc01b6ecbf7765a1c2c48b7f0e75fad5abf733e62f531da
3
+ size 9386549