Spaces:
Running
Running
| """Training CLI for DETree.""" | |
| from __future__ import annotations | |
| import argparse | |
| import random | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Iterable, Optional | |
| import torch | |
| import torch.nn.functional as F # noqa: F401 # retained for backward compat with downstream imports | |
| import torch.optim as optim | |
| import yaml | |
| from lightning import Fabric | |
| from lightning.fabric.strategies import DeepSpeedStrategy, DDPStrategy | |
| from torch.utils.data import DataLoader | |
| from torch.utils.data.dataloader import default_collate | |
| from torch.utils.tensorboard import SummaryWriter | |
| from tqdm import tqdm | |
| from transformers import AutoTokenizer | |
| from detree.model.simclr import SimCLR_Tree | |
| from detree.utils.dataset import SCLDataset, load_datapath | |
| class ExperimentPaths: | |
| """Utility container describing where to store experiment artefacts.""" | |
| root: Path | |
| runs: Path | |
| def _build_collate_fn(tokenizer, max_length: int): | |
| def collate_fn(batch: Iterable): | |
| text, label, write_model = default_collate(batch) | |
| encoded_batch = tokenizer.batch_encode_plus( | |
| text, | |
| return_tensors="pt", | |
| max_length=max_length, | |
| padding=True, | |
| truncation=True, | |
| ) | |
| return encoded_batch, label, write_model | |
| return collate_fn | |
| def _prepare_output_dir( | |
| output_dir: Path, experiment_name: str, resume: bool, *, create_dirs: bool = True | |
| ) -> ExperimentPaths: | |
| output_dir = output_dir.expanduser().resolve() | |
| candidate = output_dir / experiment_name | |
| if candidate.exists() and not resume: | |
| suffix = 0 | |
| while (output_dir / f"{experiment_name}_v{suffix}").exists(): | |
| suffix += 1 | |
| candidate = output_dir / f"{experiment_name}_v{suffix}" | |
| runs_dir = candidate / "runs" | |
| if create_dirs: | |
| candidate.mkdir(parents=True, exist_ok=True) | |
| runs_dir.mkdir(parents=True, exist_ok=True) | |
| return ExperimentPaths(root=candidate, runs=runs_dir) | |
| def build_argument_parser() -> argparse.ArgumentParser: | |
| parser = argparse.ArgumentParser( | |
| description="Train DETree using the hierarchical contrastive objective", | |
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |
| ) | |
| parser.add_argument("--model-name", type=str, default="FacebookAI/roberta-large", help="Backbone encoder identifier.") | |
| parser.add_argument("--device-num", type=int, default=1, help="Number of CUDA devices to use.") | |
| parser.add_argument("--path", type=Path, required=True, help="Root directory of the dataset.") | |
| parser.add_argument("--dataset-name", type=str, default="all", help="Dataset configuration name.") | |
| parser.add_argument( | |
| "--dataset", type=str, default="train", choices=("train", "test", "extra"), help="Dataset split to consume." | |
| ) | |
| parser.add_argument("--tree-txt", type=Path, required=True, help="Tree definition file as produced by the HAT pipeline.") | |
| parser.add_argument("--output-dir", type=Path, default=Path("runs"), help="Directory where experiment folders are saved.") | |
| parser.add_argument("--experiment-name", type=str, default="detree_experiment", help="Base name for the run directory.") | |
| parser.add_argument("--resume", action="store_true", help="Reuse the given experiment directory if it already exists.") | |
| parser.add_argument("--projection-size", type=int, default=1024) | |
| parser.add_argument("--temperature", type=float, default=0.07) | |
| parser.add_argument("--num-workers", type=int, default=8) | |
| parser.add_argument("--per-gpu-batch-size", type=int, default=64) | |
| parser.add_argument("--per-gpu-eval-batch-size", type=int, default=16) | |
| parser.add_argument("--max-length", type=int, default=512, help="Maximum sequence length for the tokenizer.") | |
| parser.add_argument("--total-epoch", type=int, default=10) | |
| parser.add_argument("--warmup-steps", type=int, default=2000) | |
| parser.add_argument("--lr", type=float, default=3e-5) | |
| parser.add_argument("--min-lr", type=float, default=5e-6) | |
| parser.add_argument("--weight-decay", type=float, default=1e-4) | |
| parser.add_argument("--beta1", type=float, default=0.9) | |
| parser.add_argument("--beta2", type=float, default=0.99) | |
| parser.add_argument("--eps", type=float, default=1e-6) | |
| parser.add_argument("--adv-p", type=float, default=0.5, help="Probability of sampling adversarial data.") | |
| parser.add_argument("--num-workers-eval", type=int, default=8, help="Reserved for compatibility.") | |
| parser.add_argument("--lora-r", type=int, default=128) | |
| parser.add_argument("--lora-alpha", type=int, default=256) | |
| parser.add_argument("--lora-dropout", type=float, default=0.0) | |
| parser.add_argument("--freeze-layer", type=int, default=0, help="Number of initial encoder layers to freeze.") | |
| parser.add_argument("--seed", type=int, default=42) | |
| parser.add_argument("--adapter-path", type=Path, default=None, help="Optional path to resume LoRA training from.") | |
| parser.add_argument("--pooling", type=str, default="max", choices=("max", "average", "cls")) | |
| parser.add_argument("--lora", dest="lora", action="store_true", help="Enable LoRA adapters.") | |
| parser.add_argument("--no-lora", dest="lora", action="store_false", help="Disable LoRA adapters.") | |
| parser.set_defaults(lora=True) | |
| parser.add_argument("--freeze-embedding-layer", dest="freeze_embedding_layer", action="store_true") | |
| parser.add_argument("--no-freeze-embedding-layer", dest="freeze_embedding_layer", action="store_false") | |
| parser.set_defaults(freeze_embedding_layer=True) | |
| parser.add_argument("--adversarial", dest="adversarial", action="store_true") | |
| parser.add_argument("--no-adversarial", dest="adversarial", action="store_false") | |
| parser.set_defaults(adversarial=True) | |
| parser.add_argument("--include-attack", dest="include_attack", action="store_true") | |
| parser.add_argument("--no-include-attack", dest="include_attack", action="store_false") | |
| parser.set_defaults(include_attack=True) | |
| parser.add_argument("--has-mix", dest="has_mix", action="store_true") | |
| parser.add_argument("--no-has-mix", dest="has_mix", action="store_false") | |
| parser.set_defaults(has_mix=True) | |
| parser.add_argument("--deepspeed", action="store_true", help="Use DeepSpeed strategy when multiple GPUs are available.") | |
| return parser | |
| def train(args: argparse.Namespace) -> None: | |
| torch.manual_seed(args.seed) | |
| random.seed(args.seed) | |
| torch.set_float32_matmul_precision("medium") | |
| if args.device_num > 1: | |
| if args.deepspeed: | |
| strategy = DeepSpeedStrategy() | |
| else: | |
| strategy = DDPStrategy(find_unused_parameters=True) | |
| fabric = Fabric(accelerator="cuda", precision="bf16-mixed", devices=args.device_num, strategy=strategy) | |
| else: | |
| fabric = Fabric(accelerator="cuda", precision="bf16-mixed", devices=args.device_num) | |
| fabric.launch() | |
| experiment_paths = ExperimentPaths(root=Path(args.output_dir), runs=Path(args.runs_dir)) | |
| if fabric.global_rank == 0: | |
| experiment_paths.root.mkdir(parents=True, exist_ok=True) | |
| experiment_paths.runs.mkdir(parents=True, exist_ok=True) | |
| fabric.barrier() | |
| tokenizer = AutoTokenizer.from_pretrained(args.model_name) | |
| collate_fn = _build_collate_fn(tokenizer, args.max_length) | |
| model = SimCLR_Tree(args, fabric).train() | |
| data_path = load_datapath( | |
| str(args.path), | |
| include_adversarial=args.adversarial, | |
| dataset_name=args.dataset_name, | |
| include_attack=args.include_attack, | |
| )[args.dataset] | |
| train_dataset = SCLDataset( | |
| data_path, | |
| fabric, | |
| tokenizer, | |
| name2id=model.names2id, | |
| has_mix=args.has_mix, | |
| adv_p=args.adv_p, | |
| ) | |
| passages_dataloader = DataLoader( | |
| train_dataset, | |
| batch_size=args.per_gpu_batch_size, | |
| num_workers=args.num_workers, | |
| pin_memory=True, | |
| shuffle=True, | |
| drop_last=True, | |
| collate_fn=collate_fn, | |
| ) | |
| model.train() | |
| if args.freeze_embedding_layer: | |
| for name, param in model.model.named_parameters(): | |
| if "emb" in name or "model.pooler" in name: | |
| param.requires_grad = False | |
| if args.freeze_layer > 0: | |
| for i in range(args.freeze_layer): | |
| if f"encoder.layer.{i}." in name: | |
| param.requires_grad = False | |
| model = torch.compile(model) | |
| if fabric.global_rank == 0: | |
| print("Model has been initialized!") | |
| for name, param in model.model.named_parameters(): | |
| print(name, param.requires_grad) | |
| passages_dataloader = fabric.setup_dataloaders(passages_dataloader, use_distributed_sampler=False) | |
| if fabric.global_rank == 0: | |
| print("DataLoader has been initialized!") | |
| if fabric.global_rank == 0: | |
| writer = SummaryWriter(str(experiment_paths.runs)) | |
| print(f"Save dir is {args.output_dir}") | |
| opt_dict = vars(args) | |
| opt_dict["output_dir"] = str(args.output_dir) | |
| with open(Path(args.output_dir) / "config.yaml", "w", encoding="utf-8") as file: | |
| yaml.dump(opt_dict, file, sort_keys=False) | |
| else: | |
| writer = None | |
| experiment_dir = experiment_paths.root | |
| num_batches_per_epoch = len(passages_dataloader) | |
| warmup_steps = args.warmup_steps | |
| lr = args.lr | |
| total_steps = args.total_epoch * num_batches_per_epoch - warmup_steps | |
| optimizer = optim.AdamW( | |
| filter(lambda p: p.requires_grad, model.parameters()), | |
| lr=args.lr, | |
| betas=(args.beta1, args.beta2), | |
| eps=args.eps, | |
| weight_decay=args.weight_decay, | |
| ) | |
| schedule = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, total_steps, eta_min=args.min_lr) | |
| model, optimizer = fabric.setup(model, optimizer) | |
| if fabric.global_rank == 0: | |
| for name, param in model.named_parameters(): | |
| if param.requires_grad: | |
| print(name, param.requires_grad) | |
| for epoch in range(args.total_epoch): | |
| model.train() | |
| avg_loss = 0.0 | |
| iterator = enumerate(passages_dataloader) | |
| if fabric.global_rank == 0: | |
| iterator = tqdm(iterator, total=len(passages_dataloader)) | |
| print(("\n" + "%11s" * 5) % ("Epoch", "GPU_mem", "loss1", "Avgloss", "lr")) | |
| for i, batch in iterator: | |
| current_step = epoch * num_batches_per_epoch + i | |
| if current_step < warmup_steps: | |
| current_lr = lr * current_step / max(warmup_steps, 1) | |
| for param_group in optimizer.param_groups: | |
| param_group["lr"] = current_lr | |
| current_lr = optimizer.param_groups[0]["lr"] | |
| encoded_batch, label, write_model = batch | |
| loss, loss_classify = model(encoded_batch, write_model) | |
| avg_loss = (avg_loss * i + loss.item()) / (i + 1) | |
| fabric.backward(loss) | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| if current_step >= warmup_steps: | |
| schedule.step() | |
| mem = f"{torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0:.3g}G" | |
| if fabric.global_rank == 0: | |
| iterator.set_description( | |
| ("%11s" * 2 + "%11.4g" * 3) | |
| % (f"{epoch + 1}/{args.total_epoch}", mem, loss_classify.item(), avg_loss, current_lr) | |
| ) | |
| if writer and current_step % 10 == 0: | |
| writer.add_scalar("lr", current_lr, current_step) | |
| writer.add_scalar("loss", loss.item(), current_step) | |
| writer.add_scalar("avg_loss", avg_loss, current_step) | |
| writer.add_scalar("loss_classify", loss_classify.item(), current_step) | |
| if fabric.global_rank == 0: | |
| checkpoint_dir = experiment_dir / f"epoch_{epoch:02d}" | |
| model.save_pretrained(str(checkpoint_dir), save_tokenizer=(epoch == 0)) | |
| print(f"Saved adapter checkpoint to {checkpoint_dir}", flush=True) | |
| last_dir = experiment_dir / "last" | |
| model.save_pretrained(str(last_dir), save_tokenizer=False) | |
| print(f"Updated latest checkpoint at {last_dir}", flush=True) | |
| fabric.barrier() | |
| if writer: | |
| writer.flush() | |
| writer.close() | |
| def main(argv: Optional[Iterable[str]] = None) -> None: | |
| parser = build_argument_parser() | |
| args = parser.parse_args(argv) | |
| experiment_paths = _prepare_output_dir( | |
| args.output_dir, args.experiment_name, resume=args.resume, create_dirs=False | |
| ) | |
| args.output_dir = str(experiment_paths.root) | |
| args.runs_dir = str(experiment_paths.runs) | |
| train(args) | |
| __all__ = ["build_argument_parser", "main", "train"] | |
| if __name__ == "__main__": | |
| main() | |