File size: 5,181 Bytes
0857e86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import pickle

import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torch.nn.functional import one_hot,softmax
import matplotlib.pyplot as plt
import random
import torch.utils.data as data
teacher_forcing_ratio = 0.5


def get_data_set(mode="train"):
    NT,AT = None,None
    with open(r'dataset/'+"NT.pkl","rb") as f1:
        NT = pickle.load(f1)

    with open(r'dataset/'+"AT.pkl","rb") as f2:
        AT = pickle.load(f2)


    def onehot_encode(char, vocab):
        # one hot encode a given text
        encoded = [0 for _ in range(len(vocab))] #[0,0,1,0,000]
        encoded[vocab.index(char)] = 1
        return encoded

    from read_graph import get_graph
    import networkx as nx
    G_new = get_graph()
    voc = list(G_new)
    with open('nodes.pkl', 'wb') as f:
        pickle.dump(voc, f, pickle.HIGHEST_PROTOCOL)

    voc = None
    with open('nodes.pkl', 'rb') as f:
        voc = pickle.load(f)

    voc.append(0) # 补全符号
    voc.append('s') # START
    voc.append('e') # EOF

    total_word_count = len(voc)


    # 轨迹的标签
    samples = []
    labels = []
    if mode=="train":
        for tr in NT:
            samples.append(tr)
            labels.append(1)  # 正常
    else:
        for tr in NT:
            samples.append(tr)
            labels.append(1)  # 正常
        for tr in AT:
            samples.append(tr)
            labels.append(0) # 异常

    def padding(x,max_length):
        if len(x) > max_length:
            text = x[:max_length]
        else:
            text = x + [[0,0]] * (max_length - len(x))
        return text


    # 计算最长轨迹
    max_len = 10
    for tr in samples:
         max_len = max(max_len,len(tr))
    samples_padded = []

    # 补全为长轨迹
    for tr in samples:
        tr = padding(tr,max_len)
        samples_padded.append(tr)

    # One hot
    def onehot_encode(char, vocab):
        # one hot encode a given text
        encoded = [0 for _ in range(len(vocab))]
        if char != 0:
            encoded[vocab.index(char)] = 1
        return encoded

    samples_one_hot = []
    samples_index = []
    for tr in samples_padded:
        tr_rep = []
        tr_rep_index = []
        for pt in tr:
            spatial = onehot_encode(pt[0], voc)
            temporal = int(pt[1])
            tr_rep.append(spatial)
            tr_rep_index.append(voc.index(pt[0]))
        samples_one_hot.append(tr_rep)
        samples_index.append(tr_rep_index)

    sampletensor = torch.Tensor(samples_one_hot)
    sampletensor_index = torch.Tensor(samples_index)
    labeltensor = torch.Tensor(labels)
    # print("sampletensor.shape",sampletensor.shape)
    # print("labeltensor.shape",labeltensor.shape)
    return sampletensor,sampletensor_index,labeltensor,max_len

global device

if torch.cuda.is_available():
    torch.backends.cudnn.enabled = False
    device = torch.device("cuda:0")
    torch.cuda.set_device(0)
    import os
    os.environ['CUDA_VISIBLE_DEVICES']='0'
    print("Working on GPU")
    torch.cuda.empty_cache()
else:
    device = torch.device("cpu")

import torch.nn as nn
# from VAE import AE,RNN

if __name__ == '__main__':
    sampletensor,sampletensor_index,labeltensor,max_len = get_data_set("train")

    batch_size = 2
    train_set = data.TensorDataset(sampletensor, sampletensor_index,labeltensor)
    train_iter = data.DataLoader(train_set, batch_size, shuffle=False, drop_last=False)

    # rnn = RNN(input_size=2694,hidden_size=64,batch_size=2,maxlen=max_len)
    # loss = nn.CrossEntropyLoss()
    # optimizer = torch.optim.Adamax(rnn.parameters(),lr=1e-2)
    #
    # net = rnn.to(device)
    # num_epochs = 120
    #
    # h_hat_avg = None
    #
    # from tqdm import tqdm
    # for epoch in tqdm(range(num_epochs)):
    #     epoch_total_loss = 0
    #     for x, x_label,y in train_iter:
    #         # RNN
    #         xhat,kld,h_hat = net(x,x,"train",None)
    #         # print(xhat.shape)
    #         # print(x_label.shape)
    #         len_all = (x_label.shape[0])*(x_label.shape[1])
    #         xhat = xhat.reshape(len_all,-1)
    #         x_label = x_label.reshape(len_all).long().to(device)
    #         # print(x_label)
    #         # print("xhat",xhat.shape)
    #         # print("x_label",x_label.shape)
    #         l = loss(xhat,x_label)
    #         # print("reconstruction loss:",l,"kld loss:",kld)
    #         total_loss = l + kld
    #         epoch_total_loss += total_loss
    #         optimizer.zero_grad()
    #         total_loss.backward()
    #         optimizer.step()
    #         if epoch == num_epochs - 1:
    #             if h_hat_avg is None:
    #                 h_hat_avg = h_hat/ torch.full(h_hat.shape,len(sampletensor)).to(device)
    #             else:
    #                 h_hat_avg += h_hat / torch.full(h_hat.shape, len(sampletensor)).to(device)
    #             print(">>> h_hat_avg",h_hat_avg.shape)
    #     print(" epoch_total_loss = ",epoch_total_loss)
    #
    # print("training ends")
    # torch.save(net,"LSTM-VAE.pth")
    # torch.save(h_hat_avg, 'h_hat_avg.pt')
    #
    #
    #
    #
    #
    #
    #
    #
    #
    #
    #
    #