|
|
|
|
|
|
|
|
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
from langchain.memory import ConversationBufferWindowMemory |
|
|
from peft import PeftModel |
|
|
import torch |
|
|
|
|
|
import json |
|
|
import sys |
|
|
|
|
|
|
|
|
if len(sys.argv) != 2: |
|
|
print("Usage: python finetune.py <jsonl_file>") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
jsonl_file_path = sys.argv[1] |
|
|
|
|
|
|
|
|
base_model = "mistralai/Mistral-7B-Instruct-v0.2" |
|
|
tokenizer = AutoTokenizer.from_pretrained(base_model) |
|
|
tokenizer.add_special_tokens({"pad_token": "[PAD]"}) |
|
|
base_model = AutoModelForCausalLM.from_pretrained(base_model) |
|
|
|
|
|
|
|
|
ft_model = PeftModel.from_pretrained(base_model, "./qlora-out") |
|
|
|
|
|
ft_model.eval() |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
ft_model.to(device) |
|
|
|
|
|
|
|
|
with open(jsonl_file_path, "r") as f: |
|
|
tp, tn, fp, fn = 0, 0, 0, 0 |
|
|
for line in f: |
|
|
data = json.loads(line) |
|
|
user_in = data["input"] |
|
|
user_input = f"[INST] ###instruction: Check if the given traffic flow is normal or of an attacker or a victim\n###input: {user_in}\n#output: [/INST]" |
|
|
encodings = tokenizer(user_input, return_tensors="pt", padding=True).to(device) |
|
|
input_ids = encodings["input_ids"] |
|
|
attention_mask = encodings["attention_mask"] |
|
|
|
|
|
output_ids = ft_model.generate(input_ids, attention_mask = attention_mask, max_new_tokens=1000, num_return_sequences=1, do_sample=True, temperature=0.1, top_p=0.9) |
|
|
|
|
|
generated_ids = output_ids[0, input_ids.shape[-1]:] |
|
|
|
|
|
|
|
|
response = tokenizer.decode(generated_ids, skip_special_tokens=True).lower() |
|
|
|
|
|
|
|
|
if "normal" not in response and data["output"] == response: |
|
|
tp += 1 |
|
|
elif "normal" in response and data["output"] == response: |
|
|
tn += 1 |
|
|
elif "normal" in response and data["output"] != response: |
|
|
fp += 1 |
|
|
elif "normal" not in response and data["output"] != response: |
|
|
fn += 1 |
|
|
else: |
|
|
print(f"Error: {response}, {data[output]}") |
|
|
print(f"User input: {user_in}") |
|
|
print(f"Generated response: {response}") |
|
|
print(f"Expected response: {data[output]}") |
|
|
print() |
|
|
|
|
|
print(f"TP: {tp}, TN: {tn}, FP: {fp}, FN: {fn}") |
|
|
|
|
|
|
|
|
|
|
|
|