|
|
import os |
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from torchvision.transforms.functional import to_tensor, to_pil_image |
|
|
import torchvision.transforms as transforms |
|
|
from transformers import AutoModel |
|
|
from transformers import AutoTokenizer, AutoConfig |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.autograd import Variable |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from torch.cuda.amp import autocast, GradScaler |
|
|
from torch.utils.data.distributed import DistributedSampler |
|
|
|
|
|
from tqdm import tqdm |
|
|
import random |
|
|
import numpy as np |
|
|
from collections import OrderedDict |
|
|
from rich import print |
|
|
import time |
|
|
import cv2 |
|
|
from glob import glob |
|
|
import string |
|
|
from torch.optim import AdamW |
|
|
from transformers import get_linear_schedule_with_warmup |
|
|
from models import get_model |
|
|
from dataset import MyDataset |
|
|
from utils import save_checkpoint, AverageMeter, ProgressMeter |
|
|
|
|
|
|
|
|
def test_epoch(model, epoch, dataloader, tokenizer): |
|
|
print(f"\n\n=> val") |
|
|
data_time = AverageMeter('- data', ':4.3f') |
|
|
batch_time = AverageMeter('- batch', ':6.3f') |
|
|
progress = ProgressMeter( |
|
|
len(dataloader), data_time, batch_time, prefix=f"Epoch: [{epoch}]") |
|
|
|
|
|
end = time.time() |
|
|
device = 'cuda:0' if torch.cuda.is_available() else 'cpu' |
|
|
model.to(device) |
|
|
model.eval() |
|
|
device = 'cuda:0' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
predictions = [] |
|
|
|
|
|
for batch_index, data_batch in enumerate(tqdm(dataloader)): |
|
|
context_str_batch = data_batch |
|
|
|
|
|
|
|
|
context_token_batch = tokenizer(context_str_batch, padding=True, truncation=True, max_length=500, return_tensors='pt') |
|
|
|
|
|
|
|
|
context_token_batch = {k:v.to(device) for k,v in context_token_batch.items()} |
|
|
|
|
|
|
|
|
data_input_batch = context_token_batch |
|
|
output_batch = model(**data_input_batch) |
|
|
|
|
|
pred_batch = output_batch.softmax(dim=-1) |
|
|
pred = torch.argmax(pred_batch, dim=-1) |
|
|
predictions.extend(pred.cpu().numpy()) |
|
|
|
|
|
batch_time.update(time.time() - end) |
|
|
end = time.time() |
|
|
|
|
|
if batch_index % 50 == 0: |
|
|
progress.print(batch_index) |
|
|
|
|
|
return predictions |
|
|
|
|
|
|
|
|
def infer20221212(): |
|
|
checkpoint_file = '/home/elaine/Desktop/macbert_code/checkpoints_name/checkpoint_epoch015_acc1.0000.pth.tar' |
|
|
output_file = r'/home/elaine/Desktop/macbert_code/output.csv' |
|
|
|
|
|
cache_dir = '/home/elaine/Desktop/macbert_code/code/cache' |
|
|
ann_file_test = r'/home/elaine/Desktop/macbert_code/dataset/name_test_8000.csv' |
|
|
|
|
|
model_cfg = { |
|
|
"pretrained_transformers": "hfl/chinese-macbert-base", |
|
|
"cache_dir": cache_dir |
|
|
} |
|
|
|
|
|
model_dict = get_model(model_cfg, mode='base') |
|
|
model = model_dict['model'] |
|
|
tokenizer = model_dict['tokenizer'] |
|
|
print(model) |
|
|
|
|
|
|
|
|
data_loader_cfg = {} |
|
|
test_dataset = MyDataset(ann_file_test, data_loader_cfg, mode='test') |
|
|
test_loader = DataLoader(test_dataset, batch_size=8, num_workers=4, pin_memory=True) |
|
|
|
|
|
|
|
|
assert checkpoint_file is not None and os.path.exists(checkpoint_file) |
|
|
checkpoint = torch.load(checkpoint_file, map_location='cpu') |
|
|
|
|
|
model.load_state_dict({k.replace('module.', ''): v for k, v in checkpoint['state_dict'].items()}) |
|
|
print(f"=> Resume: loaded checkpoint {checkpoint_file} (epoch {checkpoint['epoch']})") |
|
|
|
|
|
|
|
|
pred_res = test_epoch(model, 1, test_loader, tokenizer) |
|
|
with open(output_file, 'w') as f: |
|
|
for pred in pred_res: |
|
|
f.write(f"{pred}\n") |
|
|
|
|
|
|
|
|
import csv |
|
|
true_labels = [] |
|
|
with open(ann_file_test, 'r') as f: |
|
|
reader = csv.reader(f) |
|
|
next(reader) |
|
|
for row in reader: |
|
|
true_labels.append(int(row[3])) |
|
|
|
|
|
|
|
|
from sklearn.metrics import confusion_matrix |
|
|
cm = confusion_matrix(true_labels, pred_res) |
|
|
print('Confusion Matrix:') |
|
|
print(cm) |
|
|
|
|
|
|
|
|
with open(ann_file_test) as f: |
|
|
reader = csv.reader(f) |
|
|
next(reader) |
|
|
for idx, row in enumerate(reader): |
|
|
sms, label = row[1], int(row[3]) |
|
|
pred = pred_res[idx] |
|
|
if pred != label: |
|
|
print(f"錯誤: sms='{sms}',預測={pred},正確={label}") |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
infer20221212() |