# -*- coding: utf-8 -*- from torch.utils.data import DataLoader import tqdm from torch.cuda.amp import GradScaler, autocast import torch.nn.functional as F from torch import nn import torch import numpy as np from torch.utils.tensorboard import SummaryWriter import datetime import os import json from metrics import get_roc_metrics, get_precision_recall_metrics from torch.optim.lr_scheduler import CosineAnnealingLR import time from utils import GpuMem try: from transformers import AdamW except: from torch.optim import AdamW def evaluate_model_SPO(model, data, DEVICE): model.to(DEVICE) model.eval() loss = 0 eval_loader = DataLoader(data, batch_size=1, shuffle=False) epoch_crit_train_original, epoch_crit_train_sampled = [],[] start_time = time.time() with torch.no_grad(): for batch in tqdm.tqdm(eval_loader, desc="Evaluating"): text = batch output = model(text) loss += output['loss'].item() epoch_crit_train_original.extend(output['crit'][1].tolist()) epoch_crit_train_sampled.extend(output['crit'][3].tolist()) print(f"Total time: {time.time() - start_time:.4f}s") avg_loss = loss / len(eval_loader) fpr, tpr, roc_auc = get_roc_metrics(epoch_crit_train_original, epoch_crit_train_sampled) p, r, pr_auc = get_precision_recall_metrics(epoch_crit_train_original, epoch_crit_train_sampled) # print(f"val_loss: {avg_loss:.6f}") print(f"val_ROC_AUC: {roc_auc:.4f}, PR AUC: {pr_auc:.4f}") print(f"val_Real_mean/std: {np.mean(epoch_crit_train_original):.2f}/{np.std(epoch_crit_train_original):.2f}, val_Samples_mean/std: {np.mean(epoch_crit_train_sampled):.2f}/{np.std(epoch_crit_train_sampled):.2f}") print("="*10) results_dict = { "name": "imbd", 'info': {'n_samples': len(epoch_crit_train_original)}, 'predictions': {'real': epoch_crit_train_original, 'samples': epoch_crit_train_sampled}, 'metrics': {'roc_auc': roc_auc, 'fpr': fpr, 'tpr': tpr}, 'pr_metrics': {'pr_auc': pr_auc, 'precision': p, 'recall': r}, } return results_dict def fine_tune_ours(model, data, DEVICE, ckpt_dir='./ckpt', args=None): current_time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") writer = SummaryWriter(log_dir=f"./scripts/ImBD/logs/{args.task_name}_spo_lr_{args.lr}_beta_{args.beta}_a_{args.a}_{current_time}/train_ai_detection") train_loader = DataLoader(data[0], batch_size=1, shuffle=True) epochs = args.epochs optimizer = AdamW(model.parameters(), lr=args.lr) scheduler = CosineAnnealingLR(optimizer, T_max=len(train_loader) * epochs, eta_min=0, last_epoch=-1) scaler = GradScaler() model.to(DEVICE) # Number of iterations for gradient accumulation accumulation_steps = args.a epoch_losses, i, loss = [], 0, torch.tensor(0.0).to(DEVICE) epoch_crit_train_original, epoch_crit_train_sampled = [],[] start_time = time.time() for epoch in range(epochs): optimizer.zero_grad() start_time = time.time() for batch in tqdm.tqdm(train_loader, desc=f"Fine-tuning: {epoch} epoch"): text = batch scheduler.step() with autocast(): outputs_1 = model(text) epoch_crit_train_original.extend(outputs_1['crit'][1].tolist()) epoch_crit_train_sampled.extend(outputs_1['crit'][3].tolist()) loss += (outputs_1['loss'].to(torch.float32)) / accumulation_steps if ((i + 1) % accumulation_steps) == 0: scaler.scale(loss).backward() scaler.step(optimizer) optimizer.zero_grad() scaler.update() writer.add_scalar('Loss/train', loss.item(), i) epoch_losses.append(loss.item()) loss = torch.tensor(0.0).to(DEVICE) epoch_losses.append(loss.item()) i += 1 print(f"Total time: {time.time() - start_time:.4f}s") fpr, tpr, roc_auc = get_roc_metrics(epoch_crit_train_original, epoch_crit_train_sampled) p, r, pr_auc = get_precision_recall_metrics(epoch_crit_train_original, epoch_crit_train_sampled) print(f"ROC AUC: {roc_auc:.4f}, PR AUC: {pr_auc:.4f}") print(f"Real mean/std: {np.mean(epoch_crit_train_original):.2f}/{np.std(epoch_crit_train_original):.2f}, Samples mean/std: {np.mean(epoch_crit_train_sampled):.2f}/{np.std(epoch_crit_train_sampled):.2f}") epoch_avg_loss = np.mean(epoch_losses) writer.add_scalar('Loss/epoch', epoch_avg_loss, epoch) writer.add_scalar('ROC_AUC/epoch', roc_auc, epoch) writer.add_scalar('PR_AUC/epoch', pr_auc, epoch) writer.add_scalar('Real_mean/epoch',np.mean(epoch_crit_train_original),epoch) writer.add_scalar('Real_std/epoch',np.std(epoch_crit_train_original),epoch) writer.add_scalar('Sampled_mean/epoch',np.mean(epoch_crit_train_sampled),epoch) writer.add_scalar('Sampled_std/epoch',np.std(epoch_crit_train_sampled),epoch) epoch_crit_train_original, epoch_crit_train_sampled = [],[] # reset crit print(f"\nAverage Loss for Epoch {epoch}: {epoch_avg_loss}") # if not os.path.exists(ckpt_dir): # os.makedirs(ckpt_dir) # model.save_pretrained(ckpt_dir) # print(f"Saved finetuned model to {os.path.join(ckpt_dir, 'ours-finetuned.pth')}") writer.close() return model def run( model, data, DEVICE, args, ckpt_dir='./ckpt', ): if args.ebt or args.eval_only: print("Evaluating model before tuning...") d = evaluate_model_SPO(model, data[1], DEVICE) if args.SPOtrained: output_path = f"{args.output_file}.imbd.json" else: method_name=args.base_model.split("_")[-1] output_path = f"{args.output_file}.{method_name}.json" with open(output_path, "w") as j: json.dump(d,j) print(f"Results saved to {output_path}.") if args.eval_only: return tracker = GpuMem() print('Fine-tuning model...') start = time.perf_counter() with tracker: model = fine_tune_ours( model, data, DEVICE=DEVICE, ckpt_dir=ckpt_dir, args=args ) pre_time = time.perf_counter() - start pre_memory = tracker.memory_usage() if args.eval_after_train: print("Evaluating model after tuning...") start = time.perf_counter() with tracker: d = evaluate_model_SPO(model, data[1], DEVICE) eval_time = time.perf_counter() - start eval_time = eval_time / (len(data[1]) << 1) eval_memory = tracker.memory_usage() d['compute_info'] = {'pre_time': pre_time, 'eval_time': eval_time, 'pre_memory': pre_memory, 'eval_memory': eval_memory,} if args.SPOtrained: output_path = f"{args.output_file}.imbd.json" else: method_name=args.base_model.split("_")[-1] output_path = f"{args.output_file}.{method_name}.json" with open(output_path, "w") as j: json.dump(d, j) print(f"Results saved to {output_path}.")