Spaces:
Sleeping
Sleeping
| import argparse | |
| import os | |
| import random | |
| import sys | |
| from contextlib import nullcontext | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from dataset import AIGIBenchDataset, get_train_transforms, get_val_transforms | |
| from datasets import load_dataset | |
| from dotenv import load_dotenv | |
| from model import DeForge_AI_Model | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| # Add current directory to sys.path | |
| current_dir = os.path.dirname(os.path.abspath(__file__)) | |
| if current_dir not in sys.path: | |
| sys.path.insert(0, current_dir) | |
| def seed_everything(seed=123): | |
| random.seed(seed) | |
| os.environ['PYTHONHASHSEED'] = str(seed) | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description='Train DeForge-AI Model') | |
| parser.add_argument('--lr', type=float, default=1e-4) | |
| parser.add_argument('--backbone-lr-scale', type=float, default=0.25) | |
| parser.add_argument('--weight-decay', type=float, default=0.01) | |
| parser.add_argument('--batch-size', type=int, default=16) | |
| parser.add_argument('--epochs', type=int, default=1) | |
| parser.add_argument('--max-steps', type=int, default=10000) | |
| parser.add_argument('--num-workers', type=int, default=8) | |
| parser.add_argument('--seed', type=int, default=123) | |
| parser.add_argument('--image-size', type=int, default=256) | |
| parser.add_argument('--gradient-clip', type=float, default=1.0) | |
| parser.add_argument('--lora-r', type=int, default=16) | |
| parser.add_argument('--lora-alpha', type=int, default=32) | |
| parser.add_argument('--lora-dropout', type=float, default=0.5) | |
| parser.add_argument('--forensic-dim', type=int, default=256) | |
| parser.add_argument('--unfreeze-last-blocks', type=int, default=0) | |
| parser.add_argument( | |
| '--lora-target-modules', | |
| type=str, | |
| default='q_proj,k_proj,v_proj,out_proj,fc1,fc2', | |
| ) | |
| parser.add_argument('--no-val', action='store_true') | |
| parser.add_argument('--val-every', type=int, default=1) | |
| parser.add_argument('--pct-start', type=float, default=0.1) | |
| return parser.parse_args() | |
| def get_amp_context(device): | |
| if device.type == 'cuda': | |
| return torch.amp.autocast(device_type='cuda', dtype=torch.float16) | |
| return nullcontext() | |
| def count_parameters(model): | |
| total = sum(parameter.numel() for parameter in model.parameters()) | |
| trainable = sum( | |
| parameter.numel() for parameter in model.parameters() if parameter.requires_grad | |
| ) | |
| return trainable, total | |
| def build_optimizer(model, args): | |
| backbone_params = [] | |
| fast_params = [] | |
| for name, parameter in model.named_parameters(): | |
| if not parameter.requires_grad: | |
| continue | |
| if name.startswith('backbone') and 'lora_' not in name: | |
| backbone_params.append(parameter) | |
| else: | |
| fast_params.append(parameter) | |
| parameter_groups = [] | |
| if backbone_params: | |
| parameter_groups.append( | |
| { | |
| 'params': backbone_params, | |
| 'lr': args.lr * args.backbone_lr_scale, | |
| } | |
| ) | |
| if fast_params: | |
| parameter_groups.append({'params': fast_params, 'lr': args.lr}) | |
| return optim.AdamW( | |
| parameter_groups, | |
| lr=args.lr, | |
| weight_decay=args.weight_decay, | |
| ) | |
| def save_checkpoint(path, epoch, global_step, model, optimizer, metrics, args): | |
| torch.save( | |
| { | |
| 'epoch': epoch, | |
| 'global_step': global_step, | |
| 'model_state_dict': model.state_dict(), | |
| 'optimizer_state_dict': optimizer.state_dict(), | |
| 'metrics': metrics, | |
| 'args': vars(args), | |
| }, | |
| path, | |
| ) | |
| def run_validation(model, val_loader, criterion, device, epoch): | |
| model.eval() | |
| val_loss = 0.0 | |
| total = 0 | |
| all_preds = [] | |
| all_labels = [] | |
| print('Running validation...') | |
| with torch.inference_mode(): | |
| for images, labels in tqdm(val_loader, desc='Validating', leave=False): | |
| images = images.to(device, non_blocking=True) | |
| labels = labels.to(device, non_blocking=True).unsqueeze(1) | |
| with get_amp_context(device): | |
| logits = model(images) | |
| loss = criterion(logits, labels) | |
| batch_size = labels.size(0) | |
| val_loss += loss.item() * batch_size | |
| total += batch_size | |
| all_preds.append(torch.sigmoid(logits).cpu()) | |
| all_labels.append(labels.cpu()) | |
| all_preds = torch.cat(all_preds, dim=0).numpy() | |
| all_labels = torch.cat(all_labels, dim=0).numpy() | |
| threshold = 0.5 | |
| preds = (all_preds > threshold).astype(float) | |
| acc = (preds == all_labels).mean() | |
| real_mask = all_labels == 0 | |
| fake_mask = all_labels == 1 | |
| real_acc = ( | |
| (preds[real_mask] == all_labels[real_mask]).mean() if real_mask.any() else 0 | |
| ) | |
| fake_acc = ( | |
| (preds[fake_mask] == all_labels[fake_mask]).mean() if fake_mask.any() else 0 | |
| ) | |
| balanced_acc = 0.5 * (real_acc + fake_acc) | |
| metrics = { | |
| 'val_loss': val_loss / max(total, 1), | |
| 'val_acc': float(acc), | |
| 'val_real_acc': float(real_acc), | |
| 'val_fake_acc': float(fake_acc), | |
| 'val_balanced_acc': float(balanced_acc), | |
| } | |
| print( | |
| f'Epoch {epoch} | Val Loss: {metrics["val_loss"]:.4f} | ' | |
| f'Val Acc: {metrics["val_acc"]:.4f} | ' | |
| f'Val BAcc: {metrics["val_balanced_acc"]:.4f} | ' | |
| f'Real Acc: {metrics["val_real_acc"]:.4f} | ' | |
| f'Fake Acc: {metrics["val_fake_acc"]:.4f}' | |
| ) | |
| return metrics | |
| def train(): | |
| args = parse_args() | |
| seed_everything(args.seed) | |
| load_dotenv() | |
| hf_token = os.getenv('HF_TOKEN') | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| print(f'Using device: {device}') | |
| print( | |
| f'Config: lr={args.lr}, batch_size={args.batch_size}, epochs={args.epochs}, ' | |
| f'max_steps={args.max_steps}, image_size={args.image_size}, ' | |
| f'val={"disabled" if args.no_val else f"every {args.val_every} epoch(s)"}' | |
| ) | |
| checkpoints_dir = os.path.join(current_dir, 'checkpoints') | |
| os.makedirs(checkpoints_dir, exist_ok=True) | |
| print('Loading AIGIBench dataset from HuggingFace...') | |
| dataset = load_dataset('TheKernel01/AIGIBench', token=hf_token) | |
| train_ds = AIGIBenchDataset( | |
| dataset['train'], | |
| transform=get_train_transforms(size=args.image_size), | |
| ) | |
| train_loader = DataLoader( | |
| train_ds, | |
| batch_size=args.batch_size, | |
| shuffle=True, | |
| num_workers=args.num_workers, | |
| pin_memory=device.type == 'cuda', | |
| ) | |
| val_loader = None | |
| if not args.no_val: | |
| val_ds = AIGIBenchDataset( | |
| dataset['validation'], | |
| transform=get_val_transforms( | |
| size=args.image_size, | |
| ), | |
| ) | |
| val_loader = DataLoader( | |
| val_ds, | |
| batch_size=args.batch_size, | |
| shuffle=False, | |
| num_workers=args.num_workers, | |
| pin_memory=device.type == 'cuda', | |
| ) | |
| model = DeForge_AI_Model( | |
| lora_r=args.lora_r, | |
| lora_alpha=args.lora_alpha, | |
| lora_dropout=args.lora_dropout, | |
| lora_target_modules=[ | |
| module.strip() | |
| for module in args.lora_target_modules.split(',') | |
| if module.strip() | |
| ], | |
| forensic_dim=args.forensic_dim, | |
| unfreeze_last_blocks=args.unfreeze_last_blocks, | |
| image_size=args.image_size, | |
| ).to(device) | |
| trainable_params, total_params = count_parameters(model) | |
| print( | |
| f'Trainable params: {trainable_params / 1e6:.2f}M / ' | |
| f'{total_params / 1e6:.2f}M ({100 * trainable_params / max(total_params, 1):.2f}%)' | |
| ) | |
| optimizer = build_optimizer(model, args) | |
| criterion = nn.BCEWithLogitsLoss() | |
| scheduler = optim.lr_scheduler.OneCycleLR( | |
| optimizer, | |
| max_lr=[group['lr'] for group in optimizer.param_groups], | |
| total_steps=args.max_steps, | |
| pct_start=args.pct_start, | |
| anneal_strategy='cos', | |
| ) | |
| scaler = torch.amp.GradScaler('cuda') if device.type == 'cuda' else None | |
| best_balanced_acc = float('-inf') | |
| global_step = 0 | |
| for epoch in range(1, args.epochs + 1): | |
| model.train() | |
| running_loss = 0.0 | |
| remaining_steps = max(args.max_steps - global_step, 0) | |
| epoch_total = min(len(train_loader), remaining_steps) if remaining_steps else 0 | |
| pbar = tqdm( | |
| train_loader, desc=f'Epoch {epoch}/{args.epochs}', total=epoch_total | |
| ) | |
| for step_idx, (images, labels) in enumerate(pbar, start=1): | |
| if global_step >= args.max_steps: | |
| break | |
| images = images.to(device, non_blocking=True) | |
| labels = labels.to(device, non_blocking=True).unsqueeze(1) | |
| optimizer.zero_grad(set_to_none=True) | |
| with get_amp_context(device): | |
| logits = model(images) | |
| loss = criterion(logits, labels) | |
| if scaler is not None: | |
| scaler.scale(loss).backward() | |
| scaler.unscale_(optimizer) | |
| if args.gradient_clip is not None: | |
| torch.nn.utils.clip_grad_norm_( | |
| [p for p in model.parameters() if p.requires_grad], | |
| args.gradient_clip, | |
| ) | |
| scaler.step(optimizer) | |
| scaler.update() | |
| else: | |
| loss.backward() | |
| if args.gradient_clip is not None: | |
| torch.nn.utils.clip_grad_norm_( | |
| [p for p in model.parameters() if p.requires_grad], | |
| args.gradient_clip, | |
| ) | |
| optimizer.step() | |
| scheduler.step() | |
| global_step += 1 | |
| running_loss += loss.item() | |
| if global_step % 1000 == 0: | |
| step_checkpoint_path = os.path.join( | |
| checkpoints_dir, f'model_step_{global_step}.pth' | |
| ) | |
| save_checkpoint( | |
| step_checkpoint_path, | |
| epoch, | |
| global_step, | |
| model, | |
| optimizer, | |
| {}, | |
| args, | |
| ) | |
| print(f'Saved periodic checkpoint to {step_checkpoint_path}') | |
| if step_idx % 10 == 0 or step_idx == 1: | |
| current_lr = max(group['lr'] for group in optimizer.param_groups) | |
| pbar.set_postfix( | |
| { | |
| 'loss': f'{running_loss / step_idx:.4f}', | |
| 'lr': f'{current_lr:.2e}', | |
| } | |
| ) | |
| metrics = {} | |
| should_validate = ( | |
| not args.no_val and val_loader is not None and epoch % args.val_every == 0 | |
| ) | |
| if should_validate: | |
| metrics = run_validation(model, val_loader, criterion, device, epoch) | |
| checkpoint_name = ( | |
| f'deforge_ai_epoch_{epoch}_bacc_{metrics["val_balanced_acc"]:.4f}.pth' | |
| if metrics | |
| else f'deforge_ai_epoch_{epoch}_step_{global_step}.pth' | |
| ) | |
| checkpoint_path = os.path.join(checkpoints_dir, checkpoint_name) | |
| save_checkpoint( | |
| checkpoint_path, | |
| epoch, | |
| global_step, | |
| model, | |
| optimizer, | |
| metrics, | |
| args, | |
| ) | |
| print(f'Saved checkpoint to {checkpoint_path}') | |
| if metrics and metrics['val_balanced_acc'] > best_balanced_acc: | |
| best_balanced_acc = metrics['val_balanced_acc'] | |
| best_path = os.path.join(checkpoints_dir, 'model_epoch_best.pth') | |
| save_checkpoint( | |
| best_path, | |
| epoch, | |
| global_step, | |
| model, | |
| optimizer, | |
| metrics, | |
| args, | |
| ) | |
| print(f'Updated best checkpoint at {best_path}') | |
| elif not metrics: | |
| best_path = os.path.join(checkpoints_dir, 'model_epoch_best.pth') | |
| if not os.path.exists(best_path): | |
| save_checkpoint( | |
| best_path, | |
| epoch, | |
| global_step, | |
| model, | |
| optimizer, | |
| metrics, | |
| args, | |
| ) | |
| print(f'Initialized best checkpoint at {best_path}') | |
| latest_path = os.path.join(checkpoints_dir, 'model_epoch_last.pth') | |
| save_checkpoint( | |
| latest_path, | |
| epoch, | |
| global_step, | |
| model, | |
| optimizer, | |
| metrics, | |
| args, | |
| ) | |
| if global_step >= args.max_steps: | |
| print(f'Reached max_steps={args.max_steps}, stopping.') | |
| break | |
| print('Training complete!') | |
| if __name__ == '__main__': | |
| train() | |