| | import os |
| | import torch |
| | import torch.nn as nn |
| | import torch.optim as optim |
| | from torch.utils.data import DataLoader |
| | from tqdm import tqdm |
| | import random |
| | import ssl |
| | |
| | ssl._create_default_https_context = ssl._create_unverified_context |
| | from torch.cuda.amp import GradScaler, autocast |
| |
|
| | from src.config import Config |
| | from src.models import DeepfakeDetector |
| | from src.dataset import DeepfakeDataset |
| |
|
| | try: |
| | from safetensors.torch import save_file, load_file |
| | SAFETENSORS_AVAILABLE = True |
| | except ImportError: |
| | SAFETENSORS_AVAILABLE = False |
| | print("Warning: safetensors not installed. Checkpoints will be saved as .pt") |
| |
|
| | def train(): |
| | |
| | Config.setup() |
| | device = torch.device(Config.DEVICE) |
| | |
| | |
| | if Config.TRAIN_DATA_PATH == Config.TEST_DATA_PATH: |
| | print("Train and Test paths are identical. Performing automatic 80/20 shuffle split...") |
| | all_paths, all_labels = DeepfakeDataset.scan_directory(Config.TRAIN_DATA_PATH) |
| | |
| | if len(all_paths) == 0: |
| | print(f"No images found in {Config.TRAIN_DATA_PATH}") |
| | return |
| |
|
| | |
| | combined = list(zip(all_paths, all_labels)) |
| | random.shuffle(combined) |
| | |
| | split_idx = int(len(combined) * 0.8) |
| | train_data = combined[:split_idx] |
| | val_data = combined[split_idx:] |
| | |
| | train_paths, train_labels = zip(*train_data) |
| | val_paths, val_labels = zip(*val_data) |
| | |
| | train_dataset = DeepfakeDataset(file_paths=list(train_paths), labels=list(train_labels), phase='train') |
| | val_dataset = DeepfakeDataset(file_paths=list(val_paths), labels=list(val_labels), phase='val') |
| | else: |
| | |
| | train_dataset = DeepfakeDataset(root_dir=Config.TRAIN_DATA_PATH, phase='train') |
| | val_dataset = DeepfakeDataset(root_dir=Config.TEST_DATA_PATH, phase='val') |
| | |
| | |
| | train_loader = DataLoader(train_dataset, batch_size=Config.BATCH_SIZE, shuffle=True, |
| | num_workers=Config.NUM_WORKERS, |
| | pin_memory=True if device.type=='cuda' else False, |
| | persistent_workers=True if Config.NUM_WORKERS > 0 else False) |
| | val_loader = DataLoader(val_dataset, batch_size=Config.BATCH_SIZE, shuffle=False, |
| | num_workers=Config.NUM_WORKERS, |
| | pin_memory=True if device.type=='cuda' else False, |
| | persistent_workers=True if Config.NUM_WORKERS > 0 else False) |
| | |
| | |
| | print("Initializing Multi-Branch DeepfakeDetector...") |
| | model = DeepfakeDetector(pretrained=True).to(device) |
| | |
| | |
| | criterion = nn.BCEWithLogitsLoss() |
| | optimizer = optim.AdamW(model.parameters(), lr=Config.LEARNING_RATE, weight_decay=Config.WEIGHT_DECAY) |
| | |
| | criterion = nn.BCEWithLogitsLoss() |
| | optimizer = optim.AdamW(model.parameters(), lr=Config.LEARNING_RATE, weight_decay=Config.WEIGHT_DECAY) |
| | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5) |
| | |
| | |
| | use_amp = (Config.DEVICE == 'cuda') |
| | scaler = GradScaler() if use_amp else None |
| | if use_amp: |
| | print("π Mixed Precision (AMP) Enabled for RTX GPU") |
| | else: |
| | print("π Standard Precision (No AMP) for CPU/MPS") |
| | |
| | |
| | start_epoch = 0 |
| | best_acc = 0.0 |
| | |
| | |
| | |
| | |
| | |
| | resume_path = os.path.join(Config.CHECKPOINT_DIR, "best_model.safetensors") |
| | if not os.path.exists(resume_path): |
| | |
| | import glob |
| | import re |
| | checkpoints = glob.glob(os.path.join(Config.CHECKPOINT_DIR, "checkpoint_ep*.safetensors")) |
| | if checkpoints: |
| | |
| | def get_epoch(p): |
| | match = re.search(r"checkpoint_ep(\d+)", p) |
| | return int(match.group(1)) if match else 0 |
| | |
| | latest_ckpt = max(checkpoints, key=get_epoch) |
| | resume_path = latest_ckpt |
| | start_epoch = get_epoch(latest_ckpt) |
| | print(f"π Auto-Resuming from latest epoch: {start_epoch}") |
| | else: |
| | resume_path = os.path.join(Config.CHECKPOINT_DIR, "patched_model.safetensors") |
| | |
| | if os.path.exists(resume_path): |
| | print(f"\nπ Found existing checkpoint: {resume_path}") |
| | print("Auto-resuming to FINETUNE this model...") |
| | |
| | try: |
| | if resume_path.endswith(".safetensors") and SAFETENSORS_AVAILABLE: |
| | state_dict = load_file(resume_path) |
| | else: |
| | state_dict = torch.load(resume_path, map_location=device) |
| | |
| | |
| | model.load_state_dict(state_dict, strict=False) |
| | print("β
Weights loaded. Starting Fine-Tuning.") |
| | except Exception as e: |
| | print(f"β Failed to load checkpoint: {e}") |
| | print("Starting from ImageNet weights.") |
| | |
| | |
| | |
| | for epoch in range(start_epoch, Config.EPOCHS): |
| | model.train() |
| | train_loss = 0.0 |
| | train_correct = 0 |
| | train_total = 0 |
| | |
| | loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{Config.EPOCHS}") |
| | for images, labels in loop: |
| | images = images.to(device) |
| | labels = labels.to(device).unsqueeze(1) |
| | |
| | optimizer.zero_grad() |
| | |
| | if use_amp: |
| | with autocast(): |
| | outputs = model(images) |
| | loss = criterion(outputs, labels) |
| | |
| | scaler.scale(loss).backward() |
| | scaler.step(optimizer) |
| | scaler.update() |
| | else: |
| | |
| | outputs = model(images) |
| | loss = criterion(outputs, labels) |
| | loss.backward() |
| | optimizer.step() |
| | |
| | train_loss += loss.item() |
| | preds = (torch.sigmoid(outputs) > 0.5).float() |
| | correct = (preds == labels).sum().item() |
| | train_correct += correct |
| | train_total += labels.size(0) |
| | |
| | loop.set_postfix(loss=loss.item(), acc=correct/labels.size(0)) |
| | |
| | train_acc = train_correct / train_total if train_total > 0 else 0 |
| | print(f"Epoch {epoch+1} Train Loss: {train_loss/len(train_loader):.4f} Acc: {train_acc:.4f}") |
| | |
| | |
| | save_checkpoint(model, epoch+1, train_acc, best=False) |
| | |
| | |
| | if len(val_dataset) > 0: |
| | val_loss, val_acc = validate(model, val_loader, criterion, device) |
| | print(f"Epoch {epoch+1} Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}") |
| | |
| | |
| | if val_acc > best_acc: |
| | best_acc = val_acc |
| | print(f"β New best model! Validation Accuracy: {val_acc:.4f}") |
| | save_checkpoint(model, epoch+1, val_acc, best=True) |
| | |
| | scheduler.step() |
| | |
| | print(f"\nπ Training Complete!") |
| | print(f"Best Validation Accuracy: {best_acc:.4f}") |
| |
|
| | def validate(model, loader, criterion, device): |
| | model.eval() |
| | val_loss = 0.0 |
| | correct = 0 |
| | total = 0 |
| | |
| | with torch.no_grad(): |
| | for images, labels in loader: |
| | images = images.to(device) |
| | labels = labels.to(device).unsqueeze(1) |
| | |
| | outputs = model(images) |
| | loss = criterion(outputs, labels) |
| | |
| | val_loss += loss.item() |
| | preds = (torch.sigmoid(outputs) > 0.5).float() |
| | correct += (preds == labels).sum().item() |
| | total += labels.size(0) |
| | |
| | return val_loss / len(loader), correct / total |
| |
|
| | def save_checkpoint(model, epoch, acc, best=False): |
| | state_dict = model.state_dict() |
| | name = "best_model.safetensors" if best else f"checkpoint_ep{epoch}.safetensors" |
| | path = os.path.join(Config.CHECKPOINT_DIR, name) |
| | |
| | if SAFETENSORS_AVAILABLE: |
| | try: |
| | |
| | from safetensors.torch import save_model |
| | save_model(model, path) |
| | print(f"Saved Checkpoint: {path}") |
| | |
| | |
| | try: |
| | from datetime import datetime |
| | log_path = os.path.join(Config.PROJECT_ROOT, "TRAINING_HISTORY.md") |
| | timestamp = datetime.now().strftime("%Y-%m-%d | %I:%M %p") |
| | |
| | |
| | if not os.path.exists(log_path): |
| | with open(log_path, "w", encoding="utf-8") as f: |
| | f.write("# π Training History Log\n\n") |
| | f.write("| Date | Time | Model Name | Dataset | Epochs | Accuracy | Loss | Status |\n") |
| | f.write("| :--- | :--- | :--- | :--- | :--- | :--- | :--- | :--- |\n") |
| | |
| | |
| | with open(log_path, "a", encoding="utf-8") as f: |
| | |
| | dataset_name = os.path.basename(Config.DATA_DIR) |
| | entry = f"| **{timestamp.split(' | ')[0]}** | {timestamp.split(' | ')[1]} | {name} | {dataset_name} | {epoch} | {acc*100:.2f}% | N/A | β
Saved |\n" |
| | f.write(entry) |
| | print(f"π Logged to TRAINING_HISTORY.md") |
| |
|
| | |
| | detail_path = os.path.join(Config.PROJECT_ROOT, "DETAILED_HISTORY.md") |
| | with open(detail_path, "a", encoding="utf-8") as f: |
| | f.write(f"\n## Model: {name} (Epoch {epoch})\n") |
| | f.write(f"| Feature | Detail |\n| :--- | :--- |\n") |
| | f.write(f"| **Date** | {timestamp} |\n") |
| | f.write(f"| **Training Accuracy** | {acc*100:.2f}% |\n") |
| | f.write(f"| **Dataset** | {Config.DATA_DIR} |\n") |
| | f.write(f"| **Batch Size** | {Config.BATCH_SIZE} |\n") |
| | f.write(f"| **Optimizer** | AdamW (lr={Config.LEARNING_RATE}) |\n") |
| | f.write(f"| **Device** | {Config.DEVICE.upper()} |\n") |
| | f.write("\n---\n") |
| | print(f"π Detailed log written to DETAILED_HISTORY.md") |
| |
|
| | except Exception as e: |
| | print(f"β οΈ Failed to write log: {e}") |
| |
|
| | except Exception as e: |
| | |
| | print(f"SafeTensors save failed ({e}), falling back to .pth format") |
| | torch.save(state_dict, path.replace(".safetensors", ".pth")) |
| | print(f"Saved Checkpoint (Legacy): {path.replace('.safetensors', '.pth')}") |
| | else: |
| | torch.save(state_dict, path.replace(".safetensors", ".pth")) |
| | print(f"Saved Checkpoint (Legacy): {path}") |
| |
|
| | if __name__ == "__main__": |
| | train() |
| |
|