Spaces:
Runtime error
Runtime error
Commit ·
6f2ff70
1
Parent(s): 480ece8
changs
Browse files- README.md +3 -0
- app.py +77 -37
- bert_pytorch/dataset/dataset.py +131 -0
- bert_pytorch/dataset/log_dataset.py +134 -0
- bert_pytorch/dataset/sample.py +117 -0
- bert_pytorch/dataset/utils.py +19 -0
- bert_pytorch/dataset/vocab.py +169 -0
- bert_pytorch/model/bert.py +49 -0
- bert_pytorch/model/embedding/bert.py +42 -0
- bert_pytorch/model/embedding/position.py +25 -0
- bert_pytorch/model/embedding/segment.py +6 -0
- bert_pytorch/model/embedding/time_embed.py +10 -0
- bert_pytorch/model/embedding/token.py +6 -0
- bert_pytorch/model/language_model.py +61 -0
- bert_pytorch/model/log_model.py +74 -0
- bert_pytorch/model/transformer.py +31 -0
- bert_pytorch/predict_log.py +290 -0
- bert_pytorch/train_log.py +222 -0
- logbert_rca_pipeline_api.py +209 -0
- requirements.txt +19 -0
README.md
CHANGED
|
@@ -8,3 +8,6 @@ pinned: false
|
|
| 8 |
---
|
| 9 |
|
| 10 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
---
|
| 9 |
|
| 10 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
|
app.py
CHANGED
|
@@ -1,59 +1,99 @@
|
|
| 1 |
|
| 2 |
-
|
|
|
|
| 3 |
import os
|
|
|
|
|
|
|
|
|
|
| 4 |
import redis
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
-
import asyncio
|
| 7 |
-
from sql import insert_rca_result, connect_to_database, disconnect_from_database
|
| 8 |
|
| 9 |
-
#
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
-
#
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
REDIS_QUEUE = os.environ["REDIS_QUEUE"]
|
| 18 |
-
READY_FOR_RCA_QUEUE = os.environ.get("READY_FOR_RCA_QUEUE", "logbert_ready_for_rca")
|
| 19 |
|
| 20 |
-
# Initialize Redis client
|
| 21 |
-
redis_client = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, decode_responses=True)
|
| 22 |
|
| 23 |
-
def process_log(
|
| 24 |
-
#
|
| 25 |
-
|
| 26 |
-
|
|
|
|
|
|
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
|
| 30 |
-
|
| 31 |
-
# Store anomaly result in rca_results table using sql.py
|
| 32 |
-
async def _save():
|
| 33 |
-
await connect_to_database()
|
| 34 |
-
await insert_rca_result(rca_result)
|
| 35 |
-
await disconnect_from_database()
|
| 36 |
-
asyncio.run(_save())
|
| 37 |
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
while True:
|
| 40 |
-
|
| 41 |
-
filename = redis_client.rpop
|
| 42 |
if filename:
|
| 43 |
-
|
| 44 |
-
if
|
| 45 |
-
rca_result = process_log(
|
| 46 |
-
save_rca_to_db(rca_result)
|
| 47 |
-
# Notify ready-for-rca queue
|
| 48 |
try:
|
| 49 |
-
redis_client.lpush
|
| 50 |
print(f"Notified {READY_FOR_RCA_QUEUE} for {filename}")
|
| 51 |
except Exception as redis_exc:
|
| 52 |
print(f"Failed to notify ready-for-rca queue: {redis_exc}")
|
| 53 |
else:
|
| 54 |
-
print(f"File not
|
| 55 |
else:
|
| 56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
if __name__ == "__main__":
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
|
| 2 |
+
|
| 3 |
+
|
| 4 |
import os
|
| 5 |
+
import asyncio
|
| 6 |
+
import tempfile
|
| 7 |
+
from logbert_rca_pipeline_api import detect_anomalies_and_explain
|
| 8 |
import redis
|
| 9 |
+
import boto3
|
| 10 |
+
from botocore.exceptions import ClientError
|
| 11 |
+
from fastapi import FastAPI, HTTPException
|
| 12 |
+
from pydantic import BaseModel
|
| 13 |
+
import uvicorn
|
| 14 |
|
|
|
|
|
|
|
| 15 |
|
| 16 |
+
# FastAPI app
|
| 17 |
+
app = FastAPI()
|
| 18 |
+
|
| 19 |
+
# Initialize Redis client (adjust host/port/db as needed)
|
| 20 |
+
redis_client = redis.Redis(host='localhost', port=6379, db=0)
|
| 21 |
+
|
| 22 |
+
# Define the Redis queue name
|
| 23 |
+
REDIS_QUEUE = "log_queue"
|
| 24 |
|
| 25 |
+
# Request model
|
| 26 |
+
class LogRequest(BaseModel):
|
| 27 |
+
filename: str
|
|
|
|
|
|
|
| 28 |
|
|
|
|
|
|
|
| 29 |
|
| 30 |
+
async def process_log(filename, file_content):
|
| 31 |
+
# Save file_content to a temporary file and run RCA pipeline
|
| 32 |
+
with tempfile.NamedTemporaryFile(delete=False, mode="wb", suffix=".log") as tmp:
|
| 33 |
+
tmp.write(file_content)
|
| 34 |
+
tmp_path = tmp.name
|
| 35 |
+
loop = asyncio.get_event_loop()
|
| 36 |
|
| 37 |
+
def _run_pipeline():
|
| 38 |
+
return detect_anomalies_and_explain(tmp_path)
|
| 39 |
+
results = await loop.run_in_executor(None, _run_pipeline)
|
| 40 |
+
os.unlink(tmp_path)
|
| 41 |
+
if results and len(results) > 0:
|
| 42 |
+
return results[0]
|
| 43 |
+
else:
|
| 44 |
+
return {"filename": filename, "anomaly": False, "details": "No anomaly detected."}
|
| 45 |
|
| 46 |
|
| 47 |
+
S3_BUCKET = "your-s3-bucket-name"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
+
|
| 50 |
+
async def get_file_from_s3(filename):
|
| 51 |
+
loop = asyncio.get_event_loop()
|
| 52 |
+
|
| 53 |
+
def _download():
|
| 54 |
+
s3 = boto3.client("s3")
|
| 55 |
+
try:
|
| 56 |
+
response = s3.get_object(Bucket=S3_BUCKET, Key=filename.decode(
|
| 57 |
+
) if isinstance(filename, bytes) else filename)
|
| 58 |
+
return response["Body"].read()
|
| 59 |
+
except ClientError as e:
|
| 60 |
+
print(f"Error downloading {filename} from S3: {e}")
|
| 61 |
+
return None
|
| 62 |
+
return await loop.run_in_executor(None, _download)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
async def main():
|
| 66 |
while True:
|
| 67 |
+
loop = asyncio.get_event_loop()
|
| 68 |
+
filename = await loop.run_in_executor(None, redis_client.rpop, REDIS_QUEUE)
|
| 69 |
if filename:
|
| 70 |
+
file_content = await get_file_from_s3(filename)
|
| 71 |
+
if file_content is not None:
|
| 72 |
+
rca_result = await process_log(filename, file_content)
|
| 73 |
+
await save_rca_to_db(rca_result)
|
|
|
|
| 74 |
try:
|
| 75 |
+
await loop.run_in_executor(None, redis_client.lpush, READY_FOR_RCA_QUEUE, filename)
|
| 76 |
print(f"Notified {READY_FOR_RCA_QUEUE} for {filename}")
|
| 77 |
except Exception as redis_exc:
|
| 78 |
print(f"Failed to notify ready-for-rca queue: {redis_exc}")
|
| 79 |
else:
|
| 80 |
+
print(f"File {filename} could not be downloaded from S3.")
|
| 81 |
else:
|
| 82 |
+
await asyncio.sleep(2)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# FastAPI endpoint to process a log file from S3
|
| 86 |
+
@app.post("/process-log")
|
| 87 |
+
async def process_log_endpoint(request: LogRequest):
|
| 88 |
+
file_content = await get_file_from_s3(request.filename)
|
| 89 |
+
if file_content is None:
|
| 90 |
+
raise HTTPException(status_code=404, detail=f"File {request.filename} not found in S3 bucket.")
|
| 91 |
+
result = await process_log(request.filename, file_content)
|
| 92 |
+
return result
|
| 93 |
|
| 94 |
if __name__ == "__main__":
|
| 95 |
+
import sys
|
| 96 |
+
if len(sys.argv) > 1 and sys.argv[1] == "serve":
|
| 97 |
+
uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)
|
| 98 |
+
else:
|
| 99 |
+
asyncio.run(main())
|
bert_pytorch/dataset/dataset.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.utils.data import Dataset
|
| 2 |
+
import tqdm
|
| 3 |
+
import torch
|
| 4 |
+
import random
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
class BERTDataset(Dataset):
|
| 8 |
+
def __init__(self, corpus_path, vocab, seq_len, corpus_lines=None, encoding="utf-8", on_memory=True, predict_mode=False):
|
| 9 |
+
self.vocab = vocab
|
| 10 |
+
self.seq_len = seq_len
|
| 11 |
+
|
| 12 |
+
self.on_memory = on_memory
|
| 13 |
+
self.corpus_lines = corpus_lines
|
| 14 |
+
self.corpus_path = corpus_path
|
| 15 |
+
self.encoding = encoding
|
| 16 |
+
|
| 17 |
+
self.predict_mode = predict_mode
|
| 18 |
+
self.lines = corpus_path
|
| 19 |
+
self.corpus_lines = len(self.lines)
|
| 20 |
+
|
| 21 |
+
if not on_memory:
|
| 22 |
+
self.file = open(corpus_path, "r", encoding=encoding)
|
| 23 |
+
self.random_file = open(corpus_path, "r", encoding=encoding)
|
| 24 |
+
|
| 25 |
+
for _ in range(random.randint(self.corpus_lines if self.corpus_lines < 1000 else 1000)):
|
| 26 |
+
self.random_file.__next__()
|
| 27 |
+
|
| 28 |
+
def __len__(self):
|
| 29 |
+
return self.corpus_lines
|
| 30 |
+
|
| 31 |
+
def __getitem__(self, item):
|
| 32 |
+
t1, t2, is_next_label = self.random_sent(item)
|
| 33 |
+
t1_random, t1_label = self.random_word(t1)
|
| 34 |
+
t2_random, t2_label = self.random_word(t2)
|
| 35 |
+
|
| 36 |
+
# [CLS] tag = SOS tag, [SEP] tag = EOS tag
|
| 37 |
+
t1 = [self.vocab.sos_index] + t1_random + [self.vocab.eos_index]
|
| 38 |
+
t2 = t2_random + [self.vocab.eos_index]
|
| 39 |
+
|
| 40 |
+
t1_label = [self.vocab.pad_index] + t1_label + [self.vocab.pad_index]
|
| 41 |
+
t2_label = t2_label + [self.vocab.pad_index]
|
| 42 |
+
|
| 43 |
+
segment_label = ([1 for _ in range(len(t1))] + [2 for _ in range(len(t2))])[:self.seq_len]
|
| 44 |
+
bert_input = (t1 + t2)[:self.seq_len]
|
| 45 |
+
bert_label = (t1_label + t2_label)[:self.seq_len]
|
| 46 |
+
|
| 47 |
+
padding = [self.vocab.pad_index for _ in range(self.seq_len - len(bert_input))]
|
| 48 |
+
bert_input.extend(padding), bert_label.extend(padding), segment_label.extend(padding)
|
| 49 |
+
|
| 50 |
+
output = {"bert_input": bert_input,
|
| 51 |
+
"bert_label": bert_label,
|
| 52 |
+
"segment_label": segment_label,
|
| 53 |
+
"is_next": is_next_label}
|
| 54 |
+
|
| 55 |
+
return {key: torch.tensor(value) for key, value in output.items()}
|
| 56 |
+
|
| 57 |
+
def random_word(self, sentence):
|
| 58 |
+
tokens = list(sentence)
|
| 59 |
+
output_label = []
|
| 60 |
+
|
| 61 |
+
for i, token in enumerate(tokens):
|
| 62 |
+
prob = random.random()
|
| 63 |
+
# replace 15% of tokens in a sequence to a masked token
|
| 64 |
+
if prob < 0.15:
|
| 65 |
+
if self.predict_mode:
|
| 66 |
+
tokens[i] = self.vocab.mask_index
|
| 67 |
+
output_label.append(self.vocab.stoi.get(token, self.vocab.unk_index))
|
| 68 |
+
continue
|
| 69 |
+
|
| 70 |
+
prob /= 0.15
|
| 71 |
+
|
| 72 |
+
# 80% randomly change token to mask token
|
| 73 |
+
if prob < 0.8:
|
| 74 |
+
tokens[i] = self.vocab.mask_index
|
| 75 |
+
|
| 76 |
+
# 10% randomly change token to random token
|
| 77 |
+
elif prob < 0.9:
|
| 78 |
+
tokens[i] = random.randrange(len(self.vocab))
|
| 79 |
+
|
| 80 |
+
# 10% randomly change token to current token
|
| 81 |
+
else:
|
| 82 |
+
tokens[i] = self.vocab.stoi.get(token, self.vocab.unk_index)
|
| 83 |
+
|
| 84 |
+
output_label.append(self.vocab.stoi.get(token, self.vocab.unk_index))
|
| 85 |
+
|
| 86 |
+
else:
|
| 87 |
+
tokens[i] = self.vocab.stoi.get(token, self.vocab.unk_index)
|
| 88 |
+
output_label.append(0)
|
| 89 |
+
|
| 90 |
+
return tokens, output_label
|
| 91 |
+
|
| 92 |
+
def random_sent(self, index):
|
| 93 |
+
t1, t2 = self.get_corpus_line(index)
|
| 94 |
+
|
| 95 |
+
if self.predict_mode:
|
| 96 |
+
return t1, t2, 1
|
| 97 |
+
|
| 98 |
+
# output_text, label(isNotNext:0, isNext:1)
|
| 99 |
+
if random.random() > 0.5:
|
| 100 |
+
return t1, t2, 1
|
| 101 |
+
else:
|
| 102 |
+
return t1, self.get_random_line(), 0
|
| 103 |
+
|
| 104 |
+
def get_corpus_line(self, item):
|
| 105 |
+
if self.on_memory:
|
| 106 |
+
return self.lines[item][0], self.lines[item][1]
|
| 107 |
+
else:
|
| 108 |
+
line = self.file.__next__()
|
| 109 |
+
if line is None:
|
| 110 |
+
self.file.close()
|
| 111 |
+
self.file = open(self.corpus_path, "r", encoding=self.encoding)
|
| 112 |
+
line = self.file.__next__()
|
| 113 |
+
|
| 114 |
+
t1, t2 = line[:-1].split("\t")
|
| 115 |
+
return t1, t2
|
| 116 |
+
|
| 117 |
+
def get_random_line(self):
|
| 118 |
+
if self.on_memory:
|
| 119 |
+
return self.lines[random.randrange(len(self.lines))][1]
|
| 120 |
+
|
| 121 |
+
line = self.file.__next__()
|
| 122 |
+
if line is None:
|
| 123 |
+
self.file.close()
|
| 124 |
+
self.file = open(self.corpus_path, "r", encoding=self.encoding)
|
| 125 |
+
for _ in range(random.randint(self.corpus_lines if self.corpus_lines < 1000 else 1000)):
|
| 126 |
+
self.random_file.__next__()
|
| 127 |
+
line = self.random_file.__next__()
|
| 128 |
+
return line[:-1].split("\t")[1]
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
|
bert_pytorch/dataset/log_dataset.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.utils.data import Dataset
|
| 2 |
+
import torch
|
| 3 |
+
import random
|
| 4 |
+
import numpy as np
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
|
| 7 |
+
class LogDataset(Dataset):
|
| 8 |
+
def __init__(self, log_corpus, time_corpus, vocab, seq_len, corpus_lines=None, encoding="utf-8", on_memory=True, predict_mode=False, mask_ratio=0.15):
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
:param corpus: log sessions/line
|
| 12 |
+
:param vocab: log events collection including pad, ukn ...
|
| 13 |
+
:param seq_len: max sequence length
|
| 14 |
+
:param corpus_lines: number of log sessions
|
| 15 |
+
:param encoding:
|
| 16 |
+
:param on_memory:
|
| 17 |
+
:param predict_mode: if predict
|
| 18 |
+
"""
|
| 19 |
+
self.vocab = vocab
|
| 20 |
+
self.seq_len = seq_len
|
| 21 |
+
|
| 22 |
+
self.on_memory = on_memory
|
| 23 |
+
self.encoding = encoding
|
| 24 |
+
|
| 25 |
+
self.predict_mode = predict_mode
|
| 26 |
+
self.log_corpus = log_corpus
|
| 27 |
+
self.time_corpus = time_corpus
|
| 28 |
+
self.corpus_lines = len(log_corpus)
|
| 29 |
+
|
| 30 |
+
self.mask_ratio = mask_ratio
|
| 31 |
+
|
| 32 |
+
def __len__(self):
|
| 33 |
+
return self.corpus_lines
|
| 34 |
+
|
| 35 |
+
def __getitem__(self, idx):
|
| 36 |
+
k, t = self.log_corpus[idx], self.time_corpus[idx]
|
| 37 |
+
|
| 38 |
+
k_masked, k_label, t_masked, t_label = self.random_item(k, t)
|
| 39 |
+
|
| 40 |
+
# [CLS] tag = SOS tag, [SEP] tag = EOS tag
|
| 41 |
+
k = [self.vocab.sos_index] + k_masked
|
| 42 |
+
k_label = [self.vocab.pad_index] + k_label
|
| 43 |
+
# k_label = [self.vocab.sos_index] + k_label
|
| 44 |
+
|
| 45 |
+
t = [0] + t_masked
|
| 46 |
+
t_label = [self.vocab.pad_index] + t_label
|
| 47 |
+
|
| 48 |
+
return k, k_label, t, t_label
|
| 49 |
+
|
| 50 |
+
def random_item(self, k, t):
|
| 51 |
+
tokens = list(k)
|
| 52 |
+
output_label = []
|
| 53 |
+
|
| 54 |
+
time_intervals = list(t)
|
| 55 |
+
time_label = []
|
| 56 |
+
|
| 57 |
+
for i, token in enumerate(tokens):
|
| 58 |
+
time_int = time_intervals[i]
|
| 59 |
+
prob = random.random()
|
| 60 |
+
# replace 15% of tokens in a sequence to a masked token
|
| 61 |
+
if prob < self.mask_ratio:
|
| 62 |
+
# raise AttributeError("no mask in visualization")
|
| 63 |
+
|
| 64 |
+
if self.predict_mode:
|
| 65 |
+
tokens[i] = self.vocab.mask_index
|
| 66 |
+
output_label.append(self.vocab.stoi.get(token, self.vocab.unk_index))
|
| 67 |
+
|
| 68 |
+
time_label.append(time_int)
|
| 69 |
+
time_intervals[i] = 0
|
| 70 |
+
continue
|
| 71 |
+
|
| 72 |
+
prob /= self.mask_ratio
|
| 73 |
+
|
| 74 |
+
# 80% randomly change token to mask token
|
| 75 |
+
if prob < 0.8:
|
| 76 |
+
tokens[i] = self.vocab.mask_index
|
| 77 |
+
|
| 78 |
+
# 10% randomly change token to random token
|
| 79 |
+
elif prob < 0.9:
|
| 80 |
+
tokens[i] = random.randrange(len(self.vocab))
|
| 81 |
+
|
| 82 |
+
# 10% randomly change token to current token
|
| 83 |
+
else:
|
| 84 |
+
tokens[i] = self.vocab.stoi.get(token, self.vocab.unk_index)
|
| 85 |
+
|
| 86 |
+
output_label.append(self.vocab.stoi.get(token, self.vocab.unk_index))
|
| 87 |
+
|
| 88 |
+
time_intervals[i] = 0 # time mask value = 0
|
| 89 |
+
time_label.append(time_int)
|
| 90 |
+
|
| 91 |
+
else:
|
| 92 |
+
tokens[i] = self.vocab.stoi.get(token, self.vocab.unk_index)
|
| 93 |
+
output_label.append(0)
|
| 94 |
+
time_label.append(0)
|
| 95 |
+
|
| 96 |
+
return tokens, output_label, time_intervals, time_label
|
| 97 |
+
|
| 98 |
+
def collate_fn(self, batch, percentile=100, dynamical_pad=True):
|
| 99 |
+
lens = [len(seq[0]) for seq in batch]
|
| 100 |
+
|
| 101 |
+
# find the max len in each batch
|
| 102 |
+
if dynamical_pad:
|
| 103 |
+
# dynamical padding
|
| 104 |
+
seq_len = int(np.percentile(lens, percentile))
|
| 105 |
+
if self.seq_len is not None:
|
| 106 |
+
seq_len = min(seq_len, self.seq_len)
|
| 107 |
+
else:
|
| 108 |
+
# fixed length padding
|
| 109 |
+
seq_len = self.seq_len
|
| 110 |
+
|
| 111 |
+
output = defaultdict(list)
|
| 112 |
+
for seq in batch:
|
| 113 |
+
bert_input = seq[0][:seq_len]
|
| 114 |
+
bert_label = seq[1][:seq_len]
|
| 115 |
+
time_input = seq[2][:seq_len]
|
| 116 |
+
time_label = seq[3][:seq_len]
|
| 117 |
+
|
| 118 |
+
padding = [self.vocab.pad_index for _ in range(seq_len - len(bert_input))]
|
| 119 |
+
bert_input.extend(padding), bert_label.extend(padding), time_input.extend(padding), time_label.extend(
|
| 120 |
+
padding)
|
| 121 |
+
|
| 122 |
+
time_input = np.array(time_input)[:, np.newaxis]
|
| 123 |
+
output["bert_input"].append(bert_input)
|
| 124 |
+
output["bert_label"].append(bert_label)
|
| 125 |
+
output["time_input"].append(time_input)
|
| 126 |
+
output["time_label"].append(time_label)
|
| 127 |
+
|
| 128 |
+
output["bert_input"] = torch.tensor(output["bert_input"], dtype=torch.long)
|
| 129 |
+
output["bert_label"] = torch.tensor(output["bert_label"], dtype=torch.long)
|
| 130 |
+
output["time_input"] = torch.tensor(output["time_input"], dtype=torch.float)
|
| 131 |
+
output["time_label"] = torch.tensor(output["time_label"], dtype=torch.float)
|
| 132 |
+
|
| 133 |
+
return output
|
| 134 |
+
|
bert_pytorch/dataset/sample.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from tqdm import tqdm
|
| 2 |
+
import numpy as np
|
| 3 |
+
from sklearn.model_selection import train_test_split
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def generate_pairs(line, window_size):
|
| 7 |
+
line = np.array(line)
|
| 8 |
+
line = line[:, 0]
|
| 9 |
+
|
| 10 |
+
seqs = []
|
| 11 |
+
for i in range(0, len(line), window_size):
|
| 12 |
+
seq = line[i:i + window_size]
|
| 13 |
+
seqs.append(seq)
|
| 14 |
+
seqs += []
|
| 15 |
+
seq_pairs = []
|
| 16 |
+
for i in range(1, len(seqs)):
|
| 17 |
+
seq_pairs.append([seqs[i - 1], seqs[i]])
|
| 18 |
+
return seqs
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def fixed_window(line, window_size, adaptive_window, seq_len=None, min_len=0):
|
| 22 |
+
line = [ln.split(",") for ln in line.split()]
|
| 23 |
+
|
| 24 |
+
# filter the line/session shorter than 10
|
| 25 |
+
if len(line) < min_len:
|
| 26 |
+
return [], []
|
| 27 |
+
|
| 28 |
+
# max seq len
|
| 29 |
+
if seq_len is not None:
|
| 30 |
+
line = line[:seq_len]
|
| 31 |
+
|
| 32 |
+
if adaptive_window:
|
| 33 |
+
window_size = len(line)
|
| 34 |
+
|
| 35 |
+
line = np.array(line)
|
| 36 |
+
|
| 37 |
+
# if time duration exists in data
|
| 38 |
+
if line.shape[1] == 2:
|
| 39 |
+
tim = line[:,1].astype(float)
|
| 40 |
+
line = line[:, 0]
|
| 41 |
+
|
| 42 |
+
# the first time duration of a session should be 0, so max is window_size(mins) * 60
|
| 43 |
+
tim[0] = 0
|
| 44 |
+
else:
|
| 45 |
+
line = line.squeeze()
|
| 46 |
+
# if time duration doesn't exist, then create a zero array for time
|
| 47 |
+
tim = np.zeros(line.shape)
|
| 48 |
+
|
| 49 |
+
logkey_seqs = []
|
| 50 |
+
time_seq = []
|
| 51 |
+
for i in range(0, len(line), window_size):
|
| 52 |
+
logkey_seqs.append(line[i:i + window_size])
|
| 53 |
+
time_seq.append(tim[i:i + window_size])
|
| 54 |
+
|
| 55 |
+
return logkey_seqs, time_seq
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def generate_train_valid(data_path, window_size=20, adaptive_window=True,
|
| 59 |
+
sample_ratio=1, valid_size=0.1, output_path=None,
|
| 60 |
+
scale=None, scale_path=None, seq_len=None, min_len=0):
|
| 61 |
+
with open(data_path, 'r') as f:
|
| 62 |
+
data_iter = f.readlines()
|
| 63 |
+
|
| 64 |
+
num_session = int(len(data_iter) * sample_ratio)
|
| 65 |
+
# only even number of samples, or drop_last=True in DataLoader API
|
| 66 |
+
# coz in parallel computing in CUDA, odd number of samples reports issue when merging the result
|
| 67 |
+
# num_session += num_session % 2
|
| 68 |
+
|
| 69 |
+
test_size = int(min(num_session, len(data_iter)) * valid_size)
|
| 70 |
+
# only even number of samples
|
| 71 |
+
# test_size += test_size % 2
|
| 72 |
+
|
| 73 |
+
print("before filtering short session")
|
| 74 |
+
print("train size ", int(num_session - test_size))
|
| 75 |
+
print("valid size ", int(test_size))
|
| 76 |
+
print("="*40)
|
| 77 |
+
|
| 78 |
+
logkey_seq_pairs = []
|
| 79 |
+
time_seq_pairs = []
|
| 80 |
+
session = 0
|
| 81 |
+
for line in tqdm(data_iter):
|
| 82 |
+
if session >= num_session:
|
| 83 |
+
break
|
| 84 |
+
session += 1
|
| 85 |
+
|
| 86 |
+
logkeys, times = fixed_window(line, window_size, adaptive_window, seq_len, min_len)
|
| 87 |
+
logkey_seq_pairs += logkeys
|
| 88 |
+
time_seq_pairs += times
|
| 89 |
+
|
| 90 |
+
logkey_seq_pairs = np.array(logkey_seq_pairs, dtype=object)
|
| 91 |
+
time_seq_pairs = np.array(time_seq_pairs, dtype=object)
|
| 92 |
+
|
| 93 |
+
logkey_trainset, logkey_validset, time_trainset, time_validset = train_test_split(logkey_seq_pairs,
|
| 94 |
+
time_seq_pairs,
|
| 95 |
+
test_size=test_size,
|
| 96 |
+
random_state=1234)
|
| 97 |
+
|
| 98 |
+
# sort seq_pairs by seq len
|
| 99 |
+
train_len = list(map(len, logkey_trainset))
|
| 100 |
+
valid_len = list(map(len, logkey_validset))
|
| 101 |
+
|
| 102 |
+
train_sort_index = np.argsort(-1 * np.array(train_len))
|
| 103 |
+
valid_sort_index = np.argsort(-1 * np.array(valid_len))
|
| 104 |
+
|
| 105 |
+
logkey_trainset = logkey_trainset[train_sort_index]
|
| 106 |
+
logkey_validset = logkey_validset[valid_sort_index]
|
| 107 |
+
|
| 108 |
+
time_trainset = time_trainset[train_sort_index]
|
| 109 |
+
time_validset = time_validset[valid_sort_index]
|
| 110 |
+
|
| 111 |
+
print("="*40)
|
| 112 |
+
print("Num of train seqs", len(logkey_trainset))
|
| 113 |
+
print("Num of valid seqs", len(logkey_validset))
|
| 114 |
+
print("="*40)
|
| 115 |
+
|
| 116 |
+
return logkey_trainset, logkey_validset, time_trainset, time_validset
|
| 117 |
+
|
bert_pytorch/dataset/utils.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import os
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def save_parameters(options, filename):
|
| 8 |
+
with open(filename, "w+") as f:
|
| 9 |
+
for key in options.keys():
|
| 10 |
+
f.write("{}: {}\n".format(key, options[key]))
|
| 11 |
+
|
| 12 |
+
# https://gist.github.com/KirillVladimirov/005ec7f762293d2321385580d3dbe335
|
| 13 |
+
def seed_everything(seed=1234):
|
| 14 |
+
random.seed(seed)
|
| 15 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
| 16 |
+
np.random.seed(seed)
|
| 17 |
+
torch.manual_seed(seed)
|
| 18 |
+
# torch.cuda.manual_seed(seed)
|
| 19 |
+
# torch.backends.cudnn.deterministic = True
|
bert_pytorch/dataset/vocab.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pickle
|
| 2 |
+
import tqdm
|
| 3 |
+
from collections import Counter
|
| 4 |
+
import sys
|
| 5 |
+
sys.path.append("../")
|
| 6 |
+
|
| 7 |
+
class TorchVocab(object):
|
| 8 |
+
"""Defines a vocabulary object that will be used to numericalize a field.
|
| 9 |
+
Attributes:
|
| 10 |
+
freqs: A collections.Counter object holding the frequencies of tokens
|
| 11 |
+
in the data used to build the Vocab.
|
| 12 |
+
stoi: A collections.defaultdict instance mapping token strings to
|
| 13 |
+
numerical identifiers.
|
| 14 |
+
itos: A list of token strings indexed by their numerical identifiers.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, counter, max_size=None, min_freq=1, specials=['<pad>', '<oov>'],
|
| 18 |
+
vectors=None, unk_init=None, vectors_cache=None):
|
| 19 |
+
"""Create a Vocab object from a collections.Counter.
|
| 20 |
+
Arguments:
|
| 21 |
+
counter: collections.Counter object holding the frequencies of
|
| 22 |
+
each value found in the data.
|
| 23 |
+
max_size: The maximum size of the vocabulary, or None for no
|
| 24 |
+
maximum. Default: None.
|
| 25 |
+
min_freq: The minimum frequency needed to include a token in the
|
| 26 |
+
vocabulary. Values less than 1 will be set to 1. Default: 1.
|
| 27 |
+
specials: The list of special tokens (e.g., padding or eos) that
|
| 28 |
+
will be prepended to the vocabulary in addition to an <unk>
|
| 29 |
+
token. Default: ['<pad>']
|
| 30 |
+
vectors: One of either the available pretrained vectors
|
| 31 |
+
or custom pretrained vectors (see Vocab.load_vectors);
|
| 32 |
+
or a list of aforementioned vectors
|
| 33 |
+
unk_init (callback): by default, initialize out-of-vocabulary word vectors
|
| 34 |
+
to zero vectors; can be any function that takes in a Tensor and
|
| 35 |
+
returns a Tensor of the same size. Default: torch.Tensor.zero_
|
| 36 |
+
vectors_cache: directory for cached vectors. Default: '.vector_cache'
|
| 37 |
+
"""
|
| 38 |
+
self.freqs = counter
|
| 39 |
+
counter = counter.copy()
|
| 40 |
+
min_freq = max(min_freq, 1)
|
| 41 |
+
|
| 42 |
+
self.itos = list(specials)
|
| 43 |
+
# frequencies of special tokens are not counted when building vocabulary
|
| 44 |
+
# in frequency order
|
| 45 |
+
for tok in specials:
|
| 46 |
+
del counter[tok]
|
| 47 |
+
|
| 48 |
+
max_size = None if max_size is None else max_size + len(self.itos)
|
| 49 |
+
|
| 50 |
+
# sort by frequency, then alphabetically
|
| 51 |
+
words_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0])
|
| 52 |
+
words_and_frequencies.sort(key=lambda tup: tup[1], reverse=True)
|
| 53 |
+
|
| 54 |
+
for word, freq in words_and_frequencies:
|
| 55 |
+
if freq < min_freq or len(self.itos) == max_size:
|
| 56 |
+
break
|
| 57 |
+
self.itos.append(word)
|
| 58 |
+
|
| 59 |
+
# stoi is simply a reverse dict for itos
|
| 60 |
+
self.stoi = {tok: i for i, tok in enumerate(self.itos)}
|
| 61 |
+
|
| 62 |
+
self.vectors = None
|
| 63 |
+
if vectors is not None:
|
| 64 |
+
self.load_vectors(vectors, unk_init=unk_init, cache=vectors_cache)
|
| 65 |
+
else:
|
| 66 |
+
assert unk_init is None and vectors_cache is None
|
| 67 |
+
|
| 68 |
+
def __eq__(self, other):
|
| 69 |
+
if self.freqs != other.freqs:
|
| 70 |
+
return False
|
| 71 |
+
if self.stoi != other.stoi:
|
| 72 |
+
return False
|
| 73 |
+
if self.itos != other.itos:
|
| 74 |
+
return False
|
| 75 |
+
if self.vectors != other.vectors:
|
| 76 |
+
return False
|
| 77 |
+
return True
|
| 78 |
+
|
| 79 |
+
def __len__(self):
|
| 80 |
+
return len(self.itos)
|
| 81 |
+
|
| 82 |
+
def vocab_rerank(self):
|
| 83 |
+
self.stoi = {word: i for i, word in enumerate(self.itos)}
|
| 84 |
+
|
| 85 |
+
def extend(self, v, sort=False):
|
| 86 |
+
words = sorted(v.itos) if sort else v.itos
|
| 87 |
+
for w in words:
|
| 88 |
+
if w not in self.stoi:
|
| 89 |
+
self.itos.append(w)
|
| 90 |
+
self.stoi[w] = len(self.itos) - 1
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class Vocab(TorchVocab):
|
| 94 |
+
def __init__(self, counter, max_size=None, min_freq=1):
|
| 95 |
+
self.pad_index = 0
|
| 96 |
+
self.unk_index = 1
|
| 97 |
+
self.eos_index = 2
|
| 98 |
+
self.sos_index = 3
|
| 99 |
+
self.mask_index = 4
|
| 100 |
+
super().__init__(counter, specials=["<pad>", "<unk>", "<eos>", "<sos>", "<mask>"],
|
| 101 |
+
max_size=max_size, min_freq=min_freq)
|
| 102 |
+
|
| 103 |
+
def to_seq(self, sentece, seq_len, with_eos=False, with_sos=False) -> list:
|
| 104 |
+
pass
|
| 105 |
+
|
| 106 |
+
def from_seq(self, seq, join=False, with_pad=False):
|
| 107 |
+
pass
|
| 108 |
+
|
| 109 |
+
@staticmethod
|
| 110 |
+
def load_vocab(vocab_path: str) -> 'Vocab':
|
| 111 |
+
with open(vocab_path, "rb") as f:
|
| 112 |
+
return pickle.load(f)
|
| 113 |
+
|
| 114 |
+
def save_vocab(self, vocab_path):
|
| 115 |
+
with open(vocab_path, "wb") as f:
|
| 116 |
+
pickle.dump(self, f)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
# Building Vocab with text files
|
| 120 |
+
class WordVocab(Vocab):
|
| 121 |
+
def __init__(self, texts, max_size=None, min_freq=1):
|
| 122 |
+
print("Building Vocab")
|
| 123 |
+
counter = Counter()
|
| 124 |
+
for line in tqdm.tqdm(texts):
|
| 125 |
+
if isinstance(line, list):
|
| 126 |
+
words = line
|
| 127 |
+
else:
|
| 128 |
+
words = line.replace("\n", "").replace("\t", "").split()
|
| 129 |
+
|
| 130 |
+
for word in words:
|
| 131 |
+
counter[word] += 1
|
| 132 |
+
super().__init__(counter, max_size=max_size, min_freq=min_freq)
|
| 133 |
+
|
| 134 |
+
def to_seq(self, sentence, seq_len=None, with_eos=False, with_sos=False, with_len=False):
|
| 135 |
+
if isinstance(sentence, str):
|
| 136 |
+
sentence = sentence.split()
|
| 137 |
+
|
| 138 |
+
seq = [self.stoi.get(word, self.unk_index) for word in sentence]
|
| 139 |
+
|
| 140 |
+
if with_eos:
|
| 141 |
+
seq += [self.eos_index] # this would be index 1
|
| 142 |
+
if with_sos:
|
| 143 |
+
seq = [self.sos_index] + seq
|
| 144 |
+
|
| 145 |
+
origin_seq_len = len(seq)
|
| 146 |
+
|
| 147 |
+
if seq_len is None:
|
| 148 |
+
pass
|
| 149 |
+
elif len(seq) <= seq_len:
|
| 150 |
+
seq += [self.pad_index for _ in range(seq_len - len(seq))]
|
| 151 |
+
else:
|
| 152 |
+
seq = seq[:seq_len]
|
| 153 |
+
|
| 154 |
+
return (seq, origin_seq_len) if with_len else seq
|
| 155 |
+
|
| 156 |
+
def from_seq(self, seq, join=False, with_pad=False):
|
| 157 |
+
words = [self.itos[idx]
|
| 158 |
+
if idx < len(self.itos)
|
| 159 |
+
else "<%d>" % idx
|
| 160 |
+
for idx in seq
|
| 161 |
+
if not with_pad or idx != self.pad_index]
|
| 162 |
+
|
| 163 |
+
return " ".join(words) if join else words
|
| 164 |
+
|
| 165 |
+
@staticmethod
|
| 166 |
+
def load_vocab(vocab_path: str) -> 'WordVocab':
|
| 167 |
+
with open(vocab_path, "rb") as f:
|
| 168 |
+
return pickle.load(f)
|
| 169 |
+
|
bert_pytorch/model/bert.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
from .transformer import TransformerBlock
|
| 5 |
+
from .embedding import BERTEmbedding
|
| 6 |
+
|
| 7 |
+
class BERT(nn.Module):
|
| 8 |
+
"""
|
| 9 |
+
BERT model : Bidirectional Encoder Representations from Transformers.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
def __init__(self, vocab_size, max_len=512, hidden=768, n_layers=12, attn_heads=12, dropout=0.1, is_logkey=True, is_time=False):
|
| 13 |
+
"""
|
| 14 |
+
:param vocab_size: vocab_size of total words
|
| 15 |
+
:param hidden: BERT model hidden size
|
| 16 |
+
:param n_layers: numbers of Transformer blocks(layers)
|
| 17 |
+
:param attn_heads: number of attention heads
|
| 18 |
+
:param dropout: dropout rate
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.hidden = hidden
|
| 23 |
+
self.n_layers = n_layers
|
| 24 |
+
self.attn_heads = attn_heads
|
| 25 |
+
|
| 26 |
+
# paper noted they used 4*hidden_size for ff_network_hidden_size
|
| 27 |
+
self.feed_forward_hidden = hidden * 2
|
| 28 |
+
|
| 29 |
+
# embedding for BERT, sum of positional, segment, token embeddings
|
| 30 |
+
self.embedding = BERTEmbedding(vocab_size=vocab_size, embed_size=hidden, max_len=max_len, is_logkey=is_logkey, is_time=is_time)
|
| 31 |
+
|
| 32 |
+
# multi-layers transformer blocks, deep network
|
| 33 |
+
self.transformer_blocks = nn.ModuleList(
|
| 34 |
+
[TransformerBlock(hidden, attn_heads, hidden * 2, dropout) for _ in range(n_layers)])
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def forward(self, x, segment_info=None, time_info=None):
|
| 38 |
+
# attention masking for padded token
|
| 39 |
+
# torch.ByteTensor([batch_size, 1, seq_len, seq_len)
|
| 40 |
+
mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)
|
| 41 |
+
|
| 42 |
+
# embedding the indexed sequence to sequence of vectors
|
| 43 |
+
x = self.embedding(x, segment_info, time_info)
|
| 44 |
+
|
| 45 |
+
# running over multiple transformer blocks
|
| 46 |
+
for transformer in self.transformer_blocks:
|
| 47 |
+
x = transformer.forward(x, mask)
|
| 48 |
+
|
| 49 |
+
return x
|
bert_pytorch/model/embedding/bert.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch
|
| 3 |
+
from .token import TokenEmbedding
|
| 4 |
+
from .position import PositionalEmbedding
|
| 5 |
+
from .segment import SegmentEmbedding
|
| 6 |
+
from .time_embed import TimeEmbedding
|
| 7 |
+
|
| 8 |
+
class BERTEmbedding(nn.Module):
|
| 9 |
+
"""
|
| 10 |
+
BERT Embedding which is consisted with under features
|
| 11 |
+
1. TokenEmbedding : normal embedding matrix
|
| 12 |
+
2. PositionalEmbedding : adding positional information using sin, cos
|
| 13 |
+
2. SegmentEmbedding : adding sentence segment info, (sent_A:1, sent_B:2)
|
| 14 |
+
|
| 15 |
+
sum of all these features are output of BERTEmbedding
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, vocab_size, embed_size, max_len, dropout=0.1, is_logkey=True, is_time=False):
|
| 19 |
+
"""
|
| 20 |
+
:param vocab_size: total vocab size
|
| 21 |
+
:param embed_size: embedding size of token embedding
|
| 22 |
+
:param dropout: dropout rate
|
| 23 |
+
"""
|
| 24 |
+
super().__init__()
|
| 25 |
+
self.token = TokenEmbedding(vocab_size=vocab_size, embed_size=embed_size)
|
| 26 |
+
self.position = PositionalEmbedding(d_model=self.token.embedding_dim, max_len=max_len)
|
| 27 |
+
self.segment = SegmentEmbedding(embed_size=self.token.embedding_dim)
|
| 28 |
+
self.time_embed = TimeEmbedding(embed_size=self.token.embedding_dim)
|
| 29 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 30 |
+
self.embed_size = embed_size
|
| 31 |
+
self.is_logkey = is_logkey
|
| 32 |
+
self.is_time = is_time
|
| 33 |
+
|
| 34 |
+
def forward(self, sequence, segment_label=None, time_info=None):
|
| 35 |
+
x = self.position(sequence)
|
| 36 |
+
# if self.is_logkey:
|
| 37 |
+
x = x + self.token(sequence)
|
| 38 |
+
if segment_label is not None:
|
| 39 |
+
x = x + self.segment(segment_label)
|
| 40 |
+
if self.is_time:
|
| 41 |
+
x = x + self.time_embed(time_info)
|
| 42 |
+
return self.dropout(x)
|
bert_pytorch/model/embedding/position.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class PositionalEmbedding(nn.Module):
|
| 7 |
+
|
| 8 |
+
def __init__(self, d_model, max_len=512):
|
| 9 |
+
super().__init__()
|
| 10 |
+
|
| 11 |
+
# Compute the positional encodings once in log space.
|
| 12 |
+
pe = torch.zeros(max_len, d_model).float()
|
| 13 |
+
pe.require_grad = False
|
| 14 |
+
|
| 15 |
+
position = torch.arange(0, max_len).float().unsqueeze(1)
|
| 16 |
+
div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()
|
| 17 |
+
|
| 18 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 19 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 20 |
+
|
| 21 |
+
pe = pe.unsqueeze(0)
|
| 22 |
+
self.register_buffer('pe', pe)
|
| 23 |
+
|
| 24 |
+
def forward(self, x):
|
| 25 |
+
return self.pe[:, :x.size(1)]
|
bert_pytorch/model/embedding/segment.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class SegmentEmbedding(nn.Embedding):
|
| 5 |
+
def __init__(self, embed_size=512):
|
| 6 |
+
super().__init__(3, embed_size, padding_idx=0)
|
bert_pytorch/model/embedding/time_embed.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class TimeEmbedding(nn.Module):
|
| 5 |
+
def __init__(self, embed_size=512):
|
| 6 |
+
super().__init__()
|
| 7 |
+
self.time_embed = nn.Linear(1, embed_size)
|
| 8 |
+
|
| 9 |
+
def forward(self, time_interval):
|
| 10 |
+
return self.time_embed(time_interval)
|
bert_pytorch/model/embedding/token.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class TokenEmbedding(nn.Embedding):
|
| 5 |
+
def __init__(self, vocab_size, embed_size=512):
|
| 6 |
+
super().__init__(vocab_size, embed_size, padding_idx=0)
|
bert_pytorch/model/language_model.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
from .bert import BERT
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class BERTLM(nn.Module):
|
| 7 |
+
"""
|
| 8 |
+
BERT Language Model
|
| 9 |
+
Next Sentence Prediction Model + Masked Language Model
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
def __init__(self, bert: BERT, vocab_size):
|
| 13 |
+
"""
|
| 14 |
+
:param bert: BERT model which should be trained
|
| 15 |
+
:param vocab_size: total vocab size for masked_lm
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
super().__init__()
|
| 19 |
+
self.bert = bert
|
| 20 |
+
self.next_sentence = NextSentencePrediction(self.bert.hidden)
|
| 21 |
+
self.mask_lm = MaskedLanguageModel(self.bert.hidden, vocab_size)
|
| 22 |
+
|
| 23 |
+
def forward(self, x, segment_label):
|
| 24 |
+
x = self.bert(x, segment_label)
|
| 25 |
+
return self.next_sentence(x), self.mask_lm(x)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class NextSentencePrediction(nn.Module):
|
| 29 |
+
"""
|
| 30 |
+
2-class classification model : is_next, is_not_next
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(self, hidden):
|
| 34 |
+
"""
|
| 35 |
+
:param hidden: BERT model output size
|
| 36 |
+
"""
|
| 37 |
+
super().__init__()
|
| 38 |
+
self.linear = nn.Linear(hidden, 2)
|
| 39 |
+
self.softmax = nn.LogSoftmax(dim=-1)
|
| 40 |
+
|
| 41 |
+
def forward(self, x):
|
| 42 |
+
return self.softmax(self.linear(x[:, 0]))
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class MaskedLanguageModel(nn.Module):
|
| 46 |
+
"""
|
| 47 |
+
predicting origin token from masked input sequence
|
| 48 |
+
n-class classification problem, n-class = vocab_size
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
def __init__(self, hidden, vocab_size):
|
| 52 |
+
"""
|
| 53 |
+
:param hidden: output size of BERT model
|
| 54 |
+
:param vocab_size: total vocab size
|
| 55 |
+
"""
|
| 56 |
+
super().__init__()
|
| 57 |
+
self.linear = nn.Linear(hidden, vocab_size)
|
| 58 |
+
self.softmax = nn.LogSoftmax(dim=-1)
|
| 59 |
+
|
| 60 |
+
def forward(self, x):
|
| 61 |
+
return self.softmax(self.linear(x))
|
bert_pytorch/model/log_model.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from .bert import BERT
|
| 4 |
+
|
| 5 |
+
class BERTLog(nn.Module):
|
| 6 |
+
"""
|
| 7 |
+
BERT Log Model
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
def __init__(self, bert: BERT, vocab_size):
|
| 11 |
+
"""
|
| 12 |
+
:param bert: BERT model which should be trained
|
| 13 |
+
:param vocab_size: total vocab size for masked_lm
|
| 14 |
+
"""
|
| 15 |
+
super().__init__()
|
| 16 |
+
self.bert = bert
|
| 17 |
+
self.mask_lm = MaskedLogModel(self.bert.hidden, vocab_size)
|
| 18 |
+
self.time_lm = TimeLogModel(self.bert.hidden)
|
| 19 |
+
# self.fnn_cls = LinearCLS(self.bert.hidden)
|
| 20 |
+
# self.cls_lm = LogClassifier(self.bert.hidden)
|
| 21 |
+
|
| 22 |
+
def forward(self, x, time_info):
|
| 23 |
+
x = self.bert(x, time_info=time_info) # [batch, seq_len, hidden]
|
| 24 |
+
|
| 25 |
+
cls_output = x[:, 0] # [CLS] token vector from BERT
|
| 26 |
+
|
| 27 |
+
return {
|
| 28 |
+
"logkey_output": self.mask_lm(x), # [batch, seq_len, vocab_size]
|
| 29 |
+
"time_output": self.time_lm(x), # optional
|
| 30 |
+
"cls_output": cls_output, # [batch, hidden]
|
| 31 |
+
"cls_fnn_output": None, # unused for now
|
| 32 |
+
"token_embeddings": x[0] # [seq_len, hidden] for first batch element
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class MaskedLogModel(nn.Module):
|
| 37 |
+
"""
|
| 38 |
+
Predicting original token from masked input sequence
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def __init__(self, hidden, vocab_size):
|
| 42 |
+
super().__init__()
|
| 43 |
+
self.linear = nn.Linear(hidden, vocab_size)
|
| 44 |
+
self.softmax = nn.LogSoftmax(dim=-1)
|
| 45 |
+
|
| 46 |
+
def forward(self, x):
|
| 47 |
+
return self.softmax(self.linear(x))
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class TimeLogModel(nn.Module):
|
| 51 |
+
def __init__(self, hidden, time_size=1):
|
| 52 |
+
super().__init__()
|
| 53 |
+
self.linear = nn.Linear(hidden, time_size)
|
| 54 |
+
|
| 55 |
+
def forward(self, x):
|
| 56 |
+
return self.linear(x)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class LogClassifier(nn.Module):
|
| 60 |
+
def __init__(self, hidden):
|
| 61 |
+
super().__init__()
|
| 62 |
+
self.linear = nn.Linear(hidden, hidden)
|
| 63 |
+
|
| 64 |
+
def forward(self, cls):
|
| 65 |
+
return self.linear(cls)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class LinearCLS(nn.Module):
|
| 69 |
+
def __init__(self, hidden):
|
| 70 |
+
super().__init__()
|
| 71 |
+
self.linear = nn.Linear(hidden, hidden)
|
| 72 |
+
|
| 73 |
+
def forward(self, x):
|
| 74 |
+
return self.linear(x)
|
bert_pytorch/model/transformer.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
from .attention import MultiHeadedAttention
|
| 4 |
+
from .utils import SublayerConnection, PositionwiseFeedForward
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class TransformerBlock(nn.Module):
|
| 8 |
+
"""
|
| 9 |
+
Bidirectional Encoder = Transformer (self-attention)
|
| 10 |
+
Transformer = MultiHead_Attention + Feed_Forward with sublayer connection
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
def __init__(self, hidden, attn_heads, feed_forward_hidden, dropout):
|
| 14 |
+
"""
|
| 15 |
+
:param hidden: hidden size of transformer
|
| 16 |
+
:param attn_heads: head sizes of multi-head attention
|
| 17 |
+
:param feed_forward_hidden: feed_forward_hidden, usually 4*hidden_size
|
| 18 |
+
:param dropout: dropout rate
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.attention = MultiHeadedAttention(h=attn_heads, d_model=hidden)
|
| 23 |
+
self.feed_forward = PositionwiseFeedForward(d_model=hidden, d_ff=feed_forward_hidden, dropout=dropout)
|
| 24 |
+
self.input_sublayer = SublayerConnection(size=hidden, dropout=dropout)
|
| 25 |
+
self.output_sublayer = SublayerConnection(size=hidden, dropout=dropout)
|
| 26 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 27 |
+
|
| 28 |
+
def forward(self, x, mask):
|
| 29 |
+
x = self.input_sublayer(x, lambda _x: self.attention.forward(_x, _x, _x, mask=mask))
|
| 30 |
+
x = self.output_sublayer(x, self.feed_forward)
|
| 31 |
+
return self.dropout(x)
|
bert_pytorch/predict_log.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import scipy.stats as stats
|
| 3 |
+
import seaborn as sns
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import pickle
|
| 6 |
+
import time
|
| 7 |
+
import torch
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
from torch.utils.data import DataLoader
|
| 10 |
+
|
| 11 |
+
from bert_pytorch.dataset import WordVocab
|
| 12 |
+
from bert_pytorch.dataset import LogDataset
|
| 13 |
+
from bert_pytorch.dataset.sample import fixed_window
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def compute_anomaly(results, params, seq_threshold=0.5):
|
| 17 |
+
is_logkey = params["is_logkey"]
|
| 18 |
+
is_time = params["is_time"]
|
| 19 |
+
total_errors = 0
|
| 20 |
+
for seq_res in results:
|
| 21 |
+
# label pairs as anomaly when over half of masked tokens are undetected
|
| 22 |
+
if (is_logkey and seq_res["undetected_tokens"] > seq_res["masked_tokens"] * seq_threshold) or \
|
| 23 |
+
(is_time and seq_res["num_error"]> seq_res["masked_tokens"] * seq_threshold) or \
|
| 24 |
+
(params["hypersphere_loss_test"] and seq_res["deepSVDD_label"]):
|
| 25 |
+
total_errors += 1
|
| 26 |
+
return total_errors
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def find_best_threshold(test_normal_results, test_abnormal_results, params, th_range, seq_range):
|
| 30 |
+
best_result = [0] * 9
|
| 31 |
+
for seq_th in seq_range:
|
| 32 |
+
FP = compute_anomaly(test_normal_results, params, seq_th)
|
| 33 |
+
TP = compute_anomaly(test_abnormal_results, params, seq_th)
|
| 34 |
+
|
| 35 |
+
if TP == 0:
|
| 36 |
+
continue
|
| 37 |
+
|
| 38 |
+
TN = len(test_normal_results) - FP
|
| 39 |
+
FN = len(test_abnormal_results) - TP
|
| 40 |
+
P = 100 * TP / (TP + FP)
|
| 41 |
+
R = 100 * TP / (TP + FN)
|
| 42 |
+
F1 = 2 * P * R / (P + R)
|
| 43 |
+
|
| 44 |
+
if F1 > best_result[-1]:
|
| 45 |
+
best_result = [0, seq_th, FP, TP, TN, FN, P, R, F1]
|
| 46 |
+
return best_result
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class Predictor():
|
| 50 |
+
def __init__(self, options):
|
| 51 |
+
self.model_path = options["model_path"]
|
| 52 |
+
self.vocab_path = options["vocab_path"]
|
| 53 |
+
self.device = options["device"]
|
| 54 |
+
self.window_size = options["window_size"]
|
| 55 |
+
self.adaptive_window = options["adaptive_window"]
|
| 56 |
+
self.seq_len = options["seq_len"]
|
| 57 |
+
self.corpus_lines = options["corpus_lines"]
|
| 58 |
+
self.on_memory = options["on_memory"]
|
| 59 |
+
self.batch_size = options["batch_size"]
|
| 60 |
+
self.num_workers = options["num_workers"]
|
| 61 |
+
self.num_candidates = options["num_candidates"]
|
| 62 |
+
self.output_dir = options["output_dir"]
|
| 63 |
+
self.model_dir = options["model_dir"]
|
| 64 |
+
self.gaussian_mean = options["gaussian_mean"]
|
| 65 |
+
self.gaussian_std = options["gaussian_std"]
|
| 66 |
+
|
| 67 |
+
self.is_logkey = options["is_logkey"]
|
| 68 |
+
self.is_time = options["is_time"]
|
| 69 |
+
self.scale_path = options["scale_path"]
|
| 70 |
+
|
| 71 |
+
self.hypersphere_loss = options["hypersphere_loss"]
|
| 72 |
+
self.hypersphere_loss_test = options["hypersphere_loss_test"]
|
| 73 |
+
|
| 74 |
+
self.lower_bound = self.gaussian_mean - 3 * self.gaussian_std
|
| 75 |
+
self.upper_bound = self.gaussian_mean + 3 * self.gaussian_std
|
| 76 |
+
|
| 77 |
+
self.center = None
|
| 78 |
+
self.radius = None
|
| 79 |
+
self.test_ratio = options["test_ratio"]
|
| 80 |
+
self.mask_ratio = options["mask_ratio"]
|
| 81 |
+
self.min_len=options["min_len"]
|
| 82 |
+
|
| 83 |
+
def detect_logkey_anomaly(self, masked_output, masked_label):
|
| 84 |
+
num_undetected_tokens = 0
|
| 85 |
+
output_maskes = []
|
| 86 |
+
for i, token in enumerate(masked_label):
|
| 87 |
+
# output_maskes.append(torch.argsort(-masked_output[i])[:30].cpu().numpy()) # extract top 30 candidates for mask labels
|
| 88 |
+
|
| 89 |
+
if token not in torch.argsort(-masked_output[i])[:self.num_candidates]:
|
| 90 |
+
num_undetected_tokens += 1
|
| 91 |
+
|
| 92 |
+
return num_undetected_tokens, [output_maskes, masked_label.cpu().numpy()]
|
| 93 |
+
|
| 94 |
+
@staticmethod
|
| 95 |
+
def generate_test(output_dir, file_name, window_size, adaptive_window, seq_len, scale, min_len):
|
| 96 |
+
"""
|
| 97 |
+
:return: log_seqs: num_samples x session(seq)_length, tim_seqs: num_samples x session_length
|
| 98 |
+
"""
|
| 99 |
+
log_seqs = []
|
| 100 |
+
tim_seqs = []
|
| 101 |
+
with open(output_dir + file_name, "r") as f:
|
| 102 |
+
for idx, line in tqdm(enumerate(f.readlines())):
|
| 103 |
+
#if idx > 40: break
|
| 104 |
+
log_seq, tim_seq = fixed_window(line, window_size,
|
| 105 |
+
adaptive_window=adaptive_window,
|
| 106 |
+
seq_len=seq_len, min_len=min_len)
|
| 107 |
+
if len(log_seq) == 0:
|
| 108 |
+
continue
|
| 109 |
+
|
| 110 |
+
# if scale is not None:
|
| 111 |
+
# times = tim_seq
|
| 112 |
+
# for i, tn in enumerate(times):
|
| 113 |
+
# tn = np.array(tn).reshape(-1, 1)
|
| 114 |
+
# times[i] = scale.transform(tn).reshape(-1).tolist()
|
| 115 |
+
# tim_seq = times
|
| 116 |
+
|
| 117 |
+
log_seqs += log_seq
|
| 118 |
+
tim_seqs += tim_seq
|
| 119 |
+
|
| 120 |
+
# sort seq_pairs by seq len
|
| 121 |
+
log_seqs = np.array(log_seqs, dtype=object)
|
| 122 |
+
tim_seqs = np.array(tim_seqs, dtype=object)
|
| 123 |
+
|
| 124 |
+
test_len = list(map(len, log_seqs))
|
| 125 |
+
test_sort_index = np.argsort(-1 * np.array(test_len))
|
| 126 |
+
|
| 127 |
+
log_seqs = log_seqs[test_sort_index]
|
| 128 |
+
tim_seqs = tim_seqs[test_sort_index]
|
| 129 |
+
|
| 130 |
+
print(f"{file_name} size: {len(log_seqs)}")
|
| 131 |
+
return log_seqs, tim_seqs
|
| 132 |
+
|
| 133 |
+
def helper(self, model, output_dir, file_name, vocab, scale=None, error_dict=None):
|
| 134 |
+
total_results = []
|
| 135 |
+
total_errors = []
|
| 136 |
+
output_results = []
|
| 137 |
+
total_dist = []
|
| 138 |
+
output_cls = []
|
| 139 |
+
logkey_test, time_test = self.generate_test(output_dir, file_name, self.window_size, self.adaptive_window, self.seq_len, scale, self.min_len)
|
| 140 |
+
|
| 141 |
+
# use 1/10 test data
|
| 142 |
+
if self.test_ratio != 1:
|
| 143 |
+
num_test = len(logkey_test)
|
| 144 |
+
rand_index = torch.randperm(num_test)
|
| 145 |
+
rand_index = rand_index[:int(num_test * self.test_ratio)] if isinstance(self.test_ratio, float) else rand_index[:self.test_ratio]
|
| 146 |
+
logkey_test, time_test = logkey_test[rand_index], time_test[rand_index]
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
seq_dataset = LogDataset(logkey_test, time_test, vocab, seq_len=self.seq_len,
|
| 150 |
+
corpus_lines=self.corpus_lines, on_memory=self.on_memory, predict_mode=True, mask_ratio=self.mask_ratio)
|
| 151 |
+
|
| 152 |
+
# use large batch size in test data
|
| 153 |
+
data_loader = DataLoader(seq_dataset, batch_size=self.batch_size, num_workers=self.num_workers,
|
| 154 |
+
collate_fn=seq_dataset.collate_fn)
|
| 155 |
+
|
| 156 |
+
for idx, data in enumerate(data_loader):
|
| 157 |
+
data = {key: value.to(self.device) for key, value in data.items()}
|
| 158 |
+
|
| 159 |
+
result = model(data["bert_input"], data["time_input"])
|
| 160 |
+
|
| 161 |
+
# mask_lm_output, mask_tm_output: batch_size x session_size x vocab_size
|
| 162 |
+
# cls_output: batch_size x hidden_size
|
| 163 |
+
# bert_label, time_label: batch_size x session_size
|
| 164 |
+
# in session, some logkeys are masked
|
| 165 |
+
|
| 166 |
+
mask_lm_output, mask_tm_output = result["logkey_output"], result["time_output"]
|
| 167 |
+
output_cls += result["cls_output"].tolist()
|
| 168 |
+
|
| 169 |
+
# dist = torch.sum((result["cls_output"] - self.hyper_center) ** 2, dim=1)
|
| 170 |
+
# when visualization no mask
|
| 171 |
+
# continue
|
| 172 |
+
|
| 173 |
+
# loop though each session in batch
|
| 174 |
+
for i in range(len(data["bert_label"])):
|
| 175 |
+
seq_results = {"num_error": 0,
|
| 176 |
+
"undetected_tokens": 0,
|
| 177 |
+
"masked_tokens": 0,
|
| 178 |
+
"total_logkey": torch.sum(data["bert_input"][i] > 0).item(),
|
| 179 |
+
"deepSVDD_label": 0
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
mask_index = data["bert_label"][i] > 0
|
| 183 |
+
num_masked = torch.sum(mask_index).tolist()
|
| 184 |
+
seq_results["masked_tokens"] = num_masked
|
| 185 |
+
|
| 186 |
+
if self.is_logkey:
|
| 187 |
+
num_undetected, output_seq = self.detect_logkey_anomaly(
|
| 188 |
+
mask_lm_output[i][mask_index], data["bert_label"][i][mask_index])
|
| 189 |
+
seq_results["undetected_tokens"] = num_undetected
|
| 190 |
+
|
| 191 |
+
output_results.append(output_seq)
|
| 192 |
+
|
| 193 |
+
if self.hypersphere_loss_test:
|
| 194 |
+
# detect by deepSVDD distance
|
| 195 |
+
assert result["cls_output"][i].size() == self.center.size()
|
| 196 |
+
# dist = torch.sum((result["cls_fnn_output"][i] - self.center) ** 2)
|
| 197 |
+
dist = torch.sqrt(torch.sum((result["cls_output"][i] - self.center) ** 2))
|
| 198 |
+
total_dist.append(dist.item())
|
| 199 |
+
|
| 200 |
+
# user defined threshold for deepSVDD_label
|
| 201 |
+
seq_results["deepSVDD_label"] = int(dist.item() > self.radius)
|
| 202 |
+
#
|
| 203 |
+
# if dist > 0.25:
|
| 204 |
+
# pass
|
| 205 |
+
|
| 206 |
+
if idx < 10 or idx % 1000 == 0:
|
| 207 |
+
print(
|
| 208 |
+
"{}, #time anomaly: {} # of undetected_tokens: {}, # of masked_tokens: {} , "
|
| 209 |
+
"# of total logkey {}, deepSVDD_label: {} \n".format(
|
| 210 |
+
file_name,
|
| 211 |
+
seq_results["num_error"],
|
| 212 |
+
seq_results["undetected_tokens"],
|
| 213 |
+
seq_results["masked_tokens"],
|
| 214 |
+
seq_results["total_logkey"],
|
| 215 |
+
seq_results['deepSVDD_label']
|
| 216 |
+
)
|
| 217 |
+
)
|
| 218 |
+
total_results.append(seq_results)
|
| 219 |
+
|
| 220 |
+
# for time
|
| 221 |
+
# return total_results, total_errors
|
| 222 |
+
|
| 223 |
+
#for logkey
|
| 224 |
+
# return total_results, output_results
|
| 225 |
+
|
| 226 |
+
# for hypersphere distance
|
| 227 |
+
return total_results, output_cls
|
| 228 |
+
|
| 229 |
+
def predict(self):
|
| 230 |
+
model = torch.load(self.model_path, weights_only=False)
|
| 231 |
+
model.to(self.device)
|
| 232 |
+
model.eval()
|
| 233 |
+
print('model_path: {}'.format(self.model_path))
|
| 234 |
+
|
| 235 |
+
start_time = time.time()
|
| 236 |
+
vocab = WordVocab.load_vocab(self.vocab_path)
|
| 237 |
+
|
| 238 |
+
scale = None
|
| 239 |
+
error_dict = None
|
| 240 |
+
if self.is_time:
|
| 241 |
+
with open(self.scale_path, "rb") as f:
|
| 242 |
+
scale = pickle.load(f)
|
| 243 |
+
|
| 244 |
+
with open(self.model_dir + "error_dict.pkl", 'rb') as f:
|
| 245 |
+
error_dict = pickle.load(f)
|
| 246 |
+
|
| 247 |
+
if self.hypersphere_loss:
|
| 248 |
+
center_dict = torch.load(self.model_dir + "best_center.pt", weights_only=False)
|
| 249 |
+
self.center = center_dict["center"]
|
| 250 |
+
self.radius = center_dict["radius"]
|
| 251 |
+
# self.center = self.center.view(1,-1)
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
print("test normal predicting")
|
| 255 |
+
test_normal_results, test_normal_errors = self.helper(model, self.output_dir, "test_normal", vocab, scale, error_dict)
|
| 256 |
+
|
| 257 |
+
print("test abnormal predicting")
|
| 258 |
+
test_abnormal_results, test_abnormal_errors = self.helper(model, self.output_dir, "test_abnormal", vocab, scale, error_dict)
|
| 259 |
+
|
| 260 |
+
print("Saving test normal results")
|
| 261 |
+
with open(self.model_dir + "test_normal_results", "wb") as f:
|
| 262 |
+
pickle.dump(test_normal_results, f)
|
| 263 |
+
|
| 264 |
+
print("Saving test abnormal results")
|
| 265 |
+
with open(self.model_dir + "test_abnormal_results", "wb") as f:
|
| 266 |
+
pickle.dump(test_abnormal_results, f)
|
| 267 |
+
|
| 268 |
+
print("Saving test normal errors")
|
| 269 |
+
with open(self.model_dir + "test_normal_errors.pkl", "wb") as f:
|
| 270 |
+
pickle.dump(test_normal_errors, f)
|
| 271 |
+
|
| 272 |
+
print("Saving test abnormal results")
|
| 273 |
+
with open(self.model_dir + "test_abnormal_errors.pkl", "wb") as f:
|
| 274 |
+
pickle.dump(test_abnormal_errors, f)
|
| 275 |
+
|
| 276 |
+
params = {"is_logkey": self.is_logkey, "is_time": self.is_time, "hypersphere_loss": self.hypersphere_loss,
|
| 277 |
+
"hypersphere_loss_test": self.hypersphere_loss_test}
|
| 278 |
+
best_th, best_seq_th, FP, TP, TN, FN, P, R, F1 = find_best_threshold(test_normal_results,
|
| 279 |
+
test_abnormal_results,
|
| 280 |
+
params=params,
|
| 281 |
+
th_range=np.arange(10),
|
| 282 |
+
seq_range=np.arange(0,1,0.1))
|
| 283 |
+
|
| 284 |
+
print("best threshold: {}, best threshold ratio: {}".format(best_th, best_seq_th))
|
| 285 |
+
print("TP: {}, TN: {}, FP: {}, FN: {}".format(TP, TN, FP, FN))
|
| 286 |
+
print('Precision: {:.2f}%, Recall: {:.2f}%, F1-measure: {:.2f}%'.format(P, R, F1))
|
| 287 |
+
elapsed_time = time.time() - start_time
|
| 288 |
+
print('elapsed_time: {}'.format(elapsed_time))
|
| 289 |
+
|
| 290 |
+
|
bert_pytorch/train_log.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import gc
|
| 3 |
+
import torch
|
| 4 |
+
import tqdm
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import seaborn as sns
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
|
| 9 |
+
from torch.utils.data import DataLoader
|
| 10 |
+
from bert_pytorch.model import BERT
|
| 11 |
+
from bert_pytorch.trainer import BERTTrainer
|
| 12 |
+
from bert_pytorch.dataset import LogDataset, WordVocab
|
| 13 |
+
from bert_pytorch.dataset.sample import generate_train_valid
|
| 14 |
+
from bert_pytorch.dataset.utils import save_parameters
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class Trainer():
|
| 18 |
+
def __init__(self, options):
|
| 19 |
+
self.device = options["device"]
|
| 20 |
+
self.model_dir = options["model_dir"]
|
| 21 |
+
self.model_path = options["model_path"]
|
| 22 |
+
self.vocab_path = options["vocab_path"]
|
| 23 |
+
self.output_path = options["output_dir"]
|
| 24 |
+
self.window_size = options["window_size"]
|
| 25 |
+
self.adaptive_window = options["adaptive_window"]
|
| 26 |
+
self.sample_ratio = options["train_ratio"]
|
| 27 |
+
self.valid_ratio = options["valid_ratio"]
|
| 28 |
+
self.seq_len = options["seq_len"]
|
| 29 |
+
self.max_len = options["max_len"]
|
| 30 |
+
self.corpus_lines = options["corpus_lines"]
|
| 31 |
+
self.on_memory = options["on_memory"]
|
| 32 |
+
self.batch_size = options["batch_size"]
|
| 33 |
+
self.num_workers = options["num_workers"]
|
| 34 |
+
self.lr = options["lr"]
|
| 35 |
+
self.adam_beta1 = options["adam_beta1"]
|
| 36 |
+
self.adam_beta2 = options["adam_beta2"]
|
| 37 |
+
self.adam_weight_decay = options["adam_weight_decay"]
|
| 38 |
+
self.with_cuda = options["with_cuda"]
|
| 39 |
+
self.cuda_devices = options["cuda_devices"]
|
| 40 |
+
self.log_freq = options["log_freq"]
|
| 41 |
+
self.epochs = options["epochs"]
|
| 42 |
+
self.hidden = options["hidden"]
|
| 43 |
+
self.layers = options["layers"]
|
| 44 |
+
self.attn_heads = options["attn_heads"]
|
| 45 |
+
self.is_logkey = options["is_logkey"]
|
| 46 |
+
self.is_time = options["is_time"]
|
| 47 |
+
self.scale = options["scale"]
|
| 48 |
+
self.scale_path = options["scale_path"]
|
| 49 |
+
self.n_epochs_stop = options["n_epochs_stop"]
|
| 50 |
+
self.hypersphere_loss = options["hypersphere_loss"]
|
| 51 |
+
self.mask_ratio = options["mask_ratio"]
|
| 52 |
+
self.min_len = options["min_len"]
|
| 53 |
+
|
| 54 |
+
print("Save options parameters")
|
| 55 |
+
save_parameters(options, self.model_dir + "parameters.txt")
|
| 56 |
+
|
| 57 |
+
def train(self):
|
| 58 |
+
print("Loading vocab", self.vocab_path)
|
| 59 |
+
vocab = WordVocab.load_vocab(self.vocab_path)
|
| 60 |
+
print("vocab Size: ", len(vocab))
|
| 61 |
+
|
| 62 |
+
print("\nLoading Train Dataset")
|
| 63 |
+
train_file_path = os.path.join(self.output_path, "train")
|
| 64 |
+
logkey_train, logkey_valid, time_train, time_valid = generate_train_valid(
|
| 65 |
+
train_file_path,
|
| 66 |
+
window_size=self.window_size,
|
| 67 |
+
adaptive_window=self.adaptive_window,
|
| 68 |
+
valid_size=self.valid_ratio,
|
| 69 |
+
sample_ratio=self.sample_ratio,
|
| 70 |
+
scale=self.scale,
|
| 71 |
+
scale_path=self.scale_path,
|
| 72 |
+
seq_len=self.seq_len,
|
| 73 |
+
min_len=self.min_len
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
train_dataset = LogDataset(
|
| 77 |
+
logkey_train, time_train, vocab,
|
| 78 |
+
seq_len=self.seq_len,
|
| 79 |
+
corpus_lines=self.corpus_lines,
|
| 80 |
+
on_memory=self.on_memory,
|
| 81 |
+
mask_ratio=self.mask_ratio
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
print("\nLoading valid Dataset")
|
| 85 |
+
valid_dataset = LogDataset(
|
| 86 |
+
logkey_valid, time_valid, vocab,
|
| 87 |
+
seq_len=self.seq_len,
|
| 88 |
+
on_memory=self.on_memory,
|
| 89 |
+
mask_ratio=self.mask_ratio
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
print("Creating Dataloader")
|
| 93 |
+
self.train_data_loader = DataLoader(
|
| 94 |
+
train_dataset,
|
| 95 |
+
batch_size=self.batch_size,
|
| 96 |
+
num_workers=self.num_workers,
|
| 97 |
+
collate_fn=train_dataset.collate_fn,
|
| 98 |
+
drop_last=False
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
self.valid_data_loader = DataLoader(
|
| 102 |
+
valid_dataset,
|
| 103 |
+
batch_size=self.batch_size,
|
| 104 |
+
num_workers=self.num_workers,
|
| 105 |
+
collate_fn=train_dataset.collate_fn,
|
| 106 |
+
drop_last=False
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
del train_dataset
|
| 110 |
+
del valid_dataset
|
| 111 |
+
del logkey_train
|
| 112 |
+
del logkey_valid
|
| 113 |
+
del time_train
|
| 114 |
+
del time_valid
|
| 115 |
+
gc.collect()
|
| 116 |
+
|
| 117 |
+
print("Building BERT model")
|
| 118 |
+
bert = BERT(
|
| 119 |
+
len(vocab),
|
| 120 |
+
max_len=self.max_len,
|
| 121 |
+
hidden=self.hidden,
|
| 122 |
+
n_layers=self.layers,
|
| 123 |
+
attn_heads=self.attn_heads,
|
| 124 |
+
is_logkey=self.is_logkey,
|
| 125 |
+
is_time=self.is_time
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
print("Creating BERT Trainer")
|
| 129 |
+
self.trainer = BERTTrainer(
|
| 130 |
+
bert, len(vocab),
|
| 131 |
+
train_dataloader=self.train_data_loader,
|
| 132 |
+
valid_dataloader=self.valid_data_loader,
|
| 133 |
+
lr=self.lr,
|
| 134 |
+
betas=(self.adam_beta1, self.adam_beta2),
|
| 135 |
+
weight_decay=self.adam_weight_decay,
|
| 136 |
+
with_cuda=self.with_cuda,
|
| 137 |
+
cuda_devices=self.cuda_devices,
|
| 138 |
+
log_freq=self.log_freq,
|
| 139 |
+
is_logkey=self.is_logkey,
|
| 140 |
+
is_time=self.is_time,
|
| 141 |
+
hypersphere_loss=self.hypersphere_loss
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
self.start_iteration(surfix_log="log2")
|
| 145 |
+
self.plot_train_valid_loss("_log2")
|
| 146 |
+
|
| 147 |
+
def start_iteration(self, surfix_log):
|
| 148 |
+
print("Training Start")
|
| 149 |
+
best_loss = float('inf')
|
| 150 |
+
epochs_no_improve = 0
|
| 151 |
+
|
| 152 |
+
for epoch in range(self.epochs):
|
| 153 |
+
print("\n")
|
| 154 |
+
if self.hypersphere_loss:
|
| 155 |
+
center = self.calculate_center([self.train_data_loader, self.valid_data_loader])
|
| 156 |
+
self.trainer.hyper_center = center
|
| 157 |
+
|
| 158 |
+
_, train_dist = self.trainer.train(epoch)
|
| 159 |
+
avg_loss, valid_dist = self.trainer.valid(epoch)
|
| 160 |
+
|
| 161 |
+
self.trainer.save_log(self.model_dir, surfix_log)
|
| 162 |
+
|
| 163 |
+
if self.hypersphere_loss:
|
| 164 |
+
self.trainer.radius = self.trainer.get_radius(train_dist + valid_dist, self.trainer.nu)
|
| 165 |
+
|
| 166 |
+
if avg_loss < best_loss:
|
| 167 |
+
best_loss = avg_loss
|
| 168 |
+
self.trainer.save(self.model_path)
|
| 169 |
+
epochs_no_improve = 0
|
| 170 |
+
|
| 171 |
+
if epoch > 10 and self.hypersphere_loss:
|
| 172 |
+
best_center = self.trainer.hyper_center
|
| 173 |
+
best_radius = self.trainer.radius
|
| 174 |
+
total_dist = train_dist + valid_dist
|
| 175 |
+
|
| 176 |
+
if best_center is None:
|
| 177 |
+
raise TypeError("center is None")
|
| 178 |
+
|
| 179 |
+
print("best radius", best_radius)
|
| 180 |
+
|
| 181 |
+
best_center_path = self.model_dir + "best_center.pt"
|
| 182 |
+
print("Save best center", best_center_path)
|
| 183 |
+
torch.save({"center": best_center, "radius": best_radius}, best_center_path)
|
| 184 |
+
|
| 185 |
+
total_dist_path = self.model_dir + "best_total_dist.pt"
|
| 186 |
+
print("save total dist: ", total_dist_path)
|
| 187 |
+
torch.save(total_dist, total_dist_path)
|
| 188 |
+
else:
|
| 189 |
+
epochs_no_improve += 1
|
| 190 |
+
|
| 191 |
+
if epochs_no_improve == self.n_epochs_stop:
|
| 192 |
+
print("Early stopping")
|
| 193 |
+
break
|
| 194 |
+
|
| 195 |
+
def calculate_center(self, data_loader_list):
|
| 196 |
+
print("start calculate center")
|
| 197 |
+
with torch.no_grad():
|
| 198 |
+
outputs = 0
|
| 199 |
+
total_samples = 0
|
| 200 |
+
for data_loader in data_loader_list:
|
| 201 |
+
totol_length = len(data_loader)
|
| 202 |
+
data_iter = tqdm.tqdm(enumerate(data_loader), total=totol_length)
|
| 203 |
+
for i, data in data_iter:
|
| 204 |
+
data = {key: value.to(self.device) for key, value in data.items()}
|
| 205 |
+
result = self.trainer.model.forward(data["bert_input"], data["time_input"])
|
| 206 |
+
cls_output = result["cls_output"]
|
| 207 |
+
outputs += torch.sum(cls_output.detach().clone(), dim=0)
|
| 208 |
+
total_samples += cls_output.size(0)
|
| 209 |
+
center = outputs / total_samples
|
| 210 |
+
return center
|
| 211 |
+
|
| 212 |
+
def plot_train_valid_loss(self, surfix_log):
|
| 213 |
+
train_loss = pd.read_csv(self.model_dir + f"train{surfix_log}.csv")
|
| 214 |
+
valid_loss = pd.read_csv(self.model_dir + f"valid{surfix_log}.csv")
|
| 215 |
+
|
| 216 |
+
sns.lineplot(x="epoch", y="loss", data=train_loss, label="train loss")
|
| 217 |
+
sns.lineplot(x="epoch", y="loss", data=valid_loss, label="valid loss")
|
| 218 |
+
plt.title("epoch vs train loss vs valid loss")
|
| 219 |
+
plt.legend()
|
| 220 |
+
plt.savefig(self.model_dir + "train_valid_loss.png")
|
| 221 |
+
plt.show()
|
| 222 |
+
print("plot done")
|
logbert_rca_pipeline_api.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import re
|
| 4 |
+
import ast
|
| 5 |
+
import json
|
| 6 |
+
import time
|
| 7 |
+
import torch
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import numpy as np
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
from collections import defaultdict
|
| 12 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 13 |
+
from torch.utils.data import DataLoader
|
| 14 |
+
|
| 15 |
+
sys.path.append('../')
|
| 16 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
| 17 |
+
|
| 18 |
+
from logparser import Drain
|
| 19 |
+
from bert_pytorch.dataset import LogDataset, WordVocab
|
| 20 |
+
from bert_pytorch.model.bert import BERT
|
| 21 |
+
from bert_pytorch.model.log_model import BERTLog
|
| 22 |
+
|
| 23 |
+
# === Constants ===
|
| 24 |
+
TOP_EVENTS = 5
|
| 25 |
+
MAX_RCA_TOKENS = 200
|
| 26 |
+
MISTRAL_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
| 27 |
+
# HF_CACHE = "/content/drive/MyDrive/hf_cache"
|
| 28 |
+
|
| 29 |
+
# === Log Parsing ===
|
| 30 |
+
def parse_log_with_drain(log_file, input_dir, output_dir):
|
| 31 |
+
regex = [
|
| 32 |
+
r"appattempt_\d+_\d+_\d+",
|
| 33 |
+
r"job_\d+_\d+",
|
| 34 |
+
r"task_\d+_\d+_[a-z]+_\d+",
|
| 35 |
+
r"container_\d+",
|
| 36 |
+
r"\b(?:\d{1,3}\.){3}\d{1,3}\b",
|
| 37 |
+
r"(?<!\w)\d{5,}(?!\w)",
|
| 38 |
+
r"[a-f0-9]{8,}"
|
| 39 |
+
]
|
| 40 |
+
log_format = r'\[<AppId>] <Date> <Time> <Level> \[<Process>] <Component>: <Content>'
|
| 41 |
+
parser = Drain.LogParser(log_format, indir=input_dir, outdir=output_dir, depth=5, st=0.5, rex=regex, keep_para=True)
|
| 42 |
+
parser.parse(log_file)
|
| 43 |
+
|
| 44 |
+
def hadoop_sampling(structured_log_path, sequence_output_path):
|
| 45 |
+
df = pd.read_csv(structured_log_path)
|
| 46 |
+
data_dict = defaultdict(list)
|
| 47 |
+
for _, row in tqdm(df.iterrows(), total=len(df), desc="🔍 Grouping logs by AppId"):
|
| 48 |
+
app_id = row.get("AppId")
|
| 49 |
+
event_id = row.get("EventId")
|
| 50 |
+
if pd.notnull(app_id) and pd.notnull(event_id):
|
| 51 |
+
data_dict[app_id].append(str(event_id))
|
| 52 |
+
pd.DataFrame(list(data_dict.items()), columns=['AppId', 'EventSequence']).to_csv(sequence_output_path, index=False)
|
| 53 |
+
|
| 54 |
+
# === Utility Functions ===
|
| 55 |
+
def load_parameters(param_path):
|
| 56 |
+
options = {}
|
| 57 |
+
with open(param_path, 'r') as f:
|
| 58 |
+
for line in f:
|
| 59 |
+
if ':' not in line: continue
|
| 60 |
+
key, val = line.strip().split(':', 1)
|
| 61 |
+
key, val = key.strip(), val.strip()
|
| 62 |
+
if val.lower() in ['true', 'false', 'none']:
|
| 63 |
+
val = eval(val.capitalize())
|
| 64 |
+
else:
|
| 65 |
+
try: val = int(val)
|
| 66 |
+
except ValueError:
|
| 67 |
+
try: val = float(val)
|
| 68 |
+
except ValueError: pass
|
| 69 |
+
options[key] = val
|
| 70 |
+
return options
|
| 71 |
+
|
| 72 |
+
def load_logbert_model(options, vocab):
|
| 73 |
+
try:
|
| 74 |
+
return torch.load(options["model_path"], map_location=options["device"])
|
| 75 |
+
except:
|
| 76 |
+
bert = BERT(len(vocab), options["hidden"], options["layers"], options["attn_heads"], options["max_len"])
|
| 77 |
+
model = BERTLog(bert, vocab_size=len(vocab)).to(options["device"])
|
| 78 |
+
model.load_state_dict(torch.load(options["model_path"], map_location=options["device"]))
|
| 79 |
+
return model
|
| 80 |
+
|
| 81 |
+
def load_center(path, device):
|
| 82 |
+
center = torch.load(path, map_location=device)
|
| 83 |
+
return center["center"] if isinstance(center, dict) else center
|
| 84 |
+
|
| 85 |
+
def extract_sequences(path, min_len):
|
| 86 |
+
df = pd.read_csv(path)
|
| 87 |
+
data, app_ids = [], []
|
| 88 |
+
for _, row in df.iterrows():
|
| 89 |
+
try:
|
| 90 |
+
seq = ast.literal_eval(row["EventSequence"])
|
| 91 |
+
if len(seq) >= min_len:
|
| 92 |
+
data.append(seq)
|
| 93 |
+
app_ids.append(row["AppId"])
|
| 94 |
+
except:
|
| 95 |
+
continue
|
| 96 |
+
return data, app_ids
|
| 97 |
+
|
| 98 |
+
def prepare_dataloader(sequences, vocab, options):
|
| 99 |
+
dummy_times = [[0] * len(seq) for seq in sequences]
|
| 100 |
+
dataset = LogDataset(sequences, dummy_times, vocab, seq_len=options["seq_len"], on_memory=True, mask_ratio=options["mask_ratio"])
|
| 101 |
+
return DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=dataset.collate_fn)
|
| 102 |
+
|
| 103 |
+
def calculate_mean_std(loader, model, center, device):
|
| 104 |
+
scores = []
|
| 105 |
+
with torch.no_grad():
|
| 106 |
+
for batch in tqdm(loader, desc="📏 Computing train distances..."):
|
| 107 |
+
batch = {k: v.to(device) for k, v in batch.items()}
|
| 108 |
+
cls_output = model(batch["bert_input"], batch["time_input"])["cls_output"]
|
| 109 |
+
scores.append(torch.norm(cls_output - center, dim=1).item())
|
| 110 |
+
return np.mean(scores), np.std(scores)
|
| 111 |
+
|
| 112 |
+
def generate_prompt(event_templates):
|
| 113 |
+
prompt = "The system encountered a failure. Below are the key log events preceding the anomaly:\n\n"
|
| 114 |
+
for i, event in enumerate(event_templates, 1):
|
| 115 |
+
prompt += f"{i}. {event.strip()}\n"
|
| 116 |
+
prompt += "\nBased on the above log events, identify the most likely root cause of the issue.\n"
|
| 117 |
+
prompt += "Explain the cause in one or two sentences, using technical reasoning if possible.\n"
|
| 118 |
+
return prompt
|
| 119 |
+
|
| 120 |
+
def call_mistral(prompt, tokenizer, model, device):
|
| 121 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
| 122 |
+
outputs = model.generate(
|
| 123 |
+
**inputs,
|
| 124 |
+
max_length=inputs['input_ids'].shape[1] + MAX_RCA_TOKENS,
|
| 125 |
+
do_sample=False,
|
| 126 |
+
top_k=50,
|
| 127 |
+
pad_token_id=tokenizer.eos_token_id
|
| 128 |
+
)
|
| 129 |
+
return tokenizer.decode(outputs[0], skip_special_tokens=True)[len(prompt):].strip()
|
| 130 |
+
|
| 131 |
+
def compute_logkey_anomaly(masked_output, masked_label, top_k=5):
|
| 132 |
+
num_undetected = 0
|
| 133 |
+
for i, token in enumerate(masked_label):
|
| 134 |
+
if token not in torch.argsort(-masked_output[i])[:top_k]:
|
| 135 |
+
num_undetected += 1
|
| 136 |
+
return num_undetected, len(masked_label)
|
| 137 |
+
|
| 138 |
+
# === API-Compatible RCA Pipeline ===
|
| 139 |
+
def detect_anomalies_and_explain(input_log_path):
|
| 140 |
+
log_file = os.path.basename(input_log_path)
|
| 141 |
+
input_dir = os.path.dirname(input_log_path)
|
| 142 |
+
output_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "model", "bert"))
|
| 143 |
+
|
| 144 |
+
log_structured_file = os.path.join(output_dir, log_file + "_structured.csv")
|
| 145 |
+
log_templates_file = os.path.join(output_dir, log_file + "_templates.csv")
|
| 146 |
+
log_sequence_file = os.path.join(output_dir, "rca_abnormal_sequence.csv")
|
| 147 |
+
PARAMS_FILE = os.path.join(output_dir, "bert", "parameters.txt")
|
| 148 |
+
CENTER_PATH = os.path.join(output_dir, "bert", "best_center.pt")
|
| 149 |
+
TRAIN_FILE = os.path.join(output_dir, "train")
|
| 150 |
+
|
| 151 |
+
# Step 1: Preprocess Logs
|
| 152 |
+
parse_log_with_drain(log_file, input_dir, output_dir)
|
| 153 |
+
hadoop_sampling(log_structured_file, log_sequence_file)
|
| 154 |
+
|
| 155 |
+
# Step 2: Load Models and Parameters
|
| 156 |
+
options = load_parameters(PARAMS_FILE)
|
| 157 |
+
options["device"] = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 158 |
+
|
| 159 |
+
# tokenizer = AutoTokenizer.from_pretrained(MISTRAL_MODEL)
|
| 160 |
+
# model_mistral = AutoModelForCausalLM.from_pretrained(MISTRAL_MODEL, torch_dtype=torch.float32).to(options["device"])
|
| 161 |
+
# model_mistral.eval()
|
| 162 |
+
|
| 163 |
+
vocab = WordVocab.load_vocab(options["vocab_path"])
|
| 164 |
+
model = load_logbert_model(options, vocab).to(options["device"]).eval()
|
| 165 |
+
center = load_center(CENTER_PATH, options["device"])
|
| 166 |
+
|
| 167 |
+
# Step 3: Prepare Data
|
| 168 |
+
test_sequences, app_ids = extract_sequences(log_sequence_file, options["min_len"])
|
| 169 |
+
test_loader = prepare_dataloader(test_sequences, vocab, options)
|
| 170 |
+
|
| 171 |
+
train_sequences = [line.strip().split() for line in open(TRAIN_FILE) if len(line.strip().split()) >= options["min_len"]]
|
| 172 |
+
train_loader = prepare_dataloader(train_sequences, vocab, options)
|
| 173 |
+
mean, std = calculate_mean_std(train_loader, model, center, options["device"])
|
| 174 |
+
|
| 175 |
+
templates_df = pd.read_csv(log_templates_file)
|
| 176 |
+
event_template_dict = dict(zip(templates_df["EventId"], templates_df["EventTemplate"]))
|
| 177 |
+
|
| 178 |
+
# Step 4: Analyze & Explain Anomalies
|
| 179 |
+
results = []
|
| 180 |
+
for i, batch in enumerate(test_loader):
|
| 181 |
+
batch = {k: v.to(options["device"]) for k, v in batch.items()}
|
| 182 |
+
output = model(batch["bert_input"], batch["time_input"])
|
| 183 |
+
cls_output = output["cls_output"]
|
| 184 |
+
score = torch.norm(cls_output - center, dim=1).item()
|
| 185 |
+
z_score = (score - mean) / std
|
| 186 |
+
|
| 187 |
+
num_undetected, masked_total = compute_logkey_anomaly(output["logkey_output"][0], batch["bert_label"][0])
|
| 188 |
+
undetected_ratio = num_undetected / masked_total if masked_total else 0
|
| 189 |
+
|
| 190 |
+
status = "Abnormal" if z_score > 2 or undetected_ratio > 0.5 else "Normal"
|
| 191 |
+
if status == "Normal":
|
| 192 |
+
continue
|
| 193 |
+
|
| 194 |
+
top_eids = test_sequences[i][:TOP_EVENTS]
|
| 195 |
+
event_templates = [event_template_dict.get(eid, f"[Missing Event {eid}]") for eid in top_eids]
|
| 196 |
+
#prompt = ''#generate_prompt(event_templates)
|
| 197 |
+
#explanation = ''#call_mistral(prompt, tokenizer, model_mistral, options["device"])
|
| 198 |
+
|
| 199 |
+
results.append({
|
| 200 |
+
"AppId": app_ids[i],
|
| 201 |
+
"Score": score,
|
| 202 |
+
"z_score": z_score,
|
| 203 |
+
"UndetectedRatio": undetected_ratio,
|
| 204 |
+
"status":status,
|
| 205 |
+
"Events": event_templates,
|
| 206 |
+
"Explanation": None
|
| 207 |
+
})
|
| 208 |
+
|
| 209 |
+
return results
|
requirements.txt
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi
|
| 2 |
+
uvicorn
|
| 3 |
+
boto3
|
| 4 |
+
botocore
|
| 5 |
+
redis
|
| 6 |
+
python-dotenv
|
| 7 |
+
python-multipart
|
| 8 |
+
torch
|
| 9 |
+
transformers
|
| 10 |
+
tqdm
|
| 11 |
+
pandas
|
| 12 |
+
numpy
|
| 13 |
+
scikit-learn
|
| 14 |
+
databases
|
| 15 |
+
sqlalchemy
|
| 16 |
+
asyncpg
|
| 17 |
+
logparser
|
| 18 |
+
bert_pytorch
|
| 19 |
+
seaborn
|