|
|
import os |
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from torchvision.transforms.functional import to_tensor, to_pil_image |
|
|
import torchvision.transforms as transforms |
|
|
from transformers import AutoModel |
|
|
from transformers import AutoTokenizer, AutoConfig |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.autograd import Variable |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from torch.cuda.amp import autocast, GradScaler |
|
|
from torch.utils.data.distributed import DistributedSampler |
|
|
from torch.utils.data import RandomSampler, SequentialSampler |
|
|
|
|
|
from tqdm import tqdm |
|
|
import random |
|
|
import numpy as np |
|
|
|
|
|
from rich import print |
|
|
import time |
|
|
import cv2 |
|
|
|
|
|
import string |
|
|
from torch.optim import AdamW |
|
|
from transformers import get_linear_schedule_with_warmup |
|
|
|
|
|
|
|
|
from models import get_model |
|
|
from dataset import MyDataset |
|
|
from utils import save_checkpoint, AverageMeter, ProgressMeter |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
if "RANK" in os.environ and "WORLD_SIZE" in os.environ: |
|
|
torch.distributed.init_process_group(backend="nccl") |
|
|
local_rank = int(os.environ["LOCAL_RANK"]) |
|
|
torch.cuda.set_device(local_rank) |
|
|
device = torch.device("cuda", local_rank) |
|
|
print(f"[Distributed] Rank {os.environ['RANK']} using device {local_rank}") |
|
|
else: |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
print(f"[Single] Using device: {device}") |
|
|
|
|
|
|
|
|
scaler = torch.amp.GradScaler(device='cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
local_rank = int(os.environ.get("LOCAL_RANK", 0)) |
|
|
print_raw = print |
|
|
def print(*info): |
|
|
if local_rank == 0: |
|
|
print_raw(*info) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def crossentropy(y_true, y_pred): |
|
|
return F.cross_entropy(y_pred, y_true, label_smoothing=0.2) |
|
|
|
|
|
|
|
|
def evaluate(predictions, labels): |
|
|
nb_all = len(predictions) |
|
|
acc = sum([int(p==l) for p, l in zip(predictions, labels)]) / (nb_all + 1e-8) |
|
|
|
|
|
eval_results = {'acc': acc} |
|
|
|
|
|
return eval_results |
|
|
|
|
|
|
|
|
def train_epoch(model, optimizer, epoch, dataloader, sampler, tokenizer, scheduler): |
|
|
print(f"\n\n=> train") |
|
|
data_time = AverageMeter('- data', ':4.3f') |
|
|
batch_time = AverageMeter('- batch', ':6.3f') |
|
|
losses = AverageMeter('- loss', ':.4e') |
|
|
acces = AverageMeter('- acc', ':.4f') |
|
|
progress = ProgressMeter( |
|
|
len(dataloader), data_time, batch_time, losses, acces, prefix=f"Epoch: [{epoch}]") |
|
|
|
|
|
end = time.time() |
|
|
model.train() |
|
|
if hasattr(sampler, "set_epoch"): |
|
|
sampler.set_epoch(epoch) |
|
|
|
|
|
predictions, labels = [], [] |
|
|
|
|
|
for batch_index, data_batch in enumerate(dataloader): |
|
|
optimizer.zero_grad() |
|
|
|
|
|
context_str_batch, target_batch = data_batch |
|
|
|
|
|
|
|
|
context_token_batch = tokenizer(context_str_batch, padding=True, truncation=True, max_length=500, return_tensors='pt') |
|
|
|
|
|
|
|
|
context_token_batch = {k:v.to(device) for k,v in context_token_batch.items()} |
|
|
target_batch = target_batch.to(device) |
|
|
|
|
|
|
|
|
data_input_batch = context_token_batch |
|
|
output_batch = model(**data_input_batch) |
|
|
|
|
|
pred_batch = output_batch.softmax(dim=-1) |
|
|
|
|
|
loss_batch = crossentropy(target_batch, output_batch) |
|
|
loss = torch.mean(loss_batch) |
|
|
|
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
if scheduler is not None: |
|
|
scheduler.step() |
|
|
|
|
|
loss_value = loss.item() |
|
|
losses.update(loss_value, len(target_batch)) |
|
|
pred = torch.argmax(pred_batch, dim=-1) |
|
|
predictions.extend(pred.cpu().numpy()) |
|
|
labels.extend(target_batch.cpu().numpy()) |
|
|
acc_batch = (target_batch==pred).sum().cpu().numpy() / (len(target_batch) + 1e-8) |
|
|
acces.update(acc_batch, len(target_batch)) |
|
|
batch_time.update(time.time() - end) |
|
|
end = time.time() |
|
|
|
|
|
if batch_index % 50 == 0: |
|
|
progress.print(batch_index) |
|
|
|
|
|
results = evaluate(predictions, labels) |
|
|
print(results) |
|
|
return results |
|
|
|
|
|
|
|
|
def val_epoch(model, optimizer, epoch, dataloader, sampler, tokenizer): |
|
|
print(f"\n\n=> val") |
|
|
data_time = AverageMeter('- data', ':4.3f') |
|
|
batch_time = AverageMeter('- batch', ':6.3f') |
|
|
losses = AverageMeter('- loss', ':.4e') |
|
|
acces = AverageMeter('- acc', ':.4f') |
|
|
progress = ProgressMeter( |
|
|
len(dataloader), data_time, batch_time, losses, acces, prefix=f"Epoch: [{epoch}]") |
|
|
|
|
|
end = time.time() |
|
|
model.train() |
|
|
if hasattr(sampler, "set_epoch"): |
|
|
sampler.set_epoch(epoch) |
|
|
|
|
|
predictions, labels = [], [] |
|
|
|
|
|
for batch_index, data_batch in enumerate(dataloader): |
|
|
optimizer.zero_grad() |
|
|
|
|
|
context_str_batch, target_batch = data_batch |
|
|
|
|
|
|
|
|
context_token_batch = tokenizer(context_str_batch, padding=True, truncation=True, max_length=500, return_tensors='pt') |
|
|
|
|
|
|
|
|
context_token_batch = {k:v.to(device) for k,v in context_token_batch.items()} |
|
|
target_batch = target_batch.to(device) |
|
|
|
|
|
|
|
|
data_input_batch = context_token_batch |
|
|
output_batch = model(**data_input_batch) |
|
|
|
|
|
pred_batch = output_batch.softmax(dim=-1) |
|
|
|
|
|
loss_batch = crossentropy(target_batch, output_batch) |
|
|
loss = torch.mean(loss_batch) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loss_value = loss.item() |
|
|
losses.update(loss_value, len(target_batch)) |
|
|
pred = torch.argmax(pred_batch, dim=-1) |
|
|
predictions.extend(pred.cpu().numpy()) |
|
|
labels.extend(target_batch.cpu().numpy()) |
|
|
acc_batch = (target_batch==pred).sum().cpu().numpy() / (len(target_batch) + 1e-8) |
|
|
acces.update(acc_batch, len(target_batch)) |
|
|
batch_time.update(time.time() - end) |
|
|
end = time.time() |
|
|
|
|
|
if batch_index % 50 == 0: |
|
|
progress.print(batch_index) |
|
|
|
|
|
results = evaluate(predictions, labels) |
|
|
print(results) |
|
|
return results |
|
|
|
|
|
|
|
|
|
|
|
def gogogo(): |
|
|
|
|
|
output_dir = '/home/elaine/Desktop/macbert_code/checkpoints_travel' |
|
|
ann_file_tra = '/home/elaine/Desktop/macbert_code/dataset/travel_train_9000.csv' |
|
|
ann_file_val = '/home/elaine/Desktop/macbert_code/dataset/travel_val_9000.csv' |
|
|
checkpoint_file = None |
|
|
|
|
|
batch_size = 4 |
|
|
epochs = 20 |
|
|
cache_dir = ' /home/elaine/Desktop/macbert_code/cache' |
|
|
|
|
|
model_cfg = { |
|
|
"pretrained_transformers": "hfl/chinese-macbert-base", |
|
|
"cache_dir": cache_dir |
|
|
} |
|
|
|
|
|
|
|
|
model_dict = get_model(model_cfg, mode='base') |
|
|
model = model_dict['model'] |
|
|
tokenizer = model_dict['tokenizer'] |
|
|
print(model) |
|
|
|
|
|
|
|
|
no_decay = ['bias', 'LayerNorm.weight'] |
|
|
optimizer_grouped_parameters = [ |
|
|
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], |
|
|
'weight_decay': 0.01}, |
|
|
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], |
|
|
'weight_decay': 0.0} |
|
|
] |
|
|
optimizer = AdamW(model.parameters(), lr=1e-5, eps=1e-8) |
|
|
scheduler = None |
|
|
|
|
|
|
|
|
data_loader_cfg = {} |
|
|
tra_dataset = MyDataset(ann_file_tra, data_loader_cfg, mode='tra') |
|
|
val_dataset = MyDataset(ann_file_val, {}, mode='val') |
|
|
|
|
|
|
|
|
sampler_tra = RandomSampler(tra_dataset) |
|
|
sampler_val = SequentialSampler(val_dataset) |
|
|
|
|
|
tra_loader = DataLoader(tra_dataset, batch_size=batch_size, num_workers=8, pin_memory=True, shuffle=True) |
|
|
val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=8, pin_memory=True, shuffle=False) |
|
|
|
|
|
|
|
|
if checkpoint_file is not None and os.path.exists(checkpoint_file): |
|
|
checkpoint = torch.load(checkpoint_file, map_location='cpu') |
|
|
init_epoch = checkpoint['epoch'] + 1 |
|
|
model.load_state_dict({k.replace('module.', ''): v for k, v in checkpoint['state_dict'].items()}) |
|
|
optimizer.load_state_dict(checkpoint['optimizer']) |
|
|
if torch.cuda.is_available(): |
|
|
for state in optimizer.state.values(): |
|
|
for k, v in state.items(): |
|
|
if torch.is_tensor(v): |
|
|
state[k] = v.cuda() |
|
|
print(f"=> Resume: loaded checkpoint {checkpoint_file} (epoch {checkpoint['epoch']})") |
|
|
else: |
|
|
init_epoch = 1 |
|
|
print(f"=> No checkpoint. ") |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
model = model.to(device) |
|
|
|
|
|
|
|
|
acc = 0. |
|
|
for epoch in range(init_epoch, epochs + 1): |
|
|
results_tra = train_epoch(model, optimizer, epoch, tra_loader, sampler_tra, tokenizer, scheduler) |
|
|
results_val = val_epoch(model, optimizer, epoch, val_loader, sampler_val, tokenizer) |
|
|
acc_val = results_val['acc'] |
|
|
if acc_val >= acc: |
|
|
acc = acc_val |
|
|
save_checkpoint({ |
|
|
'epoch': epoch, |
|
|
'state_dict': model.state_dict(), |
|
|
'best_acc': acc, |
|
|
'optimizer': optimizer.state_dict(), |
|
|
}, outname=f'{output_dir}/checkpoint_epoch{epoch:03d}_acc{acc:.4f}.pth.tar', local_rank=0) |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
gogogo() |
|
|
|