| | """ |
| | Fine-tune Flan-T5 for cyber report generation. |
| | M3 — Model fine-tuning. |
| | """ |
| |
|
| | import argparse |
| | from pathlib import Path |
| |
|
| | |
| | from src.hf_auth import login as hf_login |
| |
|
| | hf_login() |
| |
|
| | import pandas as pd |
| | from datasets import Dataset |
| | from transformers import ( |
| | AutoTokenizer, |
| | AutoModelForSeq2SeqLM, |
| | DataCollatorForSeq2Seq, |
| | Seq2SeqTrainer, |
| | Seq2SeqTrainingArguments, |
| | ) |
| |
|
| |
|
| | def load_csv_data(train_path: str, val_path: str): |
| | """Load train/val CSV with input_text, target_report columns.""" |
| | train_df = pd.read_csv(train_path, encoding="utf-8") |
| | val_df = pd.read_csv(val_path, encoding="utf-8") |
| | return Dataset.from_pandas(train_df), Dataset.from_pandas(val_df) |
| |
|
| |
|
| | def preprocess_function(examples, tokenizer, max_input_length: int, max_target_length: int): |
| | """Tokenize inputs and targets for seq2seq.""" |
| | model_inputs = tokenizer( |
| | examples["input_text"], |
| | max_length=max_input_length, |
| | truncation=True, |
| | padding=False, |
| | ) |
| | labels = tokenizer( |
| | examples["target_report"], |
| | max_length=max_target_length, |
| | truncation=True, |
| | padding=False, |
| | ) |
| | model_inputs["labels"] = labels["input_ids"] |
| | return model_inputs |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser(description="Fine-tune Flan-T5 for cyber report generation") |
| | parser.add_argument("--model_name", default="google/flan-t5-base") |
| | parser.add_argument("--train", default="data/train.csv") |
| | parser.add_argument("--val", default="data/val.csv") |
| | parser.add_argument("--output_dir", default="models/flan_t5_report") |
| | parser.add_argument("--epochs", type=int, default=3) |
| | parser.add_argument("--batch_size", type=int, default=8) |
| | parser.add_argument("--learning_rate", type=float, default=3e-5) |
| | parser.add_argument("--max_input_length", type=int, default=512) |
| | parser.add_argument("--max_target_length", type=int, default=128) |
| | args = parser.parse_args() |
| |
|
| | print("Loading model and tokenizer...") |
| | tokenizer = AutoTokenizer.from_pretrained(args.model_name) |
| | model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name) |
| |
|
| | print("Loading datasets...") |
| | train_ds, val_ds = load_csv_data(args.train, args.val) |
| |
|
| | def tokenize_fn(examples): |
| | return preprocess_function( |
| | examples, tokenizer, args.max_input_length, args.max_target_length |
| | ) |
| |
|
| | train_ds = train_ds.map(tokenize_fn, batched=True, remove_columns=train_ds.column_names) |
| | val_ds = val_ds.map(tokenize_fn, batched=True, remove_columns=val_ds.column_names) |
| |
|
| | data_collator = DataCollatorForSeq2Seq( |
| | tokenizer=tokenizer, model=model, padding=True, return_tensors="pt" |
| | ) |
| |
|
| | training_args = Seq2SeqTrainingArguments( |
| | output_dir=args.output_dir, |
| | num_train_epochs=args.epochs, |
| | per_device_train_batch_size=args.batch_size, |
| | per_device_eval_batch_size=args.batch_size, |
| | learning_rate=args.learning_rate, |
| | warmup_ratio=0.1, |
| | logging_steps=20, |
| | eval_strategy="epoch", |
| | save_strategy="epoch", |
| | load_best_model_at_end=True, |
| | report_to="none", |
| | ) |
| |
|
| | trainer = Seq2SeqTrainer( |
| | model=model, |
| | args=training_args, |
| | train_dataset=train_ds, |
| | eval_dataset=val_ds, |
| | data_collator=data_collator, |
| | tokenizer=tokenizer, |
| | ) |
| |
|
| | print("Starting training...") |
| | trainer.train() |
| | trainer.save_model(args.output_dir) |
| | tokenizer.save_pretrained(args.output_dir) |
| | print(f"Model saved to {args.output_dir}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|