import os os.environ["TOKENIZERS_PARALLELISM"] = "false" import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from torchvision.transforms.functional import to_tensor, to_pil_image import torchvision.transforms as transforms from transformers import AutoModel from transformers import AutoTokenizer, AutoConfig import torch import torch.nn as nn from torch.autograd import Variable from torch.utils.data import Dataset, DataLoader from torch.cuda.amp import autocast, GradScaler from torch.utils.data.distributed import DistributedSampler from tqdm import tqdm import random import numpy as np from collections import OrderedDict from rich import print import time from glob import glob import string from torch.optim import AdamW from transformers import get_linear_schedule_with_warmup from models import get_model from dataset import MyDataset from utils import save_checkpoint, AverageMeter, ProgressMeter def test_epoch(travel_model, name_model, epoch, dataloader, tokenizer): print(f"\n\n=> val") data_time = AverageMeter('- data', ':4.3f') batch_time = AverageMeter('- batch', ':6.3f') progress = ProgressMeter( len(dataloader), data_time, batch_time, prefix=f"Epoch: [{epoch}]") end = time.time() device = 'cuda:0' if torch.cuda.is_available() else 'cpu' travel_model.to(device) travel_model.eval() name_model.to(device) name_model.eval() sms_ids = [] travel_probs = [] travel_predictions = [] name_probs = [] name_predictions = [] for batch_index, data_batch in enumerate(tqdm(dataloader)): context_str_batch, sms_id = data_batch sms_ids.append(sms_id.detach().cpu().numpy()[0]) # data tokenizer context_token_batch = tokenizer(context_str_batch, padding=True, truncation=True, max_length=500, return_tensors='pt') # to gpu context_token_batch = {k:v.to(device) for k,v in context_token_batch.items()} # forward travel data_input_batch = context_token_batch travel_output_batch = travel_model(**data_input_batch) name_output_batch = name_model(**data_input_batch) travel_pred_batch = travel_output_batch.softmax(dim=-1) travel_probs.append(travel_pred_batch.detach().cpu().numpy()[0][1]) travel_pred = torch.argmax(travel_pred_batch, dim=-1) travel_predictions.extend(travel_pred.cpu().numpy()) # forward name name_pred_batch = name_output_batch.softmax(dim=-1) name_probs.append(name_pred_batch.detach().cpu().numpy()[0][1]) name_pred = torch.argmax(name_pred_batch, dim=-1) name_predictions.extend(name_pred.cpu().numpy()) batch_time.update(time.time() - end) end = time.time() if batch_index % 50 == 0: progress.print(batch_index) return travel_predictions, travel_probs, name_predictions, name_probs, sms_ids def inference(): travel_checkpoint_file = 'checkpoints/saved_checkpoints/travel_checkpoint15_train8000.pth.tar' name_checkpoint_file = 'checkpoints/saved_checkpoints/name_checkpoint17_train9000.pth.tar' ann_file_test = 'dataset/datagame_sms_stage1(in).csv' output_file = 'both_macbertBase_20250731_2.csv' cache_dir = 'cache' model_cfg = { "pretrained_transformers": "hfl/chinese-macbert-base", "cache_dir": cache_dir } # 模型 travel_model_dict = get_model(model_cfg, mode='base') travel_model = travel_model_dict['model'] name_model_dict = get_model(model_cfg, mode='base') name_model = name_model_dict['model'] tokenizer = travel_model_dict['tokenizer'] # print(model) data_loader_cfg = {} test_dataset = MyDataset(ann_file_test, data_loader_cfg, mode='test') test_loader = DataLoader(test_dataset, batch_size=1, pin_memory=True, shuffle=False) # resume assert travel_checkpoint_file is not None and os.path.exists(travel_checkpoint_file) assert name_checkpoint_file is not None and os.path.exists(name_checkpoint_file) travel_checkpoint = torch.load(travel_checkpoint_file, map_location='cpu') name_checkpoint = torch.load(name_checkpoint_file, map_location='cpu') # model.load_state_dict(checkpoint['state_dict']) travel_model.load_state_dict({k.replace('module.', ''): v for k, v in travel_checkpoint['state_dict'].items()}) print(f"=> Resume: loaded travel checkpoint {travel_checkpoint_file} (epoch {travel_checkpoint['epoch']})") name_model.load_state_dict({k.replace('module.', ''): v for k, v in name_checkpoint['state_dict'].items()}) print(f"=> Resume: loaded name checkpoint {name_checkpoint_file} (epoch {name_checkpoint['epoch']})") #model = model.cuda() travel_predictions, travel_probs, name_predictions, name_probs, sms_ids = test_epoch(travel_model, name_model, 1, test_loader, tokenizer) with open(output_file, 'w') as f: f.write("sms_id,travel_prob,label,name_prob,name_flg\n") for travel_pred, travel_prob, name_pred, name_prob, sms_id in zip(travel_predictions, travel_probs, name_predictions, name_probs, sms_ids): f.write(f"{sms_id},{travel_prob},{travel_pred},{name_prob},{name_pred}\n") print('Output file saved!') """ # 讀取val.csv的label import csv true_labels = [] with open(ann_file_test, 'r', encoding='utf-8') as f: reader = csv.reader(f) next(reader) # skip header for row in reader: true_labels.append(int(row[2])) # 計算confusion matrix from sklearn.metrics import confusion_matrix cm = confusion_matrix(true_labels, pred_res) print('Confusion Matrix:') print(cm) # 印出預測錯誤的內容、預測值和正確答案 with open(ann_file_test, 'r', encoding='utf-8') as f: reader = csv.reader(f) next(reader) # skip header for idx, row in enumerate(reader): id, sms, label = int(row[0]), row[1], int(row[2]) pred = pred_res[idx] if pred != label: print(f"錯誤: sms_id={id},sms='{sms}',預測={pred},正確={label}") """ if __name__ == '__main__': inference()