|
|
import torch |
|
|
import torch.nn as nn |
|
|
import json |
|
|
from pathlib import Path |
|
|
from typing import List, Dict |
|
|
import random |
|
|
import tqdm |
|
|
from PIL import Image |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from torchvision import transforms |
|
|
|
|
|
|
|
|
|
|
|
class TestConfig: |
|
|
|
|
|
AI_IMAGE_DIR = "/path/to/ai-images" |
|
|
REAL_IMAGE_DIR = "path/to/images" |
|
|
CHECKPOINT_PATH = "./checkpoints/model.pt" |
|
|
|
|
|
|
|
|
SAMPLE_SIZE = 400 |
|
|
CROP_SIZE = 512 |
|
|
BATCH_SIZE = 1 |
|
|
DEVICE = "cpu" |
|
|
|
|
|
|
|
|
MODELS = ['flux', 'flux2', 'sdxl', 'sd15'] |
|
|
|
|
|
|
|
|
|
|
|
class BAILU(nn.Module): |
|
|
"""Same model architecture as training""" |
|
|
|
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
self.conv_blocks = nn.Sequential( |
|
|
nn.Conv2d(3, 16, kernel_size=4, stride=1, padding=0), nn.GELU(), |
|
|
nn.Conv2d(16, 32, kernel_size=4, stride=1, padding=0), nn.GELU(), |
|
|
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0), nn.GELU(), |
|
|
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=0), nn.GELU(), |
|
|
nn.Conv2d(128, 256, kernel_size=4, stride=4, padding=0), nn.GELU(), |
|
|
nn.Conv2d(256, 256, kernel_size=4, stride=4, padding=0), nn.GELU(), |
|
|
nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=0), nn.GELU(), |
|
|
nn.AdaptiveAvgPool2d(1) |
|
|
) |
|
|
self.head = nn.Sequential( |
|
|
nn.Linear(256, 32), nn.GELU(), |
|
|
nn.Linear(32, 4), |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
features = self.conv_blocks(x) |
|
|
features = features.view(features.size(0), -1) |
|
|
return self.head(features) |
|
|
|
|
|
|
|
|
|
|
|
class TestDataset(Dataset): |
|
|
"""Loads and processes images from AI and Real folders""" |
|
|
|
|
|
def __init__(self, ai_paths: List[Path], real_paths: List[Path], sample_size: int): |
|
|
|
|
|
ai_sample = random.sample(ai_paths, min(sample_size, len(ai_paths))) if ai_paths else [] |
|
|
real_sample = random.sample(real_paths, min(sample_size, len(real_paths))) if real_paths else [] |
|
|
|
|
|
self.image_paths = ai_sample + real_sample |
|
|
self.labels = [1] * len(ai_sample) + [0] * len(real_sample) |
|
|
|
|
|
|
|
|
self.transform = transforms.Compose([ |
|
|
transforms.CenterCrop(TestConfig.CROP_SIZE), |
|
|
transforms.ToTensor(), |
|
|
]) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.image_paths) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
path = self.image_paths[idx] |
|
|
try: |
|
|
with Image.open(path) as img: |
|
|
image = img.convert('RGB') |
|
|
image_tensor = self.transform(image) |
|
|
return { |
|
|
'image': image_tensor, |
|
|
'label': self.labels[idx], |
|
|
'path': str(path) |
|
|
} |
|
|
except Exception as e: |
|
|
print(f"Warning: Could not load {path} - {e}") |
|
|
|
|
|
dummy = torch.zeros(3, TestConfig.CROP_SIZE, TestConfig.CROP_SIZE) |
|
|
return {'image': dummy, 'label': self.labels[idx], 'path': str(path), 'error': True} |
|
|
|
|
|
|
|
|
|
|
|
def evaluate_model(): |
|
|
"""Main evaluation loop""" |
|
|
print("=" * 60) |
|
|
print("BAILU Model Test Evaluation") |
|
|
print("=" * 60) |
|
|
print(f"AI folder: {TestConfig.AI_IMAGE_DIR}") |
|
|
print(f"Real folder: {TestConfig.REAL_IMAGE_DIR}") |
|
|
print(f"Checkpoint: {TestConfig.CHECKPOINT_PATH}") |
|
|
print(f"Sample size: {TestConfig.SAMPLE_SIZE} images per class") |
|
|
|
|
|
|
|
|
device = torch.device(TestConfig.DEVICE) |
|
|
torch.manual_seed(42) |
|
|
|
|
|
|
|
|
print("\n📦 Loading model...") |
|
|
model = BAILU().to(device) |
|
|
|
|
|
if not Path(TestConfig.CHECKPOINT_PATH).exists(): |
|
|
raise FileNotFoundError(f"Checkpoint not found: {TestConfig.CHECKPOINT_PATH}") |
|
|
|
|
|
checkpoint = torch.load(TestConfig.CHECKPOINT_PATH, map_location=device) |
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
model.eval() |
|
|
print(f"✓ Model loaded (epoch {checkpoint.get('epoch', 'unknown')})") |
|
|
|
|
|
|
|
|
print("\n📂 Scanning folders...") |
|
|
ai_paths = [] |
|
|
real_paths = [] |
|
|
for ext in ['*.png', '*.jpg', '*.jpeg']: |
|
|
ai_paths.extend(Path(TestConfig.AI_IMAGE_DIR).rglob(ext)) |
|
|
real_paths.extend(Path(TestConfig.REAL_IMAGE_DIR).rglob(ext)) |
|
|
|
|
|
print(f"Found {len(ai_paths)} AI images, {len(real_paths)} real images") |
|
|
|
|
|
if not ai_paths and not real_paths: |
|
|
raise ValueError("No images found! Check folder paths.") |
|
|
|
|
|
|
|
|
test_dataset = TestDataset(ai_paths, real_paths, TestConfig.SAMPLE_SIZE) |
|
|
test_loader = DataLoader( |
|
|
test_dataset, |
|
|
batch_size=TestConfig.BATCH_SIZE, |
|
|
shuffle=False, |
|
|
num_workers=0 |
|
|
) |
|
|
|
|
|
print(f"\n🧪 Evaluating {len(test_dataset)} images...") |
|
|
|
|
|
|
|
|
total_correct = 0 |
|
|
total_samples = 0 |
|
|
ai_correct = 0 |
|
|
real_correct = 0 |
|
|
ai_total = 0 |
|
|
real_total = 0 |
|
|
|
|
|
|
|
|
num_formats = 4 |
|
|
per_format_ai_correct = torch.zeros(num_formats, device=device) |
|
|
per_format_real_correct = torch.zeros(num_formats, device=device) |
|
|
ai_count = 0 |
|
|
real_count = 0 |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
pbar = tqdm.tqdm(test_loader, desc="Processing", unit="batch") |
|
|
for batch in pbar: |
|
|
images = batch['image'].to(device) |
|
|
labels = batch['label'].to(device) |
|
|
|
|
|
|
|
|
predictions = model(images) |
|
|
probs = torch.sigmoid(predictions) |
|
|
|
|
|
|
|
|
max_probs, _ = probs.max(dim=1) |
|
|
pred_labels = (max_probs > 0.5).long() |
|
|
|
|
|
|
|
|
correct = (pred_labels == labels).float() |
|
|
total_correct += correct.sum().item() |
|
|
total_samples += len(labels) |
|
|
|
|
|
|
|
|
ai_mask = labels == 1 |
|
|
real_mask = labels == 0 |
|
|
|
|
|
ai_correct += correct[ai_mask].sum().item() |
|
|
real_correct += correct[real_mask].sum().item() |
|
|
ai_total += ai_mask.sum().item() |
|
|
real_total += real_mask.sum().item() |
|
|
|
|
|
|
|
|
if ai_mask.any(): |
|
|
ai_probs = probs[ai_mask] |
|
|
per_format_ai_correct += (ai_probs > 0.5).sum(dim=0) |
|
|
ai_count += ai_probs.shape[0] |
|
|
|
|
|
if real_mask.any(): |
|
|
real_probs = probs[real_mask] |
|
|
per_format_real_correct += (real_probs <= 0.5).sum(dim=0) |
|
|
real_count += real_probs.shape[0] |
|
|
|
|
|
|
|
|
current_acc = total_correct / total_samples * 100 if total_samples > 0 else 0 |
|
|
pbar.set_postfix_str(f"Acc: {current_acc:.2f}%") |
|
|
|
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print("RESULTS") |
|
|
print("=" * 60) |
|
|
|
|
|
overall_acc = total_correct / total_samples * 100 |
|
|
ai_acc = ai_correct / ai_total * 100 if ai_total > 0 else 0 |
|
|
real_acc = real_correct / real_total * 100 if real_total > 0 else 0 |
|
|
|
|
|
print(f"Overall Accuracy: {overall_acc:.2f}% ({total_correct:.0f}/{total_samples})") |
|
|
print(f"AI Detection Rate: {ai_acc:.2f}% ({ai_correct:.0f}/{ai_total})") |
|
|
print(f"Real Accuracy: {real_acc:.2f}% ({real_correct:.0f}/{real_total})") |
|
|
|
|
|
|
|
|
per_format_ai_acc = (per_format_ai_correct / ai_count * 100).cpu().tolist() if ai_count > 0 else [0] * 4 |
|
|
per_format_real_acc = (per_format_real_correct / real_count * 100).cpu().tolist() if real_count > 0 else [0] * 4 |
|
|
|
|
|
print(f"\nPer-Format AI Detection (true positive rate):") |
|
|
for i, name in enumerate(TestConfig.MODELS): |
|
|
print(f" {name:6s}: {per_format_ai_acc[i]:6.2f}%") |
|
|
|
|
|
print(f"\nPer-Format Real Rejection (true negative rate):") |
|
|
for i, name in enumerate(TestConfig.MODELS): |
|
|
print(f" {name:6s}: {per_format_real_acc[i]:6.2f}%") |
|
|
|
|
|
|
|
|
results = { |
|
|
'config': { |
|
|
'ai_folder': TestConfig.AI_IMAGE_DIR, |
|
|
'real_folder': TestConfig.REAL_IMAGE_DIR, |
|
|
'checkpoint': TestConfig.CHECKPOINT_PATH, |
|
|
'sample_size': TestConfig.SAMPLE_SIZE, |
|
|
}, |
|
|
'metrics': { |
|
|
'overall_accuracy': overall_acc, |
|
|
'ai_detection_accuracy': ai_acc, |
|
|
'real_detection_accuracy': real_acc, |
|
|
'per_format_ai_detection': dict(zip(TestConfig.MODELS, per_format_ai_acc)), |
|
|
'per_format_real_rejection': dict(zip(TestConfig.MODELS, per_format_real_acc)), |
|
|
} |
|
|
} |
|
|
|
|
|
output_dir = Path("./results") |
|
|
output_dir.mkdir(exist_ok=True) |
|
|
output_file = output_dir / "test_evaluation_results.json" |
|
|
|
|
|
with open(output_file, 'w') as f: |
|
|
json.dump(results, f, indent=2, default=str) |
|
|
|
|
|
print(f"\n✓ Detailed results saved to: {output_file}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
evaluate_model() |