TheKernel01's picture
Sync from GitHub via hub-sync
0788e19 verified
import argparse
import os
import random
import sys
from contextlib import nullcontext
import torch
import torch.nn as nn
import torch.optim as optim
from dataset import AIGIBenchDataset, get_train_transforms, get_val_transforms
from datasets import load_dataset
from dotenv import load_dotenv
from model import DeForge_AI_Model
from torch.utils.data import DataLoader
from tqdm import tqdm
# Add current directory to sys.path
current_dir = os.path.dirname(os.path.abspath(__file__))
if current_dir not in sys.path:
sys.path.insert(0, current_dir)
def seed_everything(seed=123):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def parse_args():
parser = argparse.ArgumentParser(description='Train DeForge-AI Model')
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--backbone-lr-scale', type=float, default=0.25)
parser.add_argument('--weight-decay', type=float, default=0.01)
parser.add_argument('--batch-size', type=int, default=16)
parser.add_argument('--epochs', type=int, default=1)
parser.add_argument('--max-steps', type=int, default=10000)
parser.add_argument('--num-workers', type=int, default=8)
parser.add_argument('--seed', type=int, default=123)
parser.add_argument('--image-size', type=int, default=256)
parser.add_argument('--gradient-clip', type=float, default=1.0)
parser.add_argument('--lora-r', type=int, default=16)
parser.add_argument('--lora-alpha', type=int, default=32)
parser.add_argument('--lora-dropout', type=float, default=0.5)
parser.add_argument('--forensic-dim', type=int, default=256)
parser.add_argument('--unfreeze-last-blocks', type=int, default=0)
parser.add_argument(
'--lora-target-modules',
type=str,
default='q_proj,k_proj,v_proj,out_proj,fc1,fc2',
)
parser.add_argument('--no-val', action='store_true')
parser.add_argument('--val-every', type=int, default=1)
parser.add_argument('--pct-start', type=float, default=0.1)
return parser.parse_args()
def get_amp_context(device):
if device.type == 'cuda':
return torch.amp.autocast(device_type='cuda', dtype=torch.float16)
return nullcontext()
def count_parameters(model):
total = sum(parameter.numel() for parameter in model.parameters())
trainable = sum(
parameter.numel() for parameter in model.parameters() if parameter.requires_grad
)
return trainable, total
def build_optimizer(model, args):
backbone_params = []
fast_params = []
for name, parameter in model.named_parameters():
if not parameter.requires_grad:
continue
if name.startswith('backbone') and 'lora_' not in name:
backbone_params.append(parameter)
else:
fast_params.append(parameter)
parameter_groups = []
if backbone_params:
parameter_groups.append(
{
'params': backbone_params,
'lr': args.lr * args.backbone_lr_scale,
}
)
if fast_params:
parameter_groups.append({'params': fast_params, 'lr': args.lr})
return optim.AdamW(
parameter_groups,
lr=args.lr,
weight_decay=args.weight_decay,
)
def save_checkpoint(path, epoch, global_step, model, optimizer, metrics, args):
torch.save(
{
'epoch': epoch,
'global_step': global_step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'metrics': metrics,
'args': vars(args),
},
path,
)
def run_validation(model, val_loader, criterion, device, epoch):
model.eval()
val_loss = 0.0
total = 0
all_preds = []
all_labels = []
print('Running validation...')
with torch.inference_mode():
for images, labels in tqdm(val_loader, desc='Validating', leave=False):
images = images.to(device, non_blocking=True)
labels = labels.to(device, non_blocking=True).unsqueeze(1)
with get_amp_context(device):
logits = model(images)
loss = criterion(logits, labels)
batch_size = labels.size(0)
val_loss += loss.item() * batch_size
total += batch_size
all_preds.append(torch.sigmoid(logits).cpu())
all_labels.append(labels.cpu())
all_preds = torch.cat(all_preds, dim=0).numpy()
all_labels = torch.cat(all_labels, dim=0).numpy()
threshold = 0.5
preds = (all_preds > threshold).astype(float)
acc = (preds == all_labels).mean()
real_mask = all_labels == 0
fake_mask = all_labels == 1
real_acc = (
(preds[real_mask] == all_labels[real_mask]).mean() if real_mask.any() else 0
)
fake_acc = (
(preds[fake_mask] == all_labels[fake_mask]).mean() if fake_mask.any() else 0
)
balanced_acc = 0.5 * (real_acc + fake_acc)
metrics = {
'val_loss': val_loss / max(total, 1),
'val_acc': float(acc),
'val_real_acc': float(real_acc),
'val_fake_acc': float(fake_acc),
'val_balanced_acc': float(balanced_acc),
}
print(
f'Epoch {epoch} | Val Loss: {metrics["val_loss"]:.4f} | '
f'Val Acc: {metrics["val_acc"]:.4f} | '
f'Val BAcc: {metrics["val_balanced_acc"]:.4f} | '
f'Real Acc: {metrics["val_real_acc"]:.4f} | '
f'Fake Acc: {metrics["val_fake_acc"]:.4f}'
)
return metrics
def train():
args = parse_args()
seed_everything(args.seed)
load_dotenv()
hf_token = os.getenv('HF_TOKEN')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
print(
f'Config: lr={args.lr}, batch_size={args.batch_size}, epochs={args.epochs}, '
f'max_steps={args.max_steps}, image_size={args.image_size}, '
f'val={"disabled" if args.no_val else f"every {args.val_every} epoch(s)"}'
)
checkpoints_dir = os.path.join(current_dir, 'checkpoints')
os.makedirs(checkpoints_dir, exist_ok=True)
print('Loading AIGIBench dataset from HuggingFace...')
dataset = load_dataset('TheKernel01/AIGIBench', token=hf_token)
train_ds = AIGIBenchDataset(
dataset['train'],
transform=get_train_transforms(size=args.image_size),
)
train_loader = DataLoader(
train_ds,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
pin_memory=device.type == 'cuda',
)
val_loader = None
if not args.no_val:
val_ds = AIGIBenchDataset(
dataset['validation'],
transform=get_val_transforms(
size=args.image_size,
),
)
val_loader = DataLoader(
val_ds,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
pin_memory=device.type == 'cuda',
)
model = DeForge_AI_Model(
lora_r=args.lora_r,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
lora_target_modules=[
module.strip()
for module in args.lora_target_modules.split(',')
if module.strip()
],
forensic_dim=args.forensic_dim,
unfreeze_last_blocks=args.unfreeze_last_blocks,
image_size=args.image_size,
).to(device)
trainable_params, total_params = count_parameters(model)
print(
f'Trainable params: {trainable_params / 1e6:.2f}M / '
f'{total_params / 1e6:.2f}M ({100 * trainable_params / max(total_params, 1):.2f}%)'
)
optimizer = build_optimizer(model, args)
criterion = nn.BCEWithLogitsLoss()
scheduler = optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=[group['lr'] for group in optimizer.param_groups],
total_steps=args.max_steps,
pct_start=args.pct_start,
anneal_strategy='cos',
)
scaler = torch.amp.GradScaler('cuda') if device.type == 'cuda' else None
best_balanced_acc = float('-inf')
global_step = 0
for epoch in range(1, args.epochs + 1):
model.train()
running_loss = 0.0
remaining_steps = max(args.max_steps - global_step, 0)
epoch_total = min(len(train_loader), remaining_steps) if remaining_steps else 0
pbar = tqdm(
train_loader, desc=f'Epoch {epoch}/{args.epochs}', total=epoch_total
)
for step_idx, (images, labels) in enumerate(pbar, start=1):
if global_step >= args.max_steps:
break
images = images.to(device, non_blocking=True)
labels = labels.to(device, non_blocking=True).unsqueeze(1)
optimizer.zero_grad(set_to_none=True)
with get_amp_context(device):
logits = model(images)
loss = criterion(logits, labels)
if scaler is not None:
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
if args.gradient_clip is not None:
torch.nn.utils.clip_grad_norm_(
[p for p in model.parameters() if p.requires_grad],
args.gradient_clip,
)
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
if args.gradient_clip is not None:
torch.nn.utils.clip_grad_norm_(
[p for p in model.parameters() if p.requires_grad],
args.gradient_clip,
)
optimizer.step()
scheduler.step()
global_step += 1
running_loss += loss.item()
if global_step % 1000 == 0:
step_checkpoint_path = os.path.join(
checkpoints_dir, f'model_step_{global_step}.pth'
)
save_checkpoint(
step_checkpoint_path,
epoch,
global_step,
model,
optimizer,
{},
args,
)
print(f'Saved periodic checkpoint to {step_checkpoint_path}')
if step_idx % 10 == 0 or step_idx == 1:
current_lr = max(group['lr'] for group in optimizer.param_groups)
pbar.set_postfix(
{
'loss': f'{running_loss / step_idx:.4f}',
'lr': f'{current_lr:.2e}',
}
)
metrics = {}
should_validate = (
not args.no_val and val_loader is not None and epoch % args.val_every == 0
)
if should_validate:
metrics = run_validation(model, val_loader, criterion, device, epoch)
checkpoint_name = (
f'deforge_ai_epoch_{epoch}_bacc_{metrics["val_balanced_acc"]:.4f}.pth'
if metrics
else f'deforge_ai_epoch_{epoch}_step_{global_step}.pth'
)
checkpoint_path = os.path.join(checkpoints_dir, checkpoint_name)
save_checkpoint(
checkpoint_path,
epoch,
global_step,
model,
optimizer,
metrics,
args,
)
print(f'Saved checkpoint to {checkpoint_path}')
if metrics and metrics['val_balanced_acc'] > best_balanced_acc:
best_balanced_acc = metrics['val_balanced_acc']
best_path = os.path.join(checkpoints_dir, 'model_epoch_best.pth')
save_checkpoint(
best_path,
epoch,
global_step,
model,
optimizer,
metrics,
args,
)
print(f'Updated best checkpoint at {best_path}')
elif not metrics:
best_path = os.path.join(checkpoints_dir, 'model_epoch_best.pth')
if not os.path.exists(best_path):
save_checkpoint(
best_path,
epoch,
global_step,
model,
optimizer,
metrics,
args,
)
print(f'Initialized best checkpoint at {best_path}')
latest_path = os.path.join(checkpoints_dir, 'model_epoch_last.pth')
save_checkpoint(
latest_path,
epoch,
global_step,
model,
optimizer,
metrics,
args,
)
if global_step >= args.max_steps:
print(f'Reached max_steps={args.max_steps}, stopping.')
break
print('Training complete!')
if __name__ == '__main__':
train()