|
|
import os |
|
|
import torch |
|
|
from torch.utils.data import Dataset |
|
|
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config, DataCollatorForLanguageModeling, Trainer, TrainingArguments |
|
|
|
|
|
class CustomTextDataset(Dataset): |
|
|
def __init__(self, tokenizer, data_chunk, block_size): |
|
|
self.examples = [] |
|
|
for chunk in data_chunk: |
|
|
tokenized_text = tokenizer.encode(chunk, add_special_tokens=True) |
|
|
self.examples.extend(tokenized_text) |
|
|
|
|
|
self.block_size = block_size |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.examples) - self.block_size |
|
|
|
|
|
def __getitem__(self, i): |
|
|
|
|
|
return torch.tensor(self.examples[i:i + self.block_size]) |
|
|
|
|
|
|
|
|
folder_path = "data" |
|
|
|
|
|
|
|
|
file_list = [f for f in os.listdir(folder_path) if f.endswith(".txt")] |
|
|
|
|
|
|
|
|
all_text_data = [] |
|
|
|
|
|
|
|
|
for file_name in file_list: |
|
|
file_path = os.path.join(folder_path, file_name) |
|
|
with open(file_path, "r", encoding="utf-8") as f: |
|
|
file_text = f.read() |
|
|
all_text_data.append(file_text) |
|
|
|
|
|
|
|
|
text = " ".join(all_text_data) |
|
|
|
|
|
|
|
|
model_name = "gpt2" |
|
|
tokenizer = GPT2Tokenizer.from_pretrained(model_name) |
|
|
config = GPT2Config.from_pretrained(model_name) |
|
|
model = GPT2LMHeadModel.from_pretrained(model_name, config=config) |
|
|
|
|
|
|
|
|
max_sequence_length = 1024 |
|
|
chunks = [text[i:i + max_sequence_length] for i in range(0, len(text), max_sequence_length)] |
|
|
|
|
|
|
|
|
dataset = CustomTextDataset(tokenizer=tokenizer, data_chunk=chunks, block_size=128) |
|
|
|
|
|
|
|
|
training_args = TrainingArguments( |
|
|
output_dir="./Cyber_LLM", |
|
|
overwrite_output_dir=True, |
|
|
num_train_epochs=1, |
|
|
per_device_train_batch_size=32, |
|
|
save_steps=10_000, |
|
|
save_total_limit=2, |
|
|
evaluation_strategy="epoch", |
|
|
eval_steps=10_000, |
|
|
) |
|
|
|
|
|
|
|
|
trainer = Trainer( |
|
|
model=model, |
|
|
args=training_args, |
|
|
data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False), |
|
|
train_dataset=dataset, |
|
|
) |
|
|
|
|
|
|
|
|
trainer.train() |
|
|
|
|
|
|
|
|
model.save_pretrained("./Cyber_LLM") |
|
|
|
|
|
print("Training completed.") |
|
|
|