|
|
import os |
|
|
|
|
|
|
|
|
|
|
|
cache_dir = '/network/rit/lab/Lai_ReSecureAI/kiel/wmm' |
|
|
|
|
|
os.environ['HF_HOME'] = cache_dir |
|
|
|
|
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" |
|
|
|
|
|
import matplotlib.pyplot as plt |
|
|
import logging |
|
|
import time |
|
|
import torch |
|
|
import json |
|
|
import torch.nn as nn |
|
|
from typing import Optional |
|
|
import pandas as pd |
|
|
from datasets import Dataset |
|
|
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training |
|
|
from dataclasses import dataclass, field |
|
|
from transformers import ( |
|
|
HfArgumentParser, |
|
|
AutoTokenizer, |
|
|
TrainingArguments, |
|
|
BitsAndBytesConfig, |
|
|
TrainerCallback, |
|
|
AutoModelForCausalLM |
|
|
) |
|
|
from trl import SFTTrainer |
|
|
import warnings |
|
|
|
|
|
|
|
|
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
model_name = "Llama2" |
|
|
WM = "TW" |
|
|
num_data = 10000 |
|
|
num_epochs = 5 |
|
|
learning_rate_ = 1e-5 |
|
|
|
|
|
|
|
|
|
|
|
print(f'Device: {device}') |
|
|
print(f'Model: {model_name}') |
|
|
print(f'WM: {WM}') |
|
|
print(f'Number of data: {num_data}') |
|
|
print(f'Number of epochs: {num_epochs}') |
|
|
print(f'Learning rate: {learning_rate_}') |
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
def load_data(file_path, num_data): |
|
|
with open(file_path, 'r') as f: |
|
|
data = json.load(f) |
|
|
return [ |
|
|
{ |
|
|
"text": "Now summarize the following text with maximum 60 words: " + |
|
|
item["article"] + |
|
|
"\nThe summary is: " + |
|
|
item['Watermarked_summary'] |
|
|
} |
|
|
for item in data[:num_data] |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
def create_dataset(data): |
|
|
""" |
|
|
Convert the concatenated data into a Hugging Face Dataset format. |
|
|
""" |
|
|
df = pd.DataFrame(data) |
|
|
return Dataset.from_pandas(df) |
|
|
|
|
|
def get_file_paths(model_name,WM): |
|
|
base_path = '/network/rit/lab/Lai_ReSecureAI/kiel/Website/Stealing/' |
|
|
if WM == "SafeSeal": |
|
|
paths = { |
|
|
'DeepSeek': ('DeepSeek_train_Summarization_Safeseal_top_3_threshold_0.8_Uniform_0_20000_20k.json', 'DeepSeek_test_Summarization_Safeseal_top_3_threshold_0.8_Uniform_0_1000_1000.json'), |
|
|
'Llama3': ('Llama3_train_Summarization_Safeseal_top_3_threshold_0.8_Uniform_0_20000_20k.json', 'Llama3_test_Summarization_Safeseal_top_3_threshold_0.8_Uniform_0_1000_1000.json') |
|
|
} |
|
|
elif WM == "DTM": |
|
|
paths = { |
|
|
'Llama3': ('Llama3_DTM_Summarization_train__20000.json', 'Llama3_DTM_Summarization_test__1000.json'), |
|
|
'DeepSeek': ('DeepSeek_DTM_Summarization_train__20000.json', 'DeepSeek_DTM_Summarization_test__1000.json'), |
|
|
'Llama2': ('Llama2_DTM_Summarization_train_20k.json', 'Llama2_DTM_Summary_test_1000.json'), |
|
|
'Mistral': ('Mistral_DTM_Summarization_train_20k.json', 'Mistral_DTM_Summary_test_1000.json') |
|
|
} |
|
|
elif WM == "KGW": |
|
|
paths = { |
|
|
'Llama3': ('Llama3_KGW_Summarization_train_0_20000_20000.json', 'Llama3_KGW_Summarization_test_0_1000_1000.json'), |
|
|
'DeepSeek': ('DeepSeek_KGW_Summarization_train_0_20000_20000.json', 'DeepSeek_KGW_Summarization_test_0_1000_1000.json') |
|
|
} |
|
|
elif WM == "SIR": |
|
|
paths = { |
|
|
'DeepSeek': ('DeepSeek_SIR_Summarization_train_0_20000_20000.json', 'DeepSeek_SIR_Summarization_test_0_1000_1000.json'), |
|
|
'Llama3': ('Llama3_SIR_Summarization_train_0_20000_20000.json', 'Llama3_SIR_Summarization_test_0_1000_1000.json') |
|
|
} |
|
|
elif WM == "SynthID": |
|
|
paths = { |
|
|
'DeepSeek': ('DeepSeek_SynthID_Summarization_train_0_20000_20000.json', 'DeepSeek_SynthID_Summarization_test_0_1000_1000.json'), |
|
|
'Llama3': ('Llama3_SynthID_Summarization_train_0_20000_20000.json', 'Llama3_SynthID_Summarization_test_0_1000_1000.json') |
|
|
} |
|
|
elif WM == "TW": |
|
|
paths = { |
|
|
'DeepSeek': ('DeepSeek_TW_Summarization_train_20000.json', 'DeepSeek_TW_Summarization_test__1000.json'), |
|
|
'Llama3': ('Llama3_TW_Summarization_train__20000.json', 'Llama3_TW_Summarization_test__1000.json'), |
|
|
'Llama2': ('Llama2_TW_Summarization_train_20k.json', 'Llama2_TW_Summary_test_1000.json'), |
|
|
'Mistral': ('Mistral_TW_Summarization_train_20k.json', 'Mistral_TW_Summary_Test_1000.json') |
|
|
} |
|
|
|
|
|
return base_path + paths[model_name][0], base_path + paths[model_name][1] |
|
|
|
|
|
def get_new_model_path(model_name,WM, num_epochs, learning_rate_, num_data): |
|
|
|
|
|
return f"./adversary_models/{model_name}_{WM}_epoch{num_epochs}_lr{learning_rate_}_data{num_data}_" |
|
|
|
|
|
|
|
|
train_file, test_file = get_file_paths(model_name, WM) |
|
|
train_data = load_data(train_file, num_data) |
|
|
test_data = load_data(test_file, num_data) |
|
|
|
|
|
train_dataset = create_dataset(train_data) |
|
|
test_dataset = create_dataset(test_data) |
|
|
|
|
|
new_model = get_new_model_path(model_name, WM, num_epochs, learning_rate_, num_data) |
|
|
print(f'New model path: {new_model}') |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ScriptArguments: |
|
|
use_8_bit: Optional[bool] = field(default=False, metadata={"help": "use 8 bit precision"}) |
|
|
use_4_bit: Optional[bool] = field(default=False, metadata={"help": "use 4 bit precision"}) |
|
|
bnb_4bit_quant_type: Optional[str] = field(default="nf4", metadata={"help": "precise the quantization type (fp4 or nf4)"}) |
|
|
use_bnb_nested_quant: Optional[bool] = field(default=False, metadata={"help": "use nested quantization"}) |
|
|
use_multi_gpu: Optional[bool] = field(default=True, metadata={"help": "use multi GPU"}) |
|
|
use_adapters: Optional[bool] = field(default=True, metadata={"help": "use adapters"}) |
|
|
batch_size: Optional[int] = field(default=8, metadata={"help": "input batch size"}) |
|
|
max_seq_length: Optional[int] = field(default=400, metadata={"help": "max sequence length"}) |
|
|
optimizer_name: Optional[str] = field(default="adamw_hf", metadata={"help": "Optimizer name"}) |
|
|
|
|
|
parser = HfArgumentParser(ScriptArguments) |
|
|
script_args = parser.parse_args_into_dataclasses()[0] |
|
|
|
|
|
|
|
|
device_map = "auto" if script_args.use_multi_gpu else "cpu" |
|
|
|
|
|
|
|
|
if script_args.use_8_bit and script_args.use_4_bit: |
|
|
raise ValueError("You can't use 8 bit and 4 bit precision at the same time") |
|
|
|
|
|
bnb_config = BitsAndBytesConfig( |
|
|
load_in_4bit=True, |
|
|
bnb_4bit_compute_dtype=torch.float16, |
|
|
bnb_4bit_quant_type=script_args.bnb_4bit_quant_type, |
|
|
bnb_4bit_use_double_quant=script_args.use_bnb_nested_quant, |
|
|
) if script_args.use_4_bit else None |
|
|
|
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
"meta-llama/Meta-Llama-3-8B" if model_name == 'Llama3' |
|
|
else "meta-llama/Llama-2-7b-chat-hf" if model_name == 'Llama2' |
|
|
else "mistralai/Mistral-7B-Instruct-v0.2" if model_name == 'Mistral' |
|
|
else "deepseek-ai/deepseek-llm-7b-base", |
|
|
cache_dir=cache_dir, |
|
|
quantization_config=bnb_config, |
|
|
device_map={"": 0} |
|
|
) |
|
|
|
|
|
model.config.use_cache = False |
|
|
model.config.pretraining_tp = 1 |
|
|
model = prepare_model_for_kbit_training(model) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
"meta-llama/Meta-Llama-3-8B" if model_name == 'Llama3' |
|
|
else "meta-llama/Llama-2-7b-chat-hf" if model_name == 'Llama2' |
|
|
else "mistralai/Mistral-7B-Instruct-v0.2" if model_name == 'Mistral' |
|
|
else "deepseek-ai/deepseek-llm-7b-base", |
|
|
use_fast=False |
|
|
) |
|
|
|
|
|
tokenizer.add_special_tokens({'pad_token': '[PAD]'}) |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
|
|
|
peft_config = LoraConfig( |
|
|
lora_alpha=32, |
|
|
lora_dropout=0.05, |
|
|
r=16, |
|
|
target_modules= ['q_proj','k_proj','v_proj','o_proj','gate_proj','down_proj','up_proj','lm_head'], |
|
|
bias="none", |
|
|
task_type="CAUSAL_LM", |
|
|
) |
|
|
|
|
|
|
|
|
os.makedirs(new_model, exist_ok=True) |
|
|
|
|
|
|
|
|
class LoggingCallback(TrainerCallback): |
|
|
def on_log(self, args, state, control, logs=None, **kwargs): |
|
|
if logs: |
|
|
output_log_file = os.path.join(args.output_dir, "train_results.json") |
|
|
with open(output_log_file, "a") as writer: |
|
|
writer.write(json.dumps(logs) + "\n") |
|
|
|
|
|
|
|
|
training_arguments = TrainingArguments( |
|
|
num_train_epochs=num_epochs, |
|
|
evaluation_strategy="steps", |
|
|
save_steps=-1, |
|
|
save_total_limit=1, |
|
|
logging_steps=500, |
|
|
eval_steps=500, |
|
|
learning_rate=learning_rate_, |
|
|
weight_decay=0.001, |
|
|
per_device_train_batch_size=script_args.batch_size, |
|
|
max_steps=-1, |
|
|
gradient_accumulation_steps=4, |
|
|
per_device_eval_batch_size=script_args.batch_size, |
|
|
output_dir=new_model, |
|
|
max_grad_norm=0.3, |
|
|
warmup_ratio=0.03, |
|
|
lr_scheduler_type="constant", |
|
|
optim=script_args.optimizer_name, |
|
|
fp16=True, |
|
|
logging_strategy="steps", |
|
|
log_level='info' |
|
|
) |
|
|
|
|
|
trainer = SFTTrainer( |
|
|
model=model, |
|
|
tokenizer=tokenizer, |
|
|
train_dataset=train_dataset, |
|
|
eval_dataset=test_dataset, |
|
|
dataset_text_field="text", |
|
|
peft_config=peft_config, |
|
|
max_seq_length=script_args.max_seq_length, |
|
|
args=training_arguments, |
|
|
callbacks=[LoggingCallback()] |
|
|
) |
|
|
|
|
|
trainer.train() |
|
|
trainer.model.save_pretrained(new_model) |
|
|
trainer.tokenizer.save_pretrained(new_model) |
|
|
print('Done in ', time.time() - start_time) |
|
|
|
|
|
|
|
|
epochs, train_losses, eval_losses = [], [], [] |
|
|
|
|
|
|
|
|
eval_results_file = os.path.join(new_model, "train_results.json") |
|
|
with open(eval_results_file, "r") as f: |
|
|
for line in f: |
|
|
data = json.loads(line) |
|
|
if 'epoch' in data: |
|
|
epoch = data['epoch'] |
|
|
if 'loss' in data: |
|
|
train_losses.append(data['loss']) |
|
|
epochs.append(epoch) |
|
|
if 'eval_loss' in data: |
|
|
eval_losses.append(data['eval_loss']) |
|
|
if epoch not in epochs: |
|
|
epochs.append(epoch) |
|
|
|
|
|
|
|
|
plt.figure(figsize=(10, 5)) |
|
|
plt.plot(epochs[:len(train_losses)], train_losses, label='Train Loss', color='blue') |
|
|
plt.plot(epochs[:len(eval_losses)], eval_losses, label='Eval Loss', color='red') |
|
|
plt.xlabel('Epoch') |
|
|
plt.ylabel('Loss') |
|
|
plt.title('Training and Evaluation Loss', fontsize=10) |
|
|
plt.legend() |
|
|
plt.tight_layout() |
|
|
|
|
|
|
|
|
plot_path = os.path.join(new_model, 'training_evaluation_loss_plot.png') |
|
|
plt.savefig(plot_path) |
|
|
plt.close() |
|
|
|
|
|
print(f"Plot saved in the current directory as 'training_evaluation_loss_plot.png'.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|