| import os |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import Dataset, DataLoader |
| from torchvision.transforms.functional import to_tensor, to_pil_image |
| import torchvision.transforms as transforms |
| from transformers import AutoModel |
| from transformers import AutoTokenizer, AutoConfig |
| import torch |
| import torch.nn as nn |
| from torch.autograd import Variable |
| from torch.utils.data import Dataset, DataLoader |
| from torch.cuda.amp import autocast, GradScaler |
| from torch.utils.data.distributed import DistributedSampler |
|
|
| from tqdm import tqdm |
| import random |
| import numpy as np |
| from collections import OrderedDict |
| from rich import print |
| import time |
| import cv2 |
| from glob import glob |
| import string |
| from torch.optim import AdamW |
| from transformers import get_linear_schedule_with_warmup |
| from models import get_model |
| from dataset import MyDataset |
| from utils import save_checkpoint, AverageMeter, ProgressMeter |
|
|
|
|
| def test_epoch(model, epoch, dataloader, tokenizer): |
| print(f"\n\n=> val") |
| data_time = AverageMeter('- data', ':4.3f') |
| batch_time = AverageMeter('- batch', ':6.3f') |
| progress = ProgressMeter( |
| len(dataloader), data_time, batch_time, prefix=f"Epoch: [{epoch}]") |
|
|
| end = time.time() |
| device = 'cuda:0' if torch.cuda.is_available() else 'cpu' |
| model.to(device) |
| model.eval() |
| device = 'cuda:0' if torch.cuda.is_available() else 'cpu' |
|
|
| sms_ids = [] |
| probs = [] |
| predictions = [] |
|
|
| for batch_index, data_batch in enumerate(tqdm(dataloader)): |
| |
| context_str_batch, sms_id = data_batch |
| sms_ids.append(sms_id.detach().cpu().numpy()[0]) |
|
|
| |
| context_token_batch = tokenizer(context_str_batch, padding=True, truncation=True, max_length=500, return_tensors='pt') |
| |
| |
| context_token_batch = {k:v.to(device) for k,v in context_token_batch.items()} |
|
|
| |
| data_input_batch = context_token_batch |
| output_batch = model(**data_input_batch) |
|
|
| pred_batch = output_batch.softmax(dim=-1) |
| probs.append(pred_batch.detach().cpu().numpy()[0][1]) |
| pred = torch.argmax(pred_batch, dim=-1) |
| predictions.extend(pred.cpu().numpy()) |
|
|
| batch_time.update(time.time() - end) |
| end = time.time() |
|
|
| if batch_index % 50 == 0: |
| progress.print(batch_index) |
|
|
| return predictions, probs, sms_ids |
|
|
|
|
| def infer20221212(): |
| checkpoint_file = '/home/elaine/Desktop/macbert_code/checkpoints_travel/checkpoint_epoch017_acc0.9988.pth.tar' |
| output_file = r'/home/elaine/Desktop/macbert_code/travel_v2_output.csv' |
|
|
| cache_dir = '/home/elaine/Desktop/macbert_code/cache' |
| ann_file_test = r'/home/elaine/Desktop/macbert_code/dataset/travel_test_9000.csv' |
|
|
| model_cfg = { |
| "pretrained_transformers": "hfl/chinese-macbert-base", |
| "cache_dir": cache_dir |
| } |
| |
| model_dict = get_model(model_cfg, mode='base') |
| model = model_dict['model'] |
| tokenizer = model_dict['tokenizer'] |
| print(model) |
|
|
|
|
| data_loader_cfg = {} |
| test_dataset = MyDataset(ann_file_test, data_loader_cfg, mode='test') |
| test_loader = DataLoader(test_dataset, batch_size=1, pin_memory=True, shuffle=False) |
|
|
| |
| assert checkpoint_file is not None and os.path.exists(checkpoint_file) |
| checkpoint = torch.load(checkpoint_file, map_location='cpu') |
| |
| model.load_state_dict({k.replace('module.', ''): v for k, v in checkpoint['state_dict'].items()}) |
| print(f"=> Resume: loaded checkpoint {checkpoint_file} (epoch {checkpoint['epoch']})") |
|
|
| |
| pred_res, probs, sms_ids = test_epoch(model, 1, test_loader, tokenizer) |
| with open(output_file, 'w') as f: |
| f.write("sms_id,prob,label\n") |
| for pred, prob, sms_id in zip(pred_res, probs, sms_ids): |
| f.write(f"{sms_id},{prob},{pred}\n") |
|
|
| |
| import csv |
| true_labels = [] |
| with open(ann_file_test, 'r', encoding='utf-8') as f: |
| reader = csv.reader(f) |
| next(reader) |
| for row in reader: |
| true_labels.append(int(row[2])) |
|
|
| |
| from sklearn.metrics import confusion_matrix |
| cm = confusion_matrix(true_labels, pred_res) |
| print('Confusion Matrix:') |
| print(cm) |
|
|
| |
| with open(ann_file_test, 'r', encoding='utf-8') as f: |
| reader = csv.reader(f) |
| next(reader) |
| for idx, row in enumerate(reader): |
| id, sms, label = int(row[0]), row[1], int(row[2]) |
| pred = pred_res[idx] |
| if pred != label: |
| print(f"錯誤: sms_id={id},sms='{sms}',預測={pred},正確={label}") |
|
|
|
|
| if __name__ == '__main__': |
| infer20221212() |
|
|