File size: 6,238 Bytes
3118055
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6be67e1
 
d1b0194
 
3118055
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms.functional import to_tensor, to_pil_image
import torchvision.transforms as transforms
from transformers import AutoModel
from transformers import AutoTokenizer, AutoConfig
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data.distributed import DistributedSampler

from tqdm import tqdm
import random
import numpy as np
from collections import OrderedDict
from rich import print
import time
from glob import glob
import string
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup
from models import get_model
from dataset import MyDataset
from utils import save_checkpoint, AverageMeter, ProgressMeter


def test_epoch(travel_model, name_model, epoch, dataloader, tokenizer):
    print(f"\n\n=> val")
    data_time = AverageMeter('- data', ':4.3f')
    batch_time = AverageMeter('- batch', ':6.3f')
    progress = ProgressMeter(
        len(dataloader), data_time, batch_time, prefix=f"Epoch: [{epoch}]")

    end = time.time()
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    travel_model.to(device)
    travel_model.eval()

    name_model.to(device)
    name_model.eval()

    sms_ids = []
    travel_probs = []
    travel_predictions = []
    name_probs = []
    name_predictions = []

    for batch_index, data_batch in enumerate(tqdm(dataloader)):
        
        context_str_batch, sms_id = data_batch
        sms_ids.append(sms_id.detach().cpu().numpy()[0])

        # data tokenizer
        context_token_batch = tokenizer(context_str_batch, padding=True, truncation=True, max_length=500, return_tensors='pt')
        
        # to gpu
        context_token_batch = {k:v.to(device) for k,v in context_token_batch.items()}

        # forward travel
        data_input_batch = context_token_batch
        travel_output_batch = travel_model(**data_input_batch)
        name_output_batch = name_model(**data_input_batch)

        travel_pred_batch = travel_output_batch.softmax(dim=-1)
        travel_probs.append(travel_pred_batch.detach().cpu().numpy()[0][1])
        travel_pred = torch.argmax(travel_pred_batch, dim=-1)
        travel_predictions.extend(travel_pred.cpu().numpy())

        # forward name
        name_pred_batch = name_output_batch.softmax(dim=-1)
        name_probs.append(name_pred_batch.detach().cpu().numpy()[0][1])
        name_pred = torch.argmax(name_pred_batch, dim=-1)
        name_predictions.extend(name_pred.cpu().numpy())

        batch_time.update(time.time() - end)
        end = time.time()

        if batch_index % 50 == 0:
            progress.print(batch_index)

    return travel_predictions, travel_probs, name_predictions, name_probs, sms_ids

def inference():
    travel_checkpoint_file = 'checkpoints/saved_checkpoints/travel_checkpoint15_train8000.pth.tar'
    name_checkpoint_file = 'checkpoints/saved_checkpoints/name_checkpoint17_train9000.pth.tar'
    ann_file_test = 'dataset/datagame_sms_stage1(in).csv'
    output_file = 'both_macbertBase_20250731_2.csv'
    cache_dir = 'cache'

    model_cfg = {
        "pretrained_transformers": "hfl/chinese-macbert-base",
        "cache_dir": cache_dir
    }
    # 模型
    travel_model_dict = get_model(model_cfg, mode='base')
    travel_model = travel_model_dict['model']

    name_model_dict = get_model(model_cfg, mode='base')
    name_model = name_model_dict['model']


    tokenizer = travel_model_dict['tokenizer']
    # print(model)


    data_loader_cfg = {}
    test_dataset = MyDataset(ann_file_test, data_loader_cfg, mode='test')
    test_loader = DataLoader(test_dataset, batch_size=1, pin_memory=True, shuffle=False)

    # resume
    assert travel_checkpoint_file is not None and os.path.exists(travel_checkpoint_file)
    assert name_checkpoint_file is not None and os.path.exists(name_checkpoint_file)

    travel_checkpoint = torch.load(travel_checkpoint_file, map_location='cpu')
    name_checkpoint = torch.load(name_checkpoint_file, map_location='cpu')
    # model.load_state_dict(checkpoint['state_dict'])
    travel_model.load_state_dict({k.replace('module.', ''): v for k, v in travel_checkpoint['state_dict'].items()})
    print(f"=> Resume: loaded travel checkpoint {travel_checkpoint_file} (epoch {travel_checkpoint['epoch']})")

    name_model.load_state_dict({k.replace('module.', ''): v for k, v in name_checkpoint['state_dict'].items()})
    print(f"=> Resume: loaded name checkpoint {name_checkpoint_file} (epoch {name_checkpoint['epoch']})")

    #model = model.cuda()
    travel_predictions, travel_probs, name_predictions, name_probs, sms_ids = test_epoch(travel_model, name_model, 1, test_loader, tokenizer)
    with open(output_file, 'w') as f:
        f.write("sms_id,travel_prob,label,name_prob,name_flg\n")
        for travel_pred, travel_prob, name_pred, name_prob, sms_id in zip(travel_predictions, travel_probs, name_predictions, name_probs, sms_ids):
            f.write(f"{sms_id},{travel_prob},{travel_pred},{name_prob},{name_pred}\n")
    print('Output file saved!')

"""

    # 讀取val.csv的label
    import csv
    true_labels = []
    with open(ann_file_test, 'r', encoding='utf-8') as f:
        reader = csv.reader(f)
        next(reader)  # skip header
        for row in reader:
            true_labels.append(int(row[2]))

    # 計算confusion matrix
    from sklearn.metrics import confusion_matrix
    cm = confusion_matrix(true_labels, pred_res)
    print('Confusion Matrix:')
    print(cm)

    # 印出預測錯誤的內容、預測值和正確答案
    with open(ann_file_test, 'r', encoding='utf-8') as f:
        reader = csv.reader(f)
        next(reader)  # skip header
        for idx, row in enumerate(reader):
            id, sms, label = int(row[0]), row[1], int(row[2])
            pred = pred_res[idx]
            if pred != label:
                print(f"錯誤: sms_id={id},sms='{sms}',預測={pred},正確={label}")
"""

if __name__ == '__main__':
    inference()