|
|
import os |
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from torchvision.transforms.functional import to_tensor, to_pil_image |
|
|
import torchvision.transforms as transforms |
|
|
from transformers import AutoModel |
|
|
from transformers import AutoTokenizer, AutoConfig |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.autograd import Variable |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from torch.cuda.amp import autocast, GradScaler |
|
|
from torch.utils.data.distributed import DistributedSampler |
|
|
|
|
|
from tqdm import tqdm |
|
|
import random |
|
|
import numpy as np |
|
|
from collections import OrderedDict |
|
|
from rich import print |
|
|
import time |
|
|
from glob import glob |
|
|
import string |
|
|
from torch.optim import AdamW |
|
|
from transformers import get_linear_schedule_with_warmup |
|
|
from models import get_model |
|
|
from dataset import MyDataset |
|
|
from utils import save_checkpoint, AverageMeter, ProgressMeter |
|
|
|
|
|
|
|
|
def test_epoch(travel_model, name_model, epoch, dataloader, tokenizer): |
|
|
print(f"\n\n=> val") |
|
|
data_time = AverageMeter('- data', ':4.3f') |
|
|
batch_time = AverageMeter('- batch', ':6.3f') |
|
|
progress = ProgressMeter( |
|
|
len(dataloader), data_time, batch_time, prefix=f"Epoch: [{epoch}]") |
|
|
|
|
|
end = time.time() |
|
|
device = 'cuda:0' if torch.cuda.is_available() else 'cpu' |
|
|
travel_model.to(device) |
|
|
travel_model.eval() |
|
|
|
|
|
name_model.to(device) |
|
|
name_model.eval() |
|
|
|
|
|
sms_ids = [] |
|
|
travel_probs = [] |
|
|
travel_predictions = [] |
|
|
name_probs = [] |
|
|
name_predictions = [] |
|
|
|
|
|
for batch_index, data_batch in enumerate(tqdm(dataloader)): |
|
|
|
|
|
context_str_batch, sms_id = data_batch |
|
|
sms_ids.append(sms_id.detach().cpu().numpy()[0]) |
|
|
|
|
|
|
|
|
context_token_batch = tokenizer(context_str_batch, padding=True, truncation=True, max_length=500, return_tensors='pt') |
|
|
|
|
|
|
|
|
context_token_batch = {k:v.to(device) for k,v in context_token_batch.items()} |
|
|
|
|
|
|
|
|
data_input_batch = context_token_batch |
|
|
travel_output_batch = travel_model(**data_input_batch) |
|
|
name_output_batch = name_model(**data_input_batch) |
|
|
|
|
|
travel_pred_batch = travel_output_batch.softmax(dim=-1) |
|
|
travel_probs.append(travel_pred_batch.detach().cpu().numpy()[0][1]) |
|
|
travel_pred = torch.argmax(travel_pred_batch, dim=-1) |
|
|
travel_predictions.extend(travel_pred.cpu().numpy()) |
|
|
|
|
|
|
|
|
name_pred_batch = name_output_batch.softmax(dim=-1) |
|
|
name_probs.append(name_pred_batch.detach().cpu().numpy()[0][1]) |
|
|
name_pred = torch.argmax(name_pred_batch, dim=-1) |
|
|
name_predictions.extend(name_pred.cpu().numpy()) |
|
|
|
|
|
batch_time.update(time.time() - end) |
|
|
end = time.time() |
|
|
|
|
|
if batch_index % 50 == 0: |
|
|
progress.print(batch_index) |
|
|
|
|
|
return travel_predictions, travel_probs, name_predictions, name_probs, sms_ids |
|
|
|
|
|
def inference(): |
|
|
travel_checkpoint_file = 'checkpoints/saved_checkpoints/travel_checkpoint15_train8000.pth.tar' |
|
|
name_checkpoint_file = 'checkpoints/saved_checkpoints/name_checkpoint17_train9000.pth.tar' |
|
|
ann_file_test = 'dataset/datagame_sms_stage1(in).csv' |
|
|
output_file = 'both_macbertBase_20250731_2.csv' |
|
|
cache_dir = 'cache' |
|
|
|
|
|
model_cfg = { |
|
|
"pretrained_transformers": "hfl/chinese-macbert-base", |
|
|
"cache_dir": cache_dir |
|
|
} |
|
|
|
|
|
travel_model_dict = get_model(model_cfg, mode='base') |
|
|
travel_model = travel_model_dict['model'] |
|
|
|
|
|
name_model_dict = get_model(model_cfg, mode='base') |
|
|
name_model = name_model_dict['model'] |
|
|
|
|
|
|
|
|
tokenizer = travel_model_dict['tokenizer'] |
|
|
|
|
|
|
|
|
|
|
|
data_loader_cfg = {} |
|
|
test_dataset = MyDataset(ann_file_test, data_loader_cfg, mode='test') |
|
|
test_loader = DataLoader(test_dataset, batch_size=1, pin_memory=True, shuffle=False) |
|
|
|
|
|
|
|
|
assert travel_checkpoint_file is not None and os.path.exists(travel_checkpoint_file) |
|
|
assert name_checkpoint_file is not None and os.path.exists(name_checkpoint_file) |
|
|
|
|
|
travel_checkpoint = torch.load(travel_checkpoint_file, map_location='cpu') |
|
|
name_checkpoint = torch.load(name_checkpoint_file, map_location='cpu') |
|
|
|
|
|
travel_model.load_state_dict({k.replace('module.', ''): v for k, v in travel_checkpoint['state_dict'].items()}) |
|
|
print(f"=> Resume: loaded travel checkpoint {travel_checkpoint_file} (epoch {travel_checkpoint['epoch']})") |
|
|
|
|
|
name_model.load_state_dict({k.replace('module.', ''): v for k, v in name_checkpoint['state_dict'].items()}) |
|
|
print(f"=> Resume: loaded name checkpoint {name_checkpoint_file} (epoch {name_checkpoint['epoch']})") |
|
|
|
|
|
|
|
|
travel_predictions, travel_probs, name_predictions, name_probs, sms_ids = test_epoch(travel_model, name_model, 1, test_loader, tokenizer) |
|
|
with open(output_file, 'w') as f: |
|
|
f.write("sms_id,travel_prob,label,name_prob,name_flg\n") |
|
|
for travel_pred, travel_prob, name_pred, name_prob, sms_id in zip(travel_predictions, travel_probs, name_predictions, name_probs, sms_ids): |
|
|
f.write(f"{sms_id},{travel_prob},{travel_pred},{name_prob},{name_pred}\n") |
|
|
print('Output file saved!') |
|
|
|
|
|
""" |
|
|
|
|
|
# 讀取val.csv的label |
|
|
import csv |
|
|
true_labels = [] |
|
|
with open(ann_file_test, 'r', encoding='utf-8') as f: |
|
|
reader = csv.reader(f) |
|
|
next(reader) # skip header |
|
|
for row in reader: |
|
|
true_labels.append(int(row[2])) |
|
|
|
|
|
# 計算confusion matrix |
|
|
from sklearn.metrics import confusion_matrix |
|
|
cm = confusion_matrix(true_labels, pred_res) |
|
|
print('Confusion Matrix:') |
|
|
print(cm) |
|
|
|
|
|
# 印出預測錯誤的內容、預測值和正確答案 |
|
|
with open(ann_file_test, 'r', encoding='utf-8') as f: |
|
|
reader = csv.reader(f) |
|
|
next(reader) # skip header |
|
|
for idx, row in enumerate(reader): |
|
|
id, sms, label = int(row[0]), row[1], int(row[2]) |
|
|
pred = pred_res[idx] |
|
|
if pred != label: |
|
|
print(f"錯誤: sms_id={id},sms='{sms}',預測={pred},正確={label}") |
|
|
""" |
|
|
|
|
|
if __name__ == '__main__': |
|
|
inference() |
|
|
|