import torch import torch.nn as nn import json from pathlib import Path from typing import List, Dict import random import tqdm from PIL import Image from torch.utils.data import Dataset, DataLoader from torchvision import transforms # ==================== TEST CONFIGURATION (EDIT THESE) ==================== class TestConfig: # Folder paths - edit these directly AI_IMAGE_DIR = "/path/to/ai-images" REAL_IMAGE_DIR = "path/to/images" CHECKPOINT_PATH = "./checkpoints/model.pt" # Test parameters SAMPLE_SIZE = 400 # How many images to randomly sample from each folder CROP_SIZE = 512 # Must match training crop size BATCH_SIZE = 1 # Adjust based on GPU memory DEVICE = "cpu" # or "cuda" # Model heads (match training config) MODELS = ['flux', 'flux2', 'sdxl', 'sd15'] # ==================== MODEL DEFINITION ==================== class BAILU(nn.Module): """Same model architecture as training""" def __init__(self): super().__init__() self.conv_blocks = nn.Sequential( nn.Conv2d(3, 16, kernel_size=4, stride=1, padding=0), nn.GELU(), nn.Conv2d(16, 32, kernel_size=4, stride=1, padding=0), nn.GELU(), nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0), nn.GELU(), nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=0), nn.GELU(), nn.Conv2d(128, 256, kernel_size=4, stride=4, padding=0), nn.GELU(), nn.Conv2d(256, 256, kernel_size=4, stride=4, padding=0), nn.GELU(), nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=0), nn.GELU(), nn.AdaptiveAvgPool2d(1) ) self.head = nn.Sequential( nn.Linear(256, 32), nn.GELU(), nn.Linear(32, 4), # 4 heads: flux, flux2, sdxl, sd15 ) def forward(self, x): features = self.conv_blocks(x) # (B, 256, 1, 1) features = features.view(features.size(0), -1) return self.head(features) # (B, 4) # ==================== TEST DATASET ==================== class TestDataset(Dataset): """Loads and processes images from AI and Real folders""" def __init__(self, ai_paths: List[Path], real_paths: List[Path], sample_size: int): # Randomly sample images from each category ai_sample = random.sample(ai_paths, min(sample_size, len(ai_paths))) if ai_paths else [] real_sample = random.sample(real_paths, min(sample_size, len(real_paths))) if real_paths else [] self.image_paths = ai_sample + real_sample self.labels = [1] * len(ai_sample) + [0] * len(real_sample) # 1=AI, 0=Real # Inference transform: deterministic pad + center crop self.transform = transforms.Compose([ transforms.CenterCrop(TestConfig.CROP_SIZE), transforms.ToTensor(), ]) def __len__(self): return len(self.image_paths) def __getitem__(self, idx): path = self.image_paths[idx] try: with Image.open(path) as img: image = img.convert('RGB') image_tensor = self.transform(image) return { 'image': image_tensor, 'label': self.labels[idx], 'path': str(path) } except Exception as e: print(f"Warning: Could not load {path} - {e}") # Return a dummy image and mark as error dummy = torch.zeros(3, TestConfig.CROP_SIZE, TestConfig.CROP_SIZE) return {'image': dummy, 'label': self.labels[idx], 'path': str(path), 'error': True} # ==================== EVALUATION FUNCTION ==================== def evaluate_model(): """Main evaluation loop""" print("=" * 60) print("BAILU Model Test Evaluation") print("=" * 60) print(f"AI folder: {TestConfig.AI_IMAGE_DIR}") print(f"Real folder: {TestConfig.REAL_IMAGE_DIR}") print(f"Checkpoint: {TestConfig.CHECKPOINT_PATH}") print(f"Sample size: {TestConfig.SAMPLE_SIZE} images per class") # Setup device device = torch.device(TestConfig.DEVICE) torch.manual_seed(42) # For reproducible sampling # Load model print("\n๐Ÿ“ฆ Loading model...") model = BAILU().to(device) if not Path(TestConfig.CHECKPOINT_PATH).exists(): raise FileNotFoundError(f"Checkpoint not found: {TestConfig.CHECKPOINT_PATH}") checkpoint = torch.load(TestConfig.CHECKPOINT_PATH, map_location=device) model.load_state_dict(checkpoint['model_state_dict']) model.eval() print(f"โœ“ Model loaded (epoch {checkpoint.get('epoch', 'unknown')})") # Load image paths print("\n๐Ÿ“‚ Scanning folders...") ai_paths = [] real_paths = [] for ext in ['*.png', '*.jpg', '*.jpeg']: ai_paths.extend(Path(TestConfig.AI_IMAGE_DIR).rglob(ext)) real_paths.extend(Path(TestConfig.REAL_IMAGE_DIR).rglob(ext)) print(f"Found {len(ai_paths)} AI images, {len(real_paths)} real images") if not ai_paths and not real_paths: raise ValueError("No images found! Check folder paths.") # Create dataset and dataloader test_dataset = TestDataset(ai_paths, real_paths, TestConfig.SAMPLE_SIZE) test_loader = DataLoader( test_dataset, batch_size=TestConfig.BATCH_SIZE, shuffle=False, num_workers=0 # Simpler for single-threaded inference ) print(f"\n๐Ÿงช Evaluating {len(test_dataset)} images...") # Metrics tracking total_correct = 0 total_samples = 0 ai_correct = 0 real_correct = 0 ai_total = 0 real_total = 0 # Per-format tracking num_formats = 4 per_format_ai_correct = torch.zeros(num_formats, device=device) per_format_real_correct = torch.zeros(num_formats, device=device) ai_count = 0 real_count = 0 # Run inference with torch.no_grad(): pbar = tqdm.tqdm(test_loader, desc="Processing", unit="batch") for batch in pbar: images = batch['image'].to(device) labels = batch['label'].to(device) # Forward pass predictions = model(images) # (B, 4) probs = torch.sigmoid(predictions) # Classification rule: AI if ANY head > 0.5 max_probs, _ = probs.max(dim=1) pred_labels = (max_probs > 0.5).long() # Update overall metrics correct = (pred_labels == labels).float() total_correct += correct.sum().item() total_samples += len(labels) # Update per-class metrics ai_mask = labels == 1 real_mask = labels == 0 ai_correct += correct[ai_mask].sum().item() real_correct += correct[real_mask].sum().item() ai_total += ai_mask.sum().item() real_total += real_mask.sum().item() # Per-format metrics if ai_mask.any(): ai_probs = probs[ai_mask] per_format_ai_correct += (ai_probs > 0.5).sum(dim=0) ai_count += ai_probs.shape[0] if real_mask.any(): real_probs = probs[real_mask] per_format_real_correct += (real_probs <= 0.5).sum(dim=0) real_count += real_probs.shape[0] # Update progress bar current_acc = total_correct / total_samples * 100 if total_samples > 0 else 0 pbar.set_postfix_str(f"Acc: {current_acc:.2f}%") # Calculate final metrics print("\n" + "=" * 60) print("RESULTS") print("=" * 60) overall_acc = total_correct / total_samples * 100 ai_acc = ai_correct / ai_total * 100 if ai_total > 0 else 0 real_acc = real_correct / real_total * 100 if real_total > 0 else 0 print(f"Overall Accuracy: {overall_acc:.2f}% ({total_correct:.0f}/{total_samples})") print(f"AI Detection Rate: {ai_acc:.2f}% ({ai_correct:.0f}/{ai_total})") print(f"Real Accuracy: {real_acc:.2f}% ({real_correct:.0f}/{real_total})") # Per-format results per_format_ai_acc = (per_format_ai_correct / ai_count * 100).cpu().tolist() if ai_count > 0 else [0] * 4 per_format_real_acc = (per_format_real_correct / real_count * 100).cpu().tolist() if real_count > 0 else [0] * 4 print(f"\nPer-Format AI Detection (true positive rate):") for i, name in enumerate(TestConfig.MODELS): print(f" {name:6s}: {per_format_ai_acc[i]:6.2f}%") print(f"\nPer-Format Real Rejection (true negative rate):") for i, name in enumerate(TestConfig.MODELS): print(f" {name:6s}: {per_format_real_acc[i]:6.2f}%") # Save results results = { 'config': { 'ai_folder': TestConfig.AI_IMAGE_DIR, 'real_folder': TestConfig.REAL_IMAGE_DIR, 'checkpoint': TestConfig.CHECKPOINT_PATH, 'sample_size': TestConfig.SAMPLE_SIZE, }, 'metrics': { 'overall_accuracy': overall_acc, 'ai_detection_accuracy': ai_acc, 'real_detection_accuracy': real_acc, 'per_format_ai_detection': dict(zip(TestConfig.MODELS, per_format_ai_acc)), 'per_format_real_rejection': dict(zip(TestConfig.MODELS, per_format_real_acc)), } } output_dir = Path("./results") output_dir.mkdir(exist_ok=True) output_file = output_dir / "test_evaluation_results.json" with open(output_file, 'w') as f: json.dump(results, f, indent=2, default=str) print(f"\nโœ“ Detailed results saved to: {output_file}") if __name__ == "__main__": evaluate_model()