Spaces:
Running
Running
Upload 9 files
Browse files- detree/cli/database.py +117 -0
- detree/cli/embeddings.py +200 -0
- detree/cli/gen_tree.py +86 -0
- detree/cli/hierarchical_clustering.py +497 -0
- detree/cli/merge_lora.py +52 -0
- detree/cli/similarity_matrix.py +77 -0
- detree/cli/test_database_score_knn.py +247 -0
- detree/cli/test_score_knn.py +267 -0
- detree/cli/train.py +313 -0
detree/cli/database.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Generate clustered prototype databases from embeddings."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Iterable, Optional
|
| 8 |
+
|
| 9 |
+
import faiss
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class GPUKMeansClusterer:
|
| 15 |
+
def __init__(self, dim: int, n_clusters: int = 500, n_iter: int = 20, n_gpu: int = 1):
|
| 16 |
+
self.clus = faiss.Clustering(dim, n_clusters)
|
| 17 |
+
self.clus.verbose = True
|
| 18 |
+
self.clus.niter = n_iter
|
| 19 |
+
self.dim = dim
|
| 20 |
+
self.n_clusters = n_clusters
|
| 21 |
+
self.clus.update_index = True
|
| 22 |
+
|
| 23 |
+
res = [faiss.StandardGpuResources() for _ in range(n_gpu)]
|
| 24 |
+
flat_config = []
|
| 25 |
+
for i in range(n_gpu):
|
| 26 |
+
cfg = faiss.GpuIndexFlatConfig()
|
| 27 |
+
cfg.useFloat16 = False
|
| 28 |
+
cfg.device = i
|
| 29 |
+
flat_config.append(cfg)
|
| 30 |
+
|
| 31 |
+
if n_gpu == 1:
|
| 32 |
+
self.index = faiss.GpuIndexFlatL2(res[0], self.dim, flat_config[0])
|
| 33 |
+
else:
|
| 34 |
+
indexes = [faiss.GpuIndexFlatL2(res[i], self.dim, flat_config[i]) for i in range(n_gpu)]
|
| 35 |
+
self.index = faiss.IndexReplicas()
|
| 36 |
+
for sub_index in indexes:
|
| 37 |
+
self.index.addIndex(sub_index)
|
| 38 |
+
|
| 39 |
+
def fit(self, embeddings_np: np.ndarray) -> np.ndarray:
|
| 40 |
+
self.index.reset()
|
| 41 |
+
self.clus.train(embeddings_np, self.index)
|
| 42 |
+
centroids = faiss.vector_float_to_array(self.clus.centroids)
|
| 43 |
+
centroids = centroids.reshape(self.n_clusters, self.dim)
|
| 44 |
+
return centroids
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def gen_data(dict_data):
|
| 48 |
+
embeddings = dict_data["embeddings"]
|
| 49 |
+
labels = dict_data["labels"]
|
| 50 |
+
ids = dict_data["ids"]
|
| 51 |
+
classes = dict_data["classes"]
|
| 52 |
+
return embeddings, labels, ids, classes
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def build_argument_parser() -> argparse.ArgumentParser:
|
| 56 |
+
parser = argparse.ArgumentParser(
|
| 57 |
+
description="Cluster embeddings into prototype databases using GPU K-Means.",
|
| 58 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
| 59 |
+
)
|
| 60 |
+
parser.add_argument("--database", type=Path, required=True, help="Input embedding database (.pt).")
|
| 61 |
+
parser.add_argument("--output", type=Path, required=True, help="Output path for the clustered database.")
|
| 62 |
+
parser.add_argument("--clusters", type=int, default=10000)
|
| 63 |
+
parser.add_argument("--dimension", type=int, default=1024)
|
| 64 |
+
parser.add_argument("--iterations", type=int, default=100)
|
| 65 |
+
parser.add_argument("--gpus", type=int, default=1)
|
| 66 |
+
parser.add_argument("--human-class-name", type=str, default="human", help="Label representing humans in the class list.")
|
| 67 |
+
return parser
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def cluster_database(args: argparse.Namespace) -> None:
|
| 71 |
+
data_emb, data_labels, data_ids, data_classes = gen_data(torch.load(args.database))
|
| 72 |
+
human_idx = data_classes.index(args.human_class_name)
|
| 73 |
+
datapos = (data_labels == human_idx).long()
|
| 74 |
+
pos2cnt = {0: args.clusters, 1: args.clusters}
|
| 75 |
+
pos2name = {0: ["llm"], 1: ["human"]}
|
| 76 |
+
|
| 77 |
+
datapos_np = datapos.cpu().numpy()
|
| 78 |
+
kmeans = GPUKMeansClusterer(args.dimension, n_clusters=args.clusters, n_iter=args.iterations, n_gpu=args.gpus)
|
| 79 |
+
all_centers = {}
|
| 80 |
+
save_labels = None
|
| 81 |
+
for key in data_emb:
|
| 82 |
+
now_emb = data_emb[key].float().cpu().numpy()
|
| 83 |
+
all_center = []
|
| 84 |
+
all_labels = []
|
| 85 |
+
for pos in pos2cnt:
|
| 86 |
+
pos_emb = now_emb[datapos_np == pos]
|
| 87 |
+
pos_center = kmeans.fit(pos_emb)
|
| 88 |
+
all_center.append(pos_center)
|
| 89 |
+
all_labels.append(np.full((pos_center.shape[0],), pos))
|
| 90 |
+
all_center = np.concatenate(all_center, axis=0)
|
| 91 |
+
all_labels = np.concatenate(all_labels, axis=0)
|
| 92 |
+
all_center = torch.from_numpy(all_center).to(dtype=torch.bfloat16)
|
| 93 |
+
all_labels = torch.from_numpy(all_labels).to(dtype=torch.long)
|
| 94 |
+
all_centers[key] = all_center
|
| 95 |
+
save_labels = all_labels
|
| 96 |
+
|
| 97 |
+
save_ids = torch.arange(save_labels.shape[0], dtype=torch.long)
|
| 98 |
+
classes = [None] * len(pos2name.keys())
|
| 99 |
+
for pos in pos2name:
|
| 100 |
+
classes[pos] = ','.join(pos2name[pos])
|
| 101 |
+
|
| 102 |
+
emb_dict = {"embeddings": all_centers, "labels": save_labels, "ids": save_ids, "classes": classes}
|
| 103 |
+
args.output.parent.mkdir(parents=True, exist_ok=True)
|
| 104 |
+
torch.save(emb_dict, args.output)
|
| 105 |
+
print(f"All centers saved to: {args.output}")
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def main(argv: Optional[Iterable[str]] = None) -> None:
|
| 109 |
+
parser = build_argument_parser()
|
| 110 |
+
args = parser.parse_args(argv)
|
| 111 |
+
cluster_database(args)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
if __name__ == "__main__":
|
| 115 |
+
main()
|
| 116 |
+
|
| 117 |
+
__all__ = ["build_argument_parser", "cluster_database", "main"]
|
detree/cli/embeddings.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Embedding generation CLI for DETree."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Iterable, Literal, Optional
|
| 8 |
+
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from lightning import Fabric
|
| 13 |
+
from torch.utils.data import DataLoader, Dataset
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
|
| 16 |
+
from detree.model.text_embedding import TextEmbeddingModel
|
| 17 |
+
from detree.utils.dataset import SCLDataset, load_datapath
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def infer(passages_dataloader, fabric, tokenizer, model, args):
|
| 21 |
+
if fabric.global_rank == 0:
|
| 22 |
+
passages_dataloader = tqdm(passages_dataloader)
|
| 23 |
+
all_ids, all_embeddings, all_labels = [], {}, []
|
| 24 |
+
for layer in args.need_layer:
|
| 25 |
+
all_embeddings[layer] = []
|
| 26 |
+
with torch.no_grad():
|
| 27 |
+
for batch in passages_dataloader:
|
| 28 |
+
text, label, write_model, ids = batch
|
| 29 |
+
encoded_batch = tokenizer.batch_encode_plus(
|
| 30 |
+
text,
|
| 31 |
+
return_tensors="pt",
|
| 32 |
+
max_length=args.max_length,
|
| 33 |
+
padding="max_length",
|
| 34 |
+
truncation=True,
|
| 35 |
+
)
|
| 36 |
+
encoded_batch = {k: v.cuda() for k, v in encoded_batch.items()}
|
| 37 |
+
embeddings = model(encoded_batch, hidden_states=True)
|
| 38 |
+
embeddings = fabric.all_gather(embeddings).view(-1, embeddings.size(-2), embeddings.size(-1))
|
| 39 |
+
label = fabric.all_gather(write_model).view(-1)
|
| 40 |
+
ids = fabric.all_gather(ids).view(-1)
|
| 41 |
+
if fabric.global_rank == 0:
|
| 42 |
+
embeddings = F.normalize(embeddings, dim=-1).cpu().to(torch.bfloat16)
|
| 43 |
+
for layer in args.need_layer:
|
| 44 |
+
all_embeddings[layer].append(embeddings[:, layer, :].clone())
|
| 45 |
+
all_ids.extend(ids.cpu().tolist())
|
| 46 |
+
all_labels.extend(label.cpu().tolist())
|
| 47 |
+
del embeddings, label, ids
|
| 48 |
+
if fabric.global_rank == 0:
|
| 49 |
+
for layer in args.need_layer:
|
| 50 |
+
all_embeddings[layer] = torch.cat(all_embeddings[layer], dim=0)
|
| 51 |
+
return torch.tensor(all_ids), all_embeddings, torch.tensor(all_labels)
|
| 52 |
+
return [], [], []
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def stable_long_hash(input_string: str) -> int:
|
| 56 |
+
import hashlib
|
| 57 |
+
|
| 58 |
+
hash_object = hashlib.sha256(input_string.encode())
|
| 59 |
+
hex_digest = hash_object.hexdigest()
|
| 60 |
+
int_hash = int(hex_digest, 16)
|
| 61 |
+
return int_hash & ((1 << 63) - 1)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def load_data(split: Literal["train", "test", "extra"], include_adversarial: bool, fp: Path) -> pd.DataFrame:
|
| 65 |
+
if split not in ("train", "test", "extra"):
|
| 66 |
+
raise ValueError("`split` must be one of (\"train\", \"test\", \"extra\")")
|
| 67 |
+
|
| 68 |
+
fname = f"{split}.csv" if include_adversarial else f"{split}_none.csv"
|
| 69 |
+
fp = fp / fname
|
| 70 |
+
return pd.read_csv(fp)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class PassagesDataset(Dataset):
|
| 74 |
+
def __init__(self, data):
|
| 75 |
+
self.passages = []
|
| 76 |
+
for item in data:
|
| 77 |
+
if item["attack"] not in ("none", "paraphrase") and stable_long_hash(item["generation"]) % 10 < 5:
|
| 78 |
+
continue
|
| 79 |
+
self.passages.append(item)
|
| 80 |
+
classes = sorted({item["model"] for item in data})
|
| 81 |
+
self.classes = list(classes)
|
| 82 |
+
self.human_id = self.classes.index("human")
|
| 83 |
+
|
| 84 |
+
def __len__(self):
|
| 85 |
+
return len(self.passages)
|
| 86 |
+
|
| 87 |
+
def __getitem__(self, idx):
|
| 88 |
+
data_now = self.passages[idx]
|
| 89 |
+
text = data_now["generation"]
|
| 90 |
+
model = self.classes.index(data_now["model"])
|
| 91 |
+
label = int(model == self.human_id)
|
| 92 |
+
ids = stable_long_hash(text)
|
| 93 |
+
return text, int(label), int(model), int(ids)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def build_argument_parser() -> argparse.ArgumentParser:
|
| 97 |
+
parser = argparse.ArgumentParser(
|
| 98 |
+
description="Generate embedding databases for DETree evaluators",
|
| 99 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
| 100 |
+
)
|
| 101 |
+
parser.add_argument("--device-num", type=int, default=1)
|
| 102 |
+
parser.add_argument("--batch-size", type=int, default=64)
|
| 103 |
+
parser.add_argument("--num-workers", type=int, default=8)
|
| 104 |
+
parser.add_argument("--max-length", type=int, default=512)
|
| 105 |
+
|
| 106 |
+
parser.add_argument("--path", type=Path, required=True, help="Dataset root directory or JSONL file path.")
|
| 107 |
+
parser.add_argument("--database-name", type=str, default="M4_monolingual")
|
| 108 |
+
parser.add_argument(
|
| 109 |
+
"--model-name",
|
| 110 |
+
type=str,
|
| 111 |
+
default="FacebookAI/roberta-large",
|
| 112 |
+
help=(
|
| 113 |
+
"Model identifier for embeddings generation. Accepts either a Hugging Face "
|
| 114 |
+
"model hub name or a local path to a directory in Hugging Face format."
|
| 115 |
+
),
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
parser.add_argument("--pooling", type=str, default="max", choices=("max", "average", "cls"))
|
| 119 |
+
parser.add_argument("--need-layer", type=int, nargs="+", default=[16, 17, 18, 19, 22, 23])
|
| 120 |
+
|
| 121 |
+
parser.add_argument("--adversarial", dest="adversarial", action="store_true")
|
| 122 |
+
parser.add_argument("--no-adversarial", dest="adversarial", action="store_false")
|
| 123 |
+
parser.set_defaults(adversarial=True)
|
| 124 |
+
|
| 125 |
+
parser.add_argument("--has-mix", dest="has_mix", action="store_true")
|
| 126 |
+
parser.add_argument("--no-has-mix", dest="has_mix", action="store_false")
|
| 127 |
+
parser.set_defaults(has_mix=False)
|
| 128 |
+
|
| 129 |
+
parser.add_argument("--savedir", type=Path, required=True, help="Output directory for the embedding database.")
|
| 130 |
+
parser.add_argument("--name", type=str, required=True, help="Filename (without extension) for the saved embeddings.")
|
| 131 |
+
parser.add_argument("--split", type=str, default="train", choices=("train", "test", "extra"))
|
| 132 |
+
|
| 133 |
+
return parser
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def generate_embeddings(args: argparse.Namespace) -> None:
|
| 137 |
+
if args.device_num > 1:
|
| 138 |
+
fabric = Fabric(accelerator="cuda", precision="bf16-mixed", devices=args.device_num, strategy="ddp")
|
| 139 |
+
else:
|
| 140 |
+
fabric = Fabric(accelerator="cuda", precision="bf16-mixed", devices=args.device_num)
|
| 141 |
+
fabric.launch()
|
| 142 |
+
|
| 143 |
+
model = TextEmbeddingModel(
|
| 144 |
+
args.model_name,
|
| 145 |
+
output_hidden_states=True,
|
| 146 |
+
infer=True,
|
| 147 |
+
use_pooling=args.pooling,
|
| 148 |
+
).cuda()
|
| 149 |
+
tokenizer = model.tokenizer
|
| 150 |
+
model.eval()
|
| 151 |
+
|
| 152 |
+
path_str = str(args.path)
|
| 153 |
+
if "LLM_detect_data" in path_str:
|
| 154 |
+
now_data = load_data(args.split, include_adversarial=args.adversarial, fp=args.path)
|
| 155 |
+
now_data = now_data.to_dict(orient="records")
|
| 156 |
+
dataset = PassagesDataset(now_data)
|
| 157 |
+
dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers)
|
| 158 |
+
dataloader = fabric.setup_dataloaders(dataloader)
|
| 159 |
+
elif path_str.endswith(".jsonl"):
|
| 160 |
+
dataset = SCLDataset([path_str], fabric, tokenizer, need_ids=True, adv_p=0)
|
| 161 |
+
dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers)
|
| 162 |
+
dataloader = fabric.setup_dataloaders(dataloader, use_distributed_sampler=False)
|
| 163 |
+
else:
|
| 164 |
+
data_path = load_datapath(
|
| 165 |
+
path_str,
|
| 166 |
+
include_adversarial=args.adversarial,
|
| 167 |
+
dataset_name=args.database_name,
|
| 168 |
+
)[args.split]
|
| 169 |
+
dataset = SCLDataset(data_path, fabric, tokenizer, need_ids=True, adv_p=0, has_mix=args.has_mix)
|
| 170 |
+
dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers)
|
| 171 |
+
dataloader = fabric.setup_dataloaders(dataloader, use_distributed_sampler=False)
|
| 172 |
+
|
| 173 |
+
model = fabric.setup(model)
|
| 174 |
+
classes = dataset.classes
|
| 175 |
+
train_ids, train_embeddings, train_labels = infer(dataloader, fabric, tokenizer, model, args)
|
| 176 |
+
|
| 177 |
+
torch.cuda.empty_cache()
|
| 178 |
+
if fabric.global_rank == 0:
|
| 179 |
+
args.savedir.mkdir(parents=True, exist_ok=True)
|
| 180 |
+
emb_dict = {
|
| 181 |
+
"embeddings": train_embeddings,
|
| 182 |
+
"labels": train_labels,
|
| 183 |
+
"ids": train_ids,
|
| 184 |
+
"classes": classes,
|
| 185 |
+
}
|
| 186 |
+
output_path = args.savedir / f"{args.name}.pt"
|
| 187 |
+
torch.save(emb_dict, output_path)
|
| 188 |
+
print(f"Saved embedding database to {output_path}")
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def main(argv: Optional[Iterable[str]] = None) -> None:
|
| 192 |
+
parser = build_argument_parser()
|
| 193 |
+
args = parser.parse_args(argv)
|
| 194 |
+
generate_embeddings(args)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
if __name__ == "__main__":
|
| 198 |
+
main()
|
| 199 |
+
|
| 200 |
+
__all__ = ["build_argument_parser", "generate_embeddings", "main"]
|
detree/cli/gen_tree.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tree generation CLI utilities for DETree."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import json
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Iterable, Sequence, Set
|
| 9 |
+
|
| 10 |
+
from detree.utils.dataset import load_datapath, model_alias_mapping
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def _str2bool(value: str) -> bool:
|
| 14 |
+
"""Parse common textual boolean representations used by legacy scripts."""
|
| 15 |
+
|
| 16 |
+
if isinstance(value, bool):
|
| 17 |
+
return value
|
| 18 |
+
lowered = value.lower()
|
| 19 |
+
if lowered in {"true", "1", "yes", "y"}:
|
| 20 |
+
return True
|
| 21 |
+
if lowered in {"false", "0", "no", "n"}:
|
| 22 |
+
return False
|
| 23 |
+
raise argparse.ArgumentTypeError(f"Boolean value expected, got: {value}")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def get_data_model(data_path: Iterable[Path], has_mix: bool = True) -> Set[str]:
|
| 27 |
+
"""Collect all model identifiers present in the provided dataset paths."""
|
| 28 |
+
|
| 29 |
+
llm_name: Set[str] = set()
|
| 30 |
+
cnt = 0
|
| 31 |
+
for path in data_path:
|
| 32 |
+
print(f"reading {path}")
|
| 33 |
+
with path.open(mode="r", encoding="utf-8") as jsonl_file:
|
| 34 |
+
for line in jsonl_file:
|
| 35 |
+
now = json.loads(line)
|
| 36 |
+
if now["src"] not in model_alias_mapping:
|
| 37 |
+
model_alias_mapping[now["src"]] = now["src"]
|
| 38 |
+
now["src"] = model_alias_mapping[now["src"]]
|
| 39 |
+
if not has_mix and "human" in now["src"] and now["src"] != "human":
|
| 40 |
+
continue
|
| 41 |
+
if now["src"] not in llm_name:
|
| 42 |
+
llm_name.add(now["src"])
|
| 43 |
+
cnt += 1
|
| 44 |
+
print(cnt)
|
| 45 |
+
return llm_name
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def build_argument_parser() -> argparse.ArgumentParser:
|
| 49 |
+
"""Create the argument parser for the tree generation CLI."""
|
| 50 |
+
|
| 51 |
+
parser = argparse.ArgumentParser(
|
| 52 |
+
description="Generate DETree-compatible tree definitions from dataset files.",
|
| 53 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
| 54 |
+
)
|
| 55 |
+
parser.add_argument("--path", type=Path, default=Path("/opt/AI-text-Dataset"), help="Root directory of the dataset.")
|
| 56 |
+
parser.add_argument("--dataset_name", type=str, default="all", help="Dataset configuration name.")
|
| 57 |
+
parser.add_argument(
|
| 58 |
+
"--mode",
|
| 59 |
+
type=str,
|
| 60 |
+
choices=("train", "test", "extra"),
|
| 61 |
+
default="train",
|
| 62 |
+
help="Dataset split to consume.",
|
| 63 |
+
)
|
| 64 |
+
parser.add_argument("--tree_txt", type=Path, default=Path("output/Tree_RAID_pcl.txt"), help="Output tree definition path.")
|
| 65 |
+
parser.add_argument("--adversarial", type=_str2bool, default=True, help="Whether to include adversarial data splits.")
|
| 66 |
+
parser.add_argument("--has_mix", type=_str2bool, default=True, help="Whether to keep mixed human/model generations.")
|
| 67 |
+
return parser
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def main(args: argparse.Namespace) -> None:
|
| 71 |
+
"""Entry point for building DETree-compatible tree structures."""
|
| 72 |
+
|
| 73 |
+
dataset_paths: Sequence[str] = load_datapath(args.path, args.adversarial, args.dataset_name)[args.mode]
|
| 74 |
+
print(f"data_path: {dataset_paths}")
|
| 75 |
+
llm_name = sorted(get_data_model((Path(p) for p in dataset_paths), args.has_mix))
|
| 76 |
+
root = len(llm_name)
|
| 77 |
+
args.tree_txt.parent.mkdir(parents=True, exist_ok=True)
|
| 78 |
+
with args.tree_txt.open("w", encoding="utf-8") as f:
|
| 79 |
+
for i, item in enumerate(llm_name):
|
| 80 |
+
f.write(f"{i} {root} {item}\n")
|
| 81 |
+
f.write(f"{root} -1 none\n")
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
if __name__ == "__main__":
|
| 85 |
+
parser = build_argument_parser()
|
| 86 |
+
main(parser.parse_args())
|
detree/cli/hierarchical_clustering.py
ADDED
|
@@ -0,0 +1,497 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import random
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Iterable, Optional
|
| 5 |
+
|
| 6 |
+
import matplotlib.cm as cm
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
import numpy as np
|
| 9 |
+
from scipy.cluster.hierarchy import dendrogram, linkage
|
| 10 |
+
from scipy.spatial.distance import euclidean, squareform
|
| 11 |
+
from sklearn.metrics import silhouette_score
|
| 12 |
+
|
| 13 |
+
def read_similarity_matrix(file_path: Path):
|
| 14 |
+
with file_path.open('r', encoding='utf-8') as f:
|
| 15 |
+
lines = f.readlines()
|
| 16 |
+
names = lines[0].strip().split()
|
| 17 |
+
matrix = []
|
| 18 |
+
|
| 19 |
+
for line in lines[1:]:
|
| 20 |
+
row = line.strip().split()[1:]
|
| 21 |
+
matrix.append([float(x) for x in row])
|
| 22 |
+
|
| 23 |
+
similarity_matrix = np.array(matrix)
|
| 24 |
+
return names, similarity_matrix
|
| 25 |
+
|
| 26 |
+
class TreeNode:
|
| 27 |
+
def __init__(self, name=None):
|
| 28 |
+
|
| 29 |
+
self.name = name
|
| 30 |
+
self.children = []
|
| 31 |
+
self.value = 0
|
| 32 |
+
self.split = True
|
| 33 |
+
|
| 34 |
+
def add_child(self, child):
|
| 35 |
+
self.children.append(child)
|
| 36 |
+
|
| 37 |
+
def build_tree(Z, names):
|
| 38 |
+
nodes = [TreeNode(name) for name in names]
|
| 39 |
+
for i, link in enumerate(Z):
|
| 40 |
+
node = TreeNode()
|
| 41 |
+
node.value = link[2]
|
| 42 |
+
node.add_child(int(link[0]))
|
| 43 |
+
node.add_child(int(link[1]))
|
| 44 |
+
nodes.append(node)
|
| 45 |
+
return nodes
|
| 46 |
+
|
| 47 |
+
def find_best_thold(node_idx,nodes, distance_matrix,min_socre=0,max_socre=1):
|
| 48 |
+
node = nodes[node_idx]
|
| 49 |
+
threshold_range = np.linspace(min_socre * node.value, max_socre * node.value, 50)
|
| 50 |
+
silhouette_scores = []
|
| 51 |
+
all_n_clusters = []
|
| 52 |
+
|
| 53 |
+
for threshold in threshold_range:
|
| 54 |
+
labels,_ = gen_label_from_node(node_idx,nodes,threshold)
|
| 55 |
+
labels = sorted(labels,key=lambda x:x[1])
|
| 56 |
+
labels = [x[0] for x in labels]
|
| 57 |
+
n_clusters = len(np.unique(labels))
|
| 58 |
+
if n_clusters > 1 and n_clusters < len(distance_matrix):
|
| 59 |
+
score = silhouette_score(distance_matrix, labels, metric='precomputed')
|
| 60 |
+
else:
|
| 61 |
+
score = -1
|
| 62 |
+
silhouette_scores.append(score)
|
| 63 |
+
all_n_clusters.append(n_clusters)
|
| 64 |
+
best_threshold_idx = np.argmax(silhouette_scores)
|
| 65 |
+
best_threshold = threshold_range[best_threshold_idx]
|
| 66 |
+
best_score = silhouette_scores[best_threshold_idx]
|
| 67 |
+
return best_threshold, best_score
|
| 68 |
+
|
| 69 |
+
def gen_label_from_node(node_idx,nodes,thd,now_label=0):
|
| 70 |
+
node = nodes[node_idx]
|
| 71 |
+
if len(node.children)==0:
|
| 72 |
+
return [(now_label,node_idx)],now_label
|
| 73 |
+
else:
|
| 74 |
+
if node.value>thd:
|
| 75 |
+
label_list = []
|
| 76 |
+
for child in node.children:
|
| 77 |
+
now_label_list,now_label = gen_label_from_node(child,nodes,thd,now_label)
|
| 78 |
+
now_label+=1
|
| 79 |
+
label_list+=now_label_list
|
| 80 |
+
return label_list,now_label
|
| 81 |
+
else:
|
| 82 |
+
label_list = []
|
| 83 |
+
for child in node.children:
|
| 84 |
+
now_label_list,now_label = gen_label_from_node(child,nodes,thd,now_label)
|
| 85 |
+
label_list+=now_label_list
|
| 86 |
+
return label_list,now_label
|
| 87 |
+
|
| 88 |
+
def find_new_root(node_idx,nodes,thd):
|
| 89 |
+
node = nodes[node_idx]
|
| 90 |
+
if node.value<=thd:
|
| 91 |
+
return [node_idx]
|
| 92 |
+
|
| 93 |
+
new_root = []
|
| 94 |
+
for child in node.children:
|
| 95 |
+
new_root+=find_new_root(child,nodes,thd)
|
| 96 |
+
return new_root
|
| 97 |
+
|
| 98 |
+
def get_leaf(node_idx,nodes):
|
| 99 |
+
node = nodes[node_idx]
|
| 100 |
+
if len(node.children)==0:
|
| 101 |
+
return [node_idx]
|
| 102 |
+
|
| 103 |
+
leaf_list = []
|
| 104 |
+
for child in node.children:
|
| 105 |
+
leaf_list+=get_leaf(child,nodes)
|
| 106 |
+
return leaf_list
|
| 107 |
+
|
| 108 |
+
def merge_tree(node_idx,nodes,distance_matrix,deep=0,end_thd=0.25):
|
| 109 |
+
if len(nodes[node_idx].children)==0:
|
| 110 |
+
return
|
| 111 |
+
print(f"Node {node_idx}: Value: {nodes[node_idx].value}, Depth: {deep}")
|
| 112 |
+
if nodes[node_idx].value<=end_thd or deep>=5:
|
| 113 |
+
nodes[node_idx].children = get_leaf(node_idx,nodes)
|
| 114 |
+
nodes[node_idx].split = False
|
| 115 |
+
return
|
| 116 |
+
leaf_list = np.array(sorted(get_leaf(node_idx,nodes)))
|
| 117 |
+
new_distance_matrix = distance_matrix[leaf_list][:,leaf_list]
|
| 118 |
+
best_threshold, best_score = find_best_thold(node_idx, nodes, new_distance_matrix,min_socre=0)
|
| 119 |
+
if best_score==-1:
|
| 120 |
+
nodes[node_idx].children = get_leaf(node_idx,nodes)
|
| 121 |
+
return
|
| 122 |
+
new_root = find_new_root(node_idx,nodes,best_threshold)
|
| 123 |
+
nodes[node_idx].children = new_root
|
| 124 |
+
|
| 125 |
+
for child in new_root:
|
| 126 |
+
merge_tree(child,nodes,distance_matrix,deep=deep+1,end_thd=end_thd)
|
| 127 |
+
|
| 128 |
+
def merge_dict(a,b):
|
| 129 |
+
for key in b.keys():
|
| 130 |
+
if key in a.keys():
|
| 131 |
+
a[key]+=b[key]
|
| 132 |
+
else:
|
| 133 |
+
a[key] = b[key]
|
| 134 |
+
return a
|
| 135 |
+
|
| 136 |
+
def update_tree(node_idx, nodes, edge_list, fa=-1, deep=0):
|
| 137 |
+
node = nodes[node_idx]
|
| 138 |
+
|
| 139 |
+
if len(node.children)==0:
|
| 140 |
+
edge_list.append((fa,node_idx,[nodes[node_idx].name]))
|
| 141 |
+
return {deep:[[node_idx]]}
|
| 142 |
+
|
| 143 |
+
if node.split==False:
|
| 144 |
+
leafs = get_leaf(node_idx,nodes)
|
| 145 |
+
edge_list.append((fa,node_idx,[nodes[idx].name for idx in leafs]))
|
| 146 |
+
return {deep:[leafs]}
|
| 147 |
+
|
| 148 |
+
edge_list.append((fa,node_idx,[]))
|
| 149 |
+
new_tree = {}
|
| 150 |
+
for child in node.children:
|
| 151 |
+
new_tree = merge_dict(
|
| 152 |
+
new_tree,
|
| 153 |
+
update_tree(child, nodes, edge_list, node_idx, deep=deep+1),
|
| 154 |
+
)
|
| 155 |
+
if deep not in new_tree.keys():
|
| 156 |
+
new_tree[deep] = []
|
| 157 |
+
new_tree[deep].append(get_leaf(node_idx,nodes))
|
| 158 |
+
|
| 159 |
+
return new_tree
|
| 160 |
+
|
| 161 |
+
def color_distance(c1, c2):
|
| 162 |
+
return euclidean(c1[:3], c2[:3]) # only consider the RGB components
|
| 163 |
+
|
| 164 |
+
def ensure_color_diversity(colors, min_distance=0.2):
|
| 165 |
+
random.shuffle(colors)
|
| 166 |
+
for i in range(1, len(colors)):
|
| 167 |
+
if color_distance(colors[i], colors[i-1]) < min_distance:
|
| 168 |
+
for j in range(i + 1, len(colors)):
|
| 169 |
+
if color_distance(colors[i], colors[j]) > min_distance:
|
| 170 |
+
colors[i], colors[j] = colors[j], colors[i]
|
| 171 |
+
break
|
| 172 |
+
return colors
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def draw_table(new_tree, names, max_deep=3, save_path='fig/E/test.pdf'):
|
| 176 |
+
base_list = new_tree[0][0]
|
| 177 |
+
data = [base_list]
|
| 178 |
+
cmap = cm.get_cmap('tab20c', 2048)
|
| 179 |
+
cmap = [cmap(i) for i in range(2048)]
|
| 180 |
+
cmap = ensure_color_diversity(cmap)
|
| 181 |
+
cell_colours = [['#FFDDC1' for _ in base_list]]
|
| 182 |
+
color_start=0
|
| 183 |
+
|
| 184 |
+
for i in range(1,max_deep+1):
|
| 185 |
+
if i not in new_tree.keys():
|
| 186 |
+
print(f"Level {i} not in new_tree")
|
| 187 |
+
continue
|
| 188 |
+
data.append([names[base] for base in base_list])
|
| 189 |
+
color_list = []
|
| 190 |
+
for k,base in enumerate(base_list):
|
| 191 |
+
color_id = -1
|
| 192 |
+
for j in range(len(new_tree[i])):
|
| 193 |
+
if base in new_tree[i][j]:
|
| 194 |
+
color_id = j
|
| 195 |
+
break
|
| 196 |
+
if color_id==-1:
|
| 197 |
+
color_list.append(cell_colours[-1][k])
|
| 198 |
+
else:
|
| 199 |
+
color_list.append(cmap[color_start+color_id])
|
| 200 |
+
cell_colours.append(color_list)
|
| 201 |
+
color_start+=len(new_tree[i])
|
| 202 |
+
|
| 203 |
+
data = list(zip(*data))
|
| 204 |
+
cell_colours = list(zip(*cell_colours))
|
| 205 |
+
columns = ['Node ID']+['Level {}'.format(i) for i in range(1,max_deep+1)]
|
| 206 |
+
plt.figure(figsize=(30, 40))
|
| 207 |
+
table = plt.table(cellText=data, colLabels=columns, loc='center', cellLoc='center',
|
| 208 |
+
colColours=['#f5f5f5']*len(columns),cellColours=cell_colours)
|
| 209 |
+
table.auto_set_column_width([0, 1])
|
| 210 |
+
plt.axis('off')
|
| 211 |
+
plt.savefig(save_path, format='pdf' ,bbox_inches='tight',pad_inches=0.01)
|
| 212 |
+
|
| 213 |
+
def fix_asymmetry(matrix):
|
| 214 |
+
matrix = (matrix + matrix.T) / 2
|
| 215 |
+
return matrix
|
| 216 |
+
|
| 217 |
+
def rename(edge):
|
| 218 |
+
cnt=0
|
| 219 |
+
reid={}
|
| 220 |
+
du={}
|
| 221 |
+
edge_dict={}
|
| 222 |
+
queue=[]
|
| 223 |
+
for i in range(len(edge)):
|
| 224 |
+
du[edge[i][0]]=du.get(edge[i][0],0)+1
|
| 225 |
+
edge_dict[edge[i][1]]=edge[i]
|
| 226 |
+
if edge[i][2] != []:
|
| 227 |
+
queue.append(edge[i][1])
|
| 228 |
+
while len(queue)>0:
|
| 229 |
+
now = queue.pop(0)
|
| 230 |
+
if now==-1:
|
| 231 |
+
reid[now]=-1
|
| 232 |
+
continue
|
| 233 |
+
if now not in reid.keys():
|
| 234 |
+
reid[now]=cnt
|
| 235 |
+
cnt+=1
|
| 236 |
+
now_edge = edge_dict[now]
|
| 237 |
+
du[now_edge[0]]-=1
|
| 238 |
+
if du[now_edge[0]]==0:
|
| 239 |
+
queue.append(now_edge[0])
|
| 240 |
+
new_edge = [(reid[x[0]],reid[x[1]],x[2]) for x in edge]
|
| 241 |
+
return new_edge
|
| 242 |
+
|
| 243 |
+
def save_edge(edge,save_path):
|
| 244 |
+
with open(save_path,'w') as f:
|
| 245 |
+
for e in edge:
|
| 246 |
+
if e[2]:
|
| 247 |
+
name_str = ','.join(e[2])
|
| 248 |
+
else:
|
| 249 |
+
name_str = 'none'
|
| 250 |
+
f.write(f"{e[1]} {e[0]} {name_str}\n")
|
| 251 |
+
|
| 252 |
+
def filter_class(names, similarity_matrix):
|
| 253 |
+
choose_idx = []
|
| 254 |
+
for i in range(len(names)):
|
| 255 |
+
if 'extend' not in names[i] and 'polish' not in names[i] and\
|
| 256 |
+
'translate' not in names[i] and 'paraphrase' not in names[i]:
|
| 257 |
+
if 'B' in names[i] or 'human' in names[i]:
|
| 258 |
+
choose_idx.append(i)
|
| 259 |
+
else:
|
| 260 |
+
if random.random()<0.3:
|
| 261 |
+
choose_idx.append(i)
|
| 262 |
+
elif 'human' in names[i]:
|
| 263 |
+
if random.random()<0.3:
|
| 264 |
+
choose_idx.append(i)
|
| 265 |
+
elif random.random()<0.15:
|
| 266 |
+
choose_idx.append(i)
|
| 267 |
+
new_names = [names[i] for i in choose_idx]
|
| 268 |
+
choose_idx = np.array(choose_idx)
|
| 269 |
+
new_similarity_matrix = similarity_matrix[choose_idx][:,choose_idx]
|
| 270 |
+
return new_names, new_similarity_matrix
|
| 271 |
+
|
| 272 |
+
def filter(names, similarity_matrix,filter_human=False,filter_llm=False,filter_mix=False):
|
| 273 |
+
choose_idx = []
|
| 274 |
+
for i in range(len(names)):
|
| 275 |
+
if names[i] == 'human' and filter_human:
|
| 276 |
+
continue
|
| 277 |
+
if filter_llm and 'human' not in names[i]:
|
| 278 |
+
continue
|
| 279 |
+
if filter_mix and 'human' in names[i] and names[i]!='human':
|
| 280 |
+
continue
|
| 281 |
+
choose_idx.append(i)
|
| 282 |
+
new_names = [names[i] for i in choose_idx]
|
| 283 |
+
choose_idx = np.array(choose_idx)
|
| 284 |
+
new_similarity_matrix = similarity_matrix[choose_idx][:,choose_idx]
|
| 285 |
+
return new_names, new_similarity_matrix
|
| 286 |
+
|
| 287 |
+
def reid_tree_dict(tree_dict, nodes, names):
|
| 288 |
+
name_to_index = {name: idx for idx, name in enumerate(names)}
|
| 289 |
+
for deep,values in tree_dict.items():
|
| 290 |
+
rename_now = []
|
| 291 |
+
# print(values,len(values))
|
| 292 |
+
for list_ in values:
|
| 293 |
+
now_list = []
|
| 294 |
+
for idx in list_:
|
| 295 |
+
name = nodes[idx].name
|
| 296 |
+
if name not in name_to_index:
|
| 297 |
+
name_to_index[name] = len(names)
|
| 298 |
+
names.append(name)
|
| 299 |
+
name_idx = name_to_index[name]
|
| 300 |
+
now_list.append(name_idx)
|
| 301 |
+
rename_now.append(now_list)
|
| 302 |
+
tree_dict[deep] = rename_now
|
| 303 |
+
return tree_dict
|
| 304 |
+
|
| 305 |
+
def gen_tree(similarity_matrix,names,opt):
|
| 306 |
+
distance_matrix = 1 - similarity_matrix
|
| 307 |
+
np.fill_diagonal(distance_matrix, 0)
|
| 308 |
+
condensed_distance_matrix = squareform(distance_matrix)
|
| 309 |
+
Z = linkage(condensed_distance_matrix, method='weighted') # alternative methods include 'single', 'complete', or 'ward'
|
| 310 |
+
if opt.save_drg:
|
| 311 |
+
plt.figure(figsize=(30, 47))
|
| 312 |
+
dendrogram(Z, labels=names, orientation='right',leaf_font_size=16) # rotate the dendrogram so the root is on the right
|
| 313 |
+
plt.savefig(opt.dendrogram_path, format='pdf' ,bbox_inches='tight')
|
| 314 |
+
nodes = build_tree(Z, names)
|
| 315 |
+
merge_tree(len(nodes)-1,nodes,distance_matrix,end_thd=opt.end_score)
|
| 316 |
+
|
| 317 |
+
return nodes
|
| 318 |
+
|
| 319 |
+
def chage_tree_priori1(nodes):
|
| 320 |
+
human_node = TreeNode(name='human')
|
| 321 |
+
root = TreeNode()
|
| 322 |
+
root.add_child(len(nodes))
|
| 323 |
+
root.add_child(len(nodes)-1)
|
| 324 |
+
nodes.append(human_node)
|
| 325 |
+
nodes.append(root)
|
| 326 |
+
return nodes
|
| 327 |
+
|
| 328 |
+
def chage_tree_priori2(human_nodes,llm_nodes):
|
| 329 |
+
root = TreeNode()
|
| 330 |
+
root.add_child(len(human_nodes)-1)
|
| 331 |
+
root.add_child(len(human_nodes)+len(llm_nodes)-1)
|
| 332 |
+
for i in range(len(llm_nodes)):
|
| 333 |
+
llm_nodes[i].children = [len(human_nodes)+x for x in llm_nodes[i].children]
|
| 334 |
+
nodes = human_nodes+llm_nodes
|
| 335 |
+
nodes.append(root)
|
| 336 |
+
return nodes
|
| 337 |
+
|
| 338 |
+
def chage_tree_priori3(co_nodes,llm_nodes):
|
| 339 |
+
human_node = TreeNode(name='human')
|
| 340 |
+
root = TreeNode()
|
| 341 |
+
root.add_child(len(co_nodes)+len(llm_nodes))
|
| 342 |
+
root.add_child(len(co_nodes)-1)
|
| 343 |
+
root.add_child(len(co_nodes)+len(llm_nodes)-1)
|
| 344 |
+
for i in range(len(llm_nodes)):
|
| 345 |
+
llm_nodes[i].children = [len(co_nodes)+x for x in llm_nodes[i].children]
|
| 346 |
+
nodes = co_nodes+llm_nodes
|
| 347 |
+
nodes.append(human_node)
|
| 348 |
+
nodes.append(root)
|
| 349 |
+
return nodes
|
| 350 |
+
|
| 351 |
+
def randmo_filter(names, similarity_matrix):
|
| 352 |
+
choose_idx = []
|
| 353 |
+
for i in range(len(names)):
|
| 354 |
+
if 'human' in names[i]:
|
| 355 |
+
choose_idx.append(i)
|
| 356 |
+
elif 'fair' in names[i] or 'pplm' in names[i] or 'gpt2-pytorch' in names[i] or ' transfo' in names[i] or 'ctrl' in names[i]:
|
| 357 |
+
continue
|
| 358 |
+
elif 'xlnet' in names[i] or 'grover' in names[i]:
|
| 359 |
+
if random.random()<0.07:
|
| 360 |
+
choose_idx.append(i)
|
| 361 |
+
elif random.random()<0.22:
|
| 362 |
+
choose_idx.append(i)
|
| 363 |
+
new_names = []
|
| 364 |
+
for i in choose_idx:
|
| 365 |
+
if names[i].startswith('7B') or names[i].startswith('13B') or names[i].startswith('30B') or names[i].startswith('65B'):
|
| 366 |
+
new_names.append('LLaMA_'+names[i])
|
| 367 |
+
else:
|
| 368 |
+
new_names.append(names[i])
|
| 369 |
+
choose_idx = np.array(choose_idx)
|
| 370 |
+
new_similarity_matrix = similarity_matrix[choose_idx][:,choose_idx]
|
| 371 |
+
return new_names, new_similarity_matrix
|
| 372 |
+
|
| 373 |
+
def ishuman(name):
|
| 374 |
+
return ('human' in name)
|
| 375 |
+
def ismachine(name):
|
| 376 |
+
return ('machine' in name or 'rephrase' in name)
|
| 377 |
+
|
| 378 |
+
def get_llm(x):
|
| 379 |
+
if 'gpt-3.5-turbo' in x:
|
| 380 |
+
return 'gpt-3.5-turbo'
|
| 381 |
+
elif 'gpt-4o' in x:
|
| 382 |
+
return 'gpt-4o'
|
| 383 |
+
elif 'llama-3.3-70b' in x:
|
| 384 |
+
return 'llama-3.3-70b'
|
| 385 |
+
elif 'gemini-1.5-pro' in x:
|
| 386 |
+
return 'gemini-1.5-pro'
|
| 387 |
+
elif 'claude-3-5-sonnet' in x:
|
| 388 |
+
return 'claude-3-5-sonnet'
|
| 389 |
+
elif 'qwen2.5-72b' in x:
|
| 390 |
+
return 'qwen2.5-72b'
|
| 391 |
+
else:
|
| 392 |
+
raise ValueError(f"Invalid class name: {x}")
|
| 393 |
+
|
| 394 |
+
def get_name(name):
|
| 395 |
+
name = name.split('_')
|
| 396 |
+
assert len(name) == 2
|
| 397 |
+
if ishuman(name[0]):
|
| 398 |
+
if name[1]=='humanize:human' or name[1]=='human':
|
| 399 |
+
return 'human'
|
| 400 |
+
elif name[1]=='humanize:tool':
|
| 401 |
+
return 'human_humanize_tool'
|
| 402 |
+
else:
|
| 403 |
+
llm_name = get_llm(name[1])
|
| 404 |
+
return f'human_rephrase_{llm_name}'
|
| 405 |
+
elif ismachine(name[0]):
|
| 406 |
+
llm_name = get_llm(name[0])
|
| 407 |
+
if name[1]=='humanize:human' or name[1]=='human':
|
| 408 |
+
return f'{llm_name}_humanize_human'
|
| 409 |
+
elif name[1]=='humanize:tool':
|
| 410 |
+
return f'{llm_name}_humanize_tool'
|
| 411 |
+
elif 'humanize:' in name[1]:
|
| 412 |
+
llm_name2 = get_llm(name[1])
|
| 413 |
+
return f'{llm_name}_humanize_{llm_name2}'
|
| 414 |
+
else:
|
| 415 |
+
return llm_name
|
| 416 |
+
|
| 417 |
+
def clear_names(names):
|
| 418 |
+
new_names = []
|
| 419 |
+
for name in names:
|
| 420 |
+
new_names.append(get_name(name))
|
| 421 |
+
return new_names
|
| 422 |
+
|
| 423 |
+
def build_argument_parser() -> argparse.ArgumentParser:
|
| 424 |
+
parser = argparse.ArgumentParser(
|
| 425 |
+
description="Construct the HAT tree from a similarity matrix.",
|
| 426 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
| 427 |
+
)
|
| 428 |
+
parser.add_argument('--file-path', type=Path, required=True, help='Input similarity matrix text file.')
|
| 429 |
+
parser.add_argument('--priori',type=int,default=1,choices=[0,1,2,3])
|
| 430 |
+
parser.add_argument('--save-txt-path', type=Path, required=True, help='Destination path for the tree definition.')
|
| 431 |
+
parser.add_argument('--save-table-path', type=Path, required=True, help='Destination path for the visualised table.')
|
| 432 |
+
parser.add_argument('--dendrogram-path', type=Path, default=None, help='Optional path for the dendrogram PDF when saved.')
|
| 433 |
+
parser.add_argument('--save-drg', action='store_true', help='Persist the dendrogram PDF alongside the tree.')
|
| 434 |
+
parser.add_argument('--no-save-drg', dest='save_drg', action='store_false')
|
| 435 |
+
parser.set_defaults(save_drg=True)
|
| 436 |
+
parser.add_argument('--save-max-dep', type=int, default=5)
|
| 437 |
+
parser.add_argument('--end-score', type=float, default=0.1)
|
| 438 |
+
parser.add_argument('--randmo-filter', action='store_true', help='Randomly subsample similarity entries.')
|
| 439 |
+
return parser
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
def main(argv: Optional[Iterable[str]] = None) -> None:
|
| 443 |
+
parser = build_argument_parser()
|
| 444 |
+
opt = parser.parse_args(argv)
|
| 445 |
+
|
| 446 |
+
names, similarity_matrix = read_similarity_matrix(opt.file_path)
|
| 447 |
+
if opt.save_drg:
|
| 448 |
+
if opt.dendrogram_path is None:
|
| 449 |
+
opt.dendrogram_path = opt.save_table_path.with_name(
|
| 450 |
+
f"{opt.save_table_path.stem}_dendrogram.pdf"
|
| 451 |
+
)
|
| 452 |
+
opt.dendrogram_path.parent.mkdir(parents=True, exist_ok=True)
|
| 453 |
+
else:
|
| 454 |
+
opt.dendrogram_path = None
|
| 455 |
+
similarity_matrix = fix_asymmetry(similarity_matrix)
|
| 456 |
+
if opt.randmo_filter:
|
| 457 |
+
names, similarity_matrix = randmo_filter(names, similarity_matrix)
|
| 458 |
+
# names = clear_names(names)
|
| 459 |
+
if opt.priori==1:
|
| 460 |
+
llm_names, llm_similarity_matrix = filter(names, similarity_matrix,filter_human=True)
|
| 461 |
+
nodes = gen_tree(llm_similarity_matrix,llm_names,opt)
|
| 462 |
+
nodes = chage_tree_priori1(nodes)
|
| 463 |
+
|
| 464 |
+
elif opt.priori==2:
|
| 465 |
+
human_names, human_similarity_matrix = filter(names, similarity_matrix,filter_llm=True)
|
| 466 |
+
human_nodes = gen_tree(human_similarity_matrix,human_names,opt)
|
| 467 |
+
llm_names, llm_similarity_matrix = filter(names, similarity_matrix,filter_human=True,filter_mix=True)
|
| 468 |
+
llm_nodes = gen_tree(llm_similarity_matrix,llm_names,opt)
|
| 469 |
+
nodes = chage_tree_priori2(human_nodes,llm_nodes)
|
| 470 |
+
|
| 471 |
+
elif opt.priori==3:
|
| 472 |
+
co_names, co_similarity_matrix = filter(names, similarity_matrix,filter_llm=True,filter_human=True)
|
| 473 |
+
co_nodes = gen_tree(co_similarity_matrix,co_names,opt)
|
| 474 |
+
llm_names, llm_similarity_matrix = filter(names, similarity_matrix,filter_human=True,filter_mix=True)
|
| 475 |
+
llm_nodes = gen_tree(llm_similarity_matrix,llm_names,opt)
|
| 476 |
+
nodes = chage_tree_priori3(co_nodes,llm_nodes)
|
| 477 |
+
|
| 478 |
+
elif opt.priori==0:
|
| 479 |
+
nodes = gen_tree(similarity_matrix,names,opt)
|
| 480 |
+
else:
|
| 481 |
+
raise ValueError("Invalid value for --priori. Choose from 0, 1, 2, or 3.")
|
| 482 |
+
|
| 483 |
+
edge=[]
|
| 484 |
+
tree_dict = update_tree(len(nodes)-1, nodes, edge)
|
| 485 |
+
edge = rename(edge)
|
| 486 |
+
opt.save_txt_path.parent.mkdir(parents=True, exist_ok=True)
|
| 487 |
+
opt.save_table_path.parent.mkdir(parents=True, exist_ok=True)
|
| 488 |
+
save_edge(edge,opt.save_txt_path)
|
| 489 |
+
tree_dict = reid_tree_dict(tree_dict, nodes, names)
|
| 490 |
+
draw_table(tree_dict, names, max_deep=opt.save_max_dep, save_path=opt.save_table_path)
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
if __name__ == "__main__":
|
| 494 |
+
main()
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
__all__ = ["build_argument_parser", "main", "read_similarity_matrix", "gen_tree"]
|
detree/cli/merge_lora.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Merge LoRA adapters into base models."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Iterable, Optional
|
| 8 |
+
|
| 9 |
+
from peft import PeftModel
|
| 10 |
+
from transformers import AutoModel, AutoTokenizer
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def merge_lora_adapter(base_model: str, adapter_path: Path, output_dir: Path, safe_serialization: bool = True) -> None:
|
| 14 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 15 |
+
|
| 16 |
+
model = AutoModel.from_pretrained(base_model, trust_remote_code=True)
|
| 17 |
+
tokenizer = AutoTokenizer.from_pretrained(base_model)
|
| 18 |
+
|
| 19 |
+
peft_model = PeftModel.from_pretrained(model, str(adapter_path))
|
| 20 |
+
merged_model = peft_model.merge_and_unload()
|
| 21 |
+
merged_model.save_pretrained(output_dir, safe_serialization=safe_serialization)
|
| 22 |
+
tokenizer.save_pretrained(output_dir)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def build_argument_parser() -> argparse.ArgumentParser:
|
| 26 |
+
parser = argparse.ArgumentParser(description="Merge a LoRA adapter into its base Hugging Face model.")
|
| 27 |
+
parser.add_argument("--base-model", type=str, required=True, help="Base model name or path.")
|
| 28 |
+
parser.add_argument("--adapter-path", type=Path, required=True, help="Directory containing the LoRA adapter weights.")
|
| 29 |
+
parser.add_argument("--output-dir", type=Path, required=True, help="Directory to store the merged model.")
|
| 30 |
+
parser.add_argument(
|
| 31 |
+
"--no-safe-serialization",
|
| 32 |
+
action="store_true",
|
| 33 |
+
help="Disable safetensors when saving the merged model.",
|
| 34 |
+
)
|
| 35 |
+
return parser
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def main(argv: Optional[Iterable[str]] = None) -> None:
|
| 39 |
+
parser = build_argument_parser()
|
| 40 |
+
args = parser.parse_args(argv)
|
| 41 |
+
merge_lora_adapter(
|
| 42 |
+
args.base_model,
|
| 43 |
+
args.adapter_path,
|
| 44 |
+
args.output_dir,
|
| 45 |
+
safe_serialization=not args.no_safe_serialization,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
if __name__ == "__main__":
|
| 50 |
+
main()
|
| 51 |
+
|
| 52 |
+
__all__ = ["build_argument_parser", "merge_lora_adapter", "main"]
|
detree/cli/similarity_matrix.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Compute similarity matrices from embedding databases."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Iterable, Optional
|
| 8 |
+
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def gen_data(dict_data):
|
| 14 |
+
embeddings = dict_data["embeddings"]
|
| 15 |
+
labels = dict_data["labels"]
|
| 16 |
+
ids = dict_data["ids"]
|
| 17 |
+
classes = dict_data["classes"]
|
| 18 |
+
return embeddings, labels, ids, classes
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def build_argument_parser() -> argparse.ArgumentParser:
|
| 22 |
+
parser = argparse.ArgumentParser(
|
| 23 |
+
description="Generate class similarity matrices for DETree.",
|
| 24 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
| 25 |
+
)
|
| 26 |
+
parser.add_argument("--database", type=Path, required=True, help="Path to the embedding database (.pt).")
|
| 27 |
+
parser.add_argument("--output-dir", type=Path, required=True, help="Directory to store the similarity outputs.")
|
| 28 |
+
parser.add_argument("--layers", type=int, nargs="*", default=None, help="Specific layers to export. Defaults to all.")
|
| 29 |
+
return parser
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def compute_similarity(database: Path, output_dir: Path, layers: Optional[Iterable[int]]) -> None:
|
| 33 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 34 |
+
data_emb, data_labels, data_ids, data_classes = gen_data(torch.load(database))
|
| 35 |
+
|
| 36 |
+
if layers is None:
|
| 37 |
+
layers = list(data_emb.keys())
|
| 38 |
+
|
| 39 |
+
for layer in layers:
|
| 40 |
+
center = []
|
| 41 |
+
for item in data_classes:
|
| 42 |
+
index = data_classes.index(item)
|
| 43 |
+
now_emb = data_emb[layer][data_labels == index]
|
| 44 |
+
center.append(torch.mean(now_emb, dim=0))
|
| 45 |
+
center = torch.stack(center)
|
| 46 |
+
similarity = center @ center.T
|
| 47 |
+
similarity_np = similarity.cpu().float().numpy()
|
| 48 |
+
|
| 49 |
+
txt_path = output_dir / f"similarity_layer_{layer}.txt"
|
| 50 |
+
with txt_path.open("w", encoding="utf-8") as f:
|
| 51 |
+
f.write(" ".join(data_classes) + "\n")
|
| 52 |
+
for i, class_name in enumerate(data_classes):
|
| 53 |
+
row = " ".join(f"{similarity_np[i, j]:.4f}" for j in range(len(data_classes)))
|
| 54 |
+
f.write(f"{class_name} {row}\n")
|
| 55 |
+
|
| 56 |
+
plt.figure(figsize=(30, 30))
|
| 57 |
+
plt.imshow(similarity_np, cmap="viridis")
|
| 58 |
+
plt.colorbar()
|
| 59 |
+
plt.xticks(range(len(data_classes)), data_classes, rotation=45, fontsize=12)
|
| 60 |
+
plt.yticks(range(len(data_classes)), data_classes, fontsize=12)
|
| 61 |
+
plt.title(f"Similarity Matrix (layer {layer})", fontsize=20)
|
| 62 |
+
fig_path = output_dir / f"similarity_layer_{layer}.png"
|
| 63 |
+
plt.savefig(fig_path, dpi=300, bbox_inches="tight")
|
| 64 |
+
plt.close()
|
| 65 |
+
print(f"Saved similarity matrix for layer {layer} to {txt_path} and {fig_path}")
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def main(argv: Optional[Iterable[str]] = None) -> None:
|
| 69 |
+
parser = build_argument_parser()
|
| 70 |
+
args = parser.parse_args(argv)
|
| 71 |
+
compute_similarity(args.database, args.output_dir, args.layers)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
if __name__ == "__main__":
|
| 75 |
+
main()
|
| 76 |
+
|
| 77 |
+
__all__ = ["build_argument_parser", "compute_similarity", "main"]
|
detree/cli/test_database_score_knn.py
ADDED
|
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""kNN evaluation using pre-computed embedding databases."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
from multiprocessing import Pool, cpu_count
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Iterable, List, Optional, Sequence, Tuple
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from lightning import Fabric
|
| 16 |
+
from torch.nn.functional import softmax as F_softmax
|
| 17 |
+
from torch.utils.data import DataLoader, Dataset
|
| 18 |
+
from tqdm import tqdm
|
| 19 |
+
|
| 20 |
+
from detree.model.text_embedding import TextEmbeddingModel
|
| 21 |
+
from detree.utils.index import Indexer
|
| 22 |
+
from detree.utils.utils import evaluate_metrics
|
| 23 |
+
|
| 24 |
+
os.environ.setdefault("TOKENIZERS_PARALLELISM", "true")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def load_jsonl(file_path: Path) -> List[dict]:
|
| 28 |
+
out = []
|
| 29 |
+
with file_path.open(mode="r", encoding="utf-8") as jsonl_file:
|
| 30 |
+
for line in jsonl_file:
|
| 31 |
+
item = json.loads(line)
|
| 32 |
+
out.append(item)
|
| 33 |
+
print(f"Loaded {len(out)} examples from {file_path}")
|
| 34 |
+
return out
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def gen_data(dict_data):
|
| 38 |
+
embeddings = dict_data["embeddings"]
|
| 39 |
+
labels = dict_data["labels"]
|
| 40 |
+
ids = dict_data["ids"]
|
| 41 |
+
classes = dict_data["classes"]
|
| 42 |
+
return embeddings, labels, ids, classes
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class PassagesDataset(Dataset):
|
| 46 |
+
def __init__(self, data: Sequence[dict]):
|
| 47 |
+
self.passages = list(data)
|
| 48 |
+
|
| 49 |
+
def __len__(self) -> int:
|
| 50 |
+
return len(self.passages)
|
| 51 |
+
|
| 52 |
+
def __getitem__(self, idx: int):
|
| 53 |
+
data_now = self.passages[idx]
|
| 54 |
+
text = data_now["text"]
|
| 55 |
+
label = data_now["label"]
|
| 56 |
+
ids = data_now["id"]
|
| 57 |
+
return text, int(label), int(ids)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def infer(passages_dataloader, fabric, tokenizer, model, need_layers: Sequence[int], max_length: int = 512):
|
| 61 |
+
if fabric.global_rank == 0:
|
| 62 |
+
passages_dataloader = tqdm(passages_dataloader)
|
| 63 |
+
all_ids: List[int] = []
|
| 64 |
+
all_embeddings: List[torch.Tensor] = []
|
| 65 |
+
all_labels: List[int] = []
|
| 66 |
+
with torch.no_grad():
|
| 67 |
+
for batch in passages_dataloader:
|
| 68 |
+
text, label, ids = batch
|
| 69 |
+
encoded_batch = tokenizer.batch_encode_plus(
|
| 70 |
+
text,
|
| 71 |
+
return_tensors="pt",
|
| 72 |
+
max_length=max_length,
|
| 73 |
+
padding="max_length",
|
| 74 |
+
truncation=True,
|
| 75 |
+
)
|
| 76 |
+
encoded_batch = {k: v.cuda() for k, v in encoded_batch.items()}
|
| 77 |
+
embeddings = model(encoded_batch, hidden_states=True)
|
| 78 |
+
embeddings = fabric.all_gather(embeddings).view(-1, embeddings.size(-2), embeddings.size(-1))
|
| 79 |
+
label = fabric.all_gather(label).view(-1)
|
| 80 |
+
ids = fabric.all_gather(ids).view(-1)
|
| 81 |
+
if fabric.global_rank == 0:
|
| 82 |
+
all_embeddings.append(embeddings.cpu())
|
| 83 |
+
all_ids.extend(ids.cpu().tolist())
|
| 84 |
+
all_labels.extend(label.cpu().tolist())
|
| 85 |
+
if fabric.global_rank == 0:
|
| 86 |
+
embeddings_tensor = torch.cat(all_embeddings, dim=0)
|
| 87 |
+
embeddings_tensor = F.normalize(embeddings_tensor, dim=-1).permute(1, 0, 2).numpy()
|
| 88 |
+
embeddings_tensor = {layer: embeddings_tensor[layer] for layer in need_layers}
|
| 89 |
+
return all_ids, embeddings_tensor, all_labels
|
| 90 |
+
return [], [], []
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def dict2str(metrics: dict) -> str:
|
| 94 |
+
out_str = ""
|
| 95 |
+
if "layer" in metrics:
|
| 96 |
+
out_str += f"layer:{metrics['layer']} "
|
| 97 |
+
if "k" in metrics:
|
| 98 |
+
out_str += f"k:{metrics['k']} "
|
| 99 |
+
for key, value in metrics.items():
|
| 100 |
+
if key not in {"layer", "k"}:
|
| 101 |
+
out_str += f"{key}:{value} "
|
| 102 |
+
return out_str.strip()
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def process_element(args: Tuple[Sequence[int], Sequence[float], Sequence[int], float]):
|
| 106 |
+
ids, scores, labels, temperature = args
|
| 107 |
+
now_score = torch.zeros(2)
|
| 108 |
+
sorted_indices = np.argsort(scores)[::-1]
|
| 109 |
+
element_preds = {}
|
| 110 |
+
|
| 111 |
+
for k, idx in enumerate(sorted_indices):
|
| 112 |
+
label = labels[idx]
|
| 113 |
+
now_score[label] += scores[idx] * temperature
|
| 114 |
+
prob = F_softmax(now_score, dim=-1)[1].item()
|
| 115 |
+
element_preds[k + 1] = prob
|
| 116 |
+
|
| 117 |
+
return element_preds
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def build_argument_parser() -> argparse.ArgumentParser:
|
| 121 |
+
parser = argparse.ArgumentParser(
|
| 122 |
+
description="Evaluate DETree with a precomputed embedding database.",
|
| 123 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
| 124 |
+
)
|
| 125 |
+
parser.add_argument("--device-num", type=int, default=1)
|
| 126 |
+
parser.add_argument("--batch-size", type=int, default=32)
|
| 127 |
+
parser.add_argument("--num-workers", type=int, default=8)
|
| 128 |
+
parser.add_argument("--max-length", type=int, default=512)
|
| 129 |
+
|
| 130 |
+
parser.add_argument("--database-path", type=Path, required=True, help="Path to the saved embedding database (.pt).")
|
| 131 |
+
parser.add_argument("--test-dataset-path", type=Path, required=True, help="Evaluation JSONL file.")
|
| 132 |
+
parser.add_argument("--model-name-or-path", type=str, required=True)
|
| 133 |
+
parser.add_argument("--temperature", type=float, default=0.05)
|
| 134 |
+
|
| 135 |
+
parser.add_argument("--max-k", type=int, default=51, dest="max_K")
|
| 136 |
+
parser.add_argument("--pooling", type=str, default="max", choices=("max", "average", "cls"))
|
| 137 |
+
|
| 138 |
+
parser.add_argument("--embedding-dim", type=int, default=1024)
|
| 139 |
+
parser.add_argument("--pool-workers", type=int, default=min(32, cpu_count()))
|
| 140 |
+
parser.add_argument("--log-file", type=Path, default=Path("runs/val.txt"))
|
| 141 |
+
|
| 142 |
+
return parser
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def evaluate(args: argparse.Namespace) -> None:
|
| 146 |
+
if args.device_num > 1:
|
| 147 |
+
fabric = Fabric(accelerator="cuda", devices=args.device_num, strategy="ddp", precision="bf16-mixed")
|
| 148 |
+
else:
|
| 149 |
+
fabric = Fabric(accelerator="cuda", devices=args.device_num, precision="bf16-mixed")
|
| 150 |
+
fabric.launch()
|
| 151 |
+
|
| 152 |
+
model = TextEmbeddingModel(
|
| 153 |
+
args.model_name_or_path,
|
| 154 |
+
output_hidden_states=True,
|
| 155 |
+
infer=True,
|
| 156 |
+
use_pooling=args.pooling,
|
| 157 |
+
).cuda()
|
| 158 |
+
tokenizer = model.tokenizer
|
| 159 |
+
model.eval()
|
| 160 |
+
|
| 161 |
+
if fabric.global_rank == 0:
|
| 162 |
+
db_embeddings, db_labels, db_ids, classes = gen_data(torch.load(args.database_path))
|
| 163 |
+
need_layers = list(db_embeddings.keys())
|
| 164 |
+
else:
|
| 165 |
+
db_embeddings = db_labels = db_ids = classes = None
|
| 166 |
+
need_layers = []
|
| 167 |
+
need_layers = fabric.broadcast(need_layers)
|
| 168 |
+
|
| 169 |
+
test_database = load_jsonl(args.test_dataset_path)
|
| 170 |
+
test_dataset = PassagesDataset(test_database)
|
| 171 |
+
test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False)
|
| 172 |
+
test_dataloader = fabric.setup_dataloaders(test_dataloader)
|
| 173 |
+
model = fabric.setup(model)
|
| 174 |
+
test_ids, test_embeddings, test_labels = infer(test_dataloader, fabric, tokenizer, model, need_layers, args.max_length)
|
| 175 |
+
|
| 176 |
+
torch.cuda.empty_cache()
|
| 177 |
+
if fabric.global_rank != 0:
|
| 178 |
+
return
|
| 179 |
+
|
| 180 |
+
test_labels = [int(label) for label in test_labels]
|
| 181 |
+
index = Indexer(args.embedding_dim)
|
| 182 |
+
human_idx = classes.index("human")
|
| 183 |
+
|
| 184 |
+
all_details = []
|
| 185 |
+
with Pool(processes=args.pool_workers) as pool:
|
| 186 |
+
for layer in need_layers:
|
| 187 |
+
now_best_metrics = None
|
| 188 |
+
label_dict = {}
|
| 189 |
+
train_embeddings = db_embeddings[layer].float().numpy()
|
| 190 |
+
if isinstance(db_labels, dict):
|
| 191 |
+
train_labels = db_labels[layer].tolist()
|
| 192 |
+
train_ids = db_ids[layer].tolist()
|
| 193 |
+
else:
|
| 194 |
+
train_labels = db_labels.tolist()
|
| 195 |
+
train_ids = db_ids.tolist()
|
| 196 |
+
|
| 197 |
+
for i in range(len(train_ids)):
|
| 198 |
+
label_dict[int(train_ids[i])] = int(train_labels[i] == human_idx)
|
| 199 |
+
index.label_dict = label_dict
|
| 200 |
+
index.reset()
|
| 201 |
+
index.index_data(train_ids, train_embeddings)
|
| 202 |
+
preds = {k: [] for k in range(1, args.max_K + 1)}
|
| 203 |
+
top_ids_and_scores = index.search_knn(test_embeddings[layer], args.max_K, index_batch_size=128)
|
| 204 |
+
|
| 205 |
+
args_list = [
|
| 206 |
+
(ids, scores, labels, args.temperature)
|
| 207 |
+
for ids, scores, labels in top_ids_and_scores
|
| 208 |
+
]
|
| 209 |
+
for result in tqdm(pool.imap(process_element, args_list), total=len(args_list)):
|
| 210 |
+
for k, value in result.items():
|
| 211 |
+
preds[k].append(value)
|
| 212 |
+
|
| 213 |
+
for k in range(1, args.max_K + 1):
|
| 214 |
+
metric = evaluate_metrics(test_labels, preds[k], threshold_param=-1)
|
| 215 |
+
if now_best_metrics is None or now_best_metrics["auroc"] < metric["auroc"]:
|
| 216 |
+
now_best_metrics = metric
|
| 217 |
+
now_best_metrics["k"] = k
|
| 218 |
+
now_best_metrics["layer"] = layer
|
| 219 |
+
|
| 220 |
+
if now_best_metrics:
|
| 221 |
+
print(dict2str(now_best_metrics))
|
| 222 |
+
all_details.append(now_best_metrics)
|
| 223 |
+
|
| 224 |
+
if not all_details:
|
| 225 |
+
return
|
| 226 |
+
|
| 227 |
+
max_ids = max(range(len(all_details)), key=lambda idx: all_details[idx]["auroc"])
|
| 228 |
+
best_metrics = all_details[max_ids]
|
| 229 |
+
print("Best " + dict2str(best_metrics))
|
| 230 |
+
args.log_file.parent.mkdir(parents=True, exist_ok=True)
|
| 231 |
+
with args.log_file.open("a+", encoding="utf-8") as fp:
|
| 232 |
+
fp.write(f"test model:{args.model_name_or_path} mode:{args.test_dataset_path} database_path:{args.database_path}\n")
|
| 233 |
+
fp.write(f"Last {dict2str(all_details[-1])}\n")
|
| 234 |
+
fp.write(f"Best {dict2str(best_metrics)}\n")
|
| 235 |
+
fp.write("------------------------------------------\n")
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def main(argv: Optional[Iterable[str]] = None) -> None:
|
| 239 |
+
parser = build_argument_parser()
|
| 240 |
+
args = parser.parse_args(argv)
|
| 241 |
+
evaluate(args)
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
if __name__ == "__main__":
|
| 245 |
+
main()
|
| 246 |
+
|
| 247 |
+
__all__ = ["build_argument_parser", "evaluate", "main"]
|
detree/cli/test_score_knn.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""kNN evaluation against raw text datasets."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
from multiprocessing import Pool, cpu_count
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Iterable, List, Optional, Sequence, Tuple
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from lightning import Fabric
|
| 16 |
+
from torch.nn.functional import softmax as F_softmax
|
| 17 |
+
from torch.utils.data import DataLoader, Dataset
|
| 18 |
+
from tqdm import tqdm
|
| 19 |
+
|
| 20 |
+
from detree.model.text_embedding import TextEmbeddingModel
|
| 21 |
+
from detree.utils.index import Indexer
|
| 22 |
+
from detree.utils.utils import evaluate_metrics
|
| 23 |
+
|
| 24 |
+
os.environ.setdefault("TOKENIZERS_PARALLELISM", "true")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def load_jsonl(file_path: Path) -> List[dict]:
|
| 28 |
+
out = []
|
| 29 |
+
with file_path.open(mode="r", encoding="utf-8") as jsonl_file:
|
| 30 |
+
for line in jsonl_file:
|
| 31 |
+
item = json.loads(line)
|
| 32 |
+
out.append(item)
|
| 33 |
+
print(f"Loaded {len(out)} examples from {file_path}")
|
| 34 |
+
return out
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class PassagesDataset(Dataset):
|
| 38 |
+
def __init__(self, data: Sequence[dict]):
|
| 39 |
+
self.passages = list(data)
|
| 40 |
+
|
| 41 |
+
def __len__(self) -> int:
|
| 42 |
+
return len(self.passages)
|
| 43 |
+
|
| 44 |
+
def __getitem__(self, idx: int):
|
| 45 |
+
data_now = self.passages[idx]
|
| 46 |
+
text = data_now["text"]
|
| 47 |
+
label = data_now["label"]
|
| 48 |
+
ids = data_now["id"]
|
| 49 |
+
return text, int(label), int(ids)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def infer(passages_dataloader, fabric, tokenizer, model, max_length: int = 512):
|
| 53 |
+
if fabric.global_rank == 0:
|
| 54 |
+
passages_dataloader = tqdm(passages_dataloader)
|
| 55 |
+
all_ids: List[int] = []
|
| 56 |
+
all_embeddings: List[torch.Tensor] = []
|
| 57 |
+
all_labels: List[int] = []
|
| 58 |
+
with torch.no_grad():
|
| 59 |
+
for batch in passages_dataloader:
|
| 60 |
+
text, label, ids = batch
|
| 61 |
+
encoded_batch = tokenizer.batch_encode_plus(
|
| 62 |
+
text,
|
| 63 |
+
return_tensors="pt",
|
| 64 |
+
max_length=max_length,
|
| 65 |
+
padding="max_length",
|
| 66 |
+
truncation=True,
|
| 67 |
+
)
|
| 68 |
+
encoded_batch = {k: v.cuda() for k, v in encoded_batch.items()}
|
| 69 |
+
embeddings = model(encoded_batch, hidden_states=True)
|
| 70 |
+
embeddings = fabric.all_gather(embeddings).view(-1, embeddings.size(-2), embeddings.size(-1))
|
| 71 |
+
label = fabric.all_gather(label).view(-1)
|
| 72 |
+
ids = fabric.all_gather(ids).view(-1)
|
| 73 |
+
if fabric.global_rank == 0:
|
| 74 |
+
all_embeddings.append(embeddings.cpu())
|
| 75 |
+
all_ids.extend(ids.cpu().tolist())
|
| 76 |
+
all_labels.extend(label.cpu().tolist())
|
| 77 |
+
if fabric.global_rank == 0:
|
| 78 |
+
embeddings_tensor = torch.cat(all_embeddings, dim=0)
|
| 79 |
+
embeddings_tensor = F.normalize(embeddings_tensor, dim=-1).permute(1, 0, 2)
|
| 80 |
+
return all_ids, embeddings_tensor.numpy(), all_labels
|
| 81 |
+
return [], [], []
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def save_pt(train_embeddings, all_labels, train_ids, args, best_layer):
|
| 85 |
+
save_layer = [best_layer, train_embeddings.shape[0] - 1]
|
| 86 |
+
all_embeddings = {i: torch.tensor(train_embeddings[i]) for i in save_layer}
|
| 87 |
+
emb_dict = {
|
| 88 |
+
"embeddings": all_embeddings,
|
| 89 |
+
"labels": torch.tensor(all_labels),
|
| 90 |
+
"ids": torch.tensor(train_ids),
|
| 91 |
+
"classes": ["llm", "human"],
|
| 92 |
+
}
|
| 93 |
+
args.savedir.mkdir(parents=True, exist_ok=True)
|
| 94 |
+
output_path = args.savedir / f"{args.name}.pt"
|
| 95 |
+
torch.save(emb_dict, output_path)
|
| 96 |
+
print(f"Saved embedding snapshot to {output_path}")
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def dict2str(metrics: dict) -> str:
|
| 100 |
+
out_str = ""
|
| 101 |
+
if "layer" in metrics:
|
| 102 |
+
out_str += f"layer:{metrics['layer']} "
|
| 103 |
+
if "k" in metrics:
|
| 104 |
+
out_str += f"k:{metrics['k']} "
|
| 105 |
+
for key, value in metrics.items():
|
| 106 |
+
if key not in {"layer", "k"}:
|
| 107 |
+
out_str += f"{key}:{value} "
|
| 108 |
+
return out_str.strip()
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def process_element(args: Tuple[Sequence[int], Sequence[float], Sequence[int], float]):
|
| 112 |
+
ids, scores, labels, temperature = args
|
| 113 |
+
now_score = torch.zeros(2)
|
| 114 |
+
sorted_indices = np.argsort(scores)[::-1]
|
| 115 |
+
element_preds = {}
|
| 116 |
+
|
| 117 |
+
for k, idx in enumerate(sorted_indices):
|
| 118 |
+
label = labels[idx]
|
| 119 |
+
now_score[label] += scores[idx] * temperature
|
| 120 |
+
prob = F_softmax(now_score, dim=-1)[1].item()
|
| 121 |
+
element_preds[k + 1] = prob
|
| 122 |
+
|
| 123 |
+
return element_preds
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def build_argument_parser() -> argparse.ArgumentParser:
|
| 127 |
+
parser = argparse.ArgumentParser(
|
| 128 |
+
description="Evaluate DETree checkpoints using a kNN classifier over hidden states.",
|
| 129 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
| 130 |
+
)
|
| 131 |
+
parser.add_argument("--device-num", type=int, default=1)
|
| 132 |
+
parser.add_argument("--batch-size", type=int, default=32)
|
| 133 |
+
parser.add_argument("--num-workers", type=int, default=8)
|
| 134 |
+
parser.add_argument("--max-length", type=int, default=512)
|
| 135 |
+
|
| 136 |
+
parser.add_argument("--database-path", type=Path, required=True, help="Training set JSONL file.")
|
| 137 |
+
parser.add_argument("--test-dataset-path", type=Path, required=True, help="Evaluation set JSONL file.")
|
| 138 |
+
parser.add_argument(
|
| 139 |
+
"--model-name-or-path",
|
| 140 |
+
type=str,
|
| 141 |
+
required=True,
|
| 142 |
+
help="Model identifier from Hugging Face or local path to a merged checkpoint.",
|
| 143 |
+
)
|
| 144 |
+
parser.add_argument("--temperature", type=float, default=0.05)
|
| 145 |
+
|
| 146 |
+
parser.add_argument("--max-k", type=int, default=50, dest="max_K", help="Maximum k to evaluate for kNN.")
|
| 147 |
+
parser.add_argument("--min-layer", type=int, default=15, help="Minimum hidden layer index to evaluate.")
|
| 148 |
+
parser.add_argument("--pooling", type=str, default="max", choices=("max", "average", "cls"))
|
| 149 |
+
|
| 150 |
+
parser.add_argument("--embedding-dim", type=int, default=1024)
|
| 151 |
+
parser.add_argument("--n-subquantizers", type=int, default=1)
|
| 152 |
+
parser.add_argument("--n-bits", type=int, default=8)
|
| 153 |
+
|
| 154 |
+
parser.add_argument("--savedir", type=Path, default=Path("runs"))
|
| 155 |
+
parser.add_argument("--name", type=str, default="database_knn_eval")
|
| 156 |
+
parser.add_argument("--pool-workers", type=int, default=min(32, cpu_count()))
|
| 157 |
+
parser.add_argument("--save-embeddings", action="store_true", help="Persist embeddings for the best-performing layer.")
|
| 158 |
+
parser.add_argument("--log-file", type=Path, default=Path("runs/val.txt"))
|
| 159 |
+
|
| 160 |
+
return parser
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def evaluate(args: argparse.Namespace) -> None:
|
| 164 |
+
if args.device_num > 1:
|
| 165 |
+
fabric = Fabric(accelerator="cuda", devices=args.device_num, strategy="ddp", precision="bf16-mixed")
|
| 166 |
+
else:
|
| 167 |
+
fabric = Fabric(accelerator="cuda", devices=args.device_num, precision="bf16-mixed")
|
| 168 |
+
fabric.launch()
|
| 169 |
+
|
| 170 |
+
model = TextEmbeddingModel(
|
| 171 |
+
args.model_name_or_path,
|
| 172 |
+
output_hidden_states=True,
|
| 173 |
+
infer=True,
|
| 174 |
+
use_pooling=args.pooling,
|
| 175 |
+
).cuda()
|
| 176 |
+
tokenizer = model.tokenizer
|
| 177 |
+
model.eval()
|
| 178 |
+
|
| 179 |
+
database = load_jsonl(args.database_path)
|
| 180 |
+
test_database = load_jsonl(args.test_dataset_path)
|
| 181 |
+
|
| 182 |
+
passages_dataset = PassagesDataset(database)
|
| 183 |
+
test_dataset = PassagesDataset(test_database)
|
| 184 |
+
|
| 185 |
+
passages_dataloader = DataLoader(
|
| 186 |
+
passages_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True
|
| 187 |
+
)
|
| 188 |
+
test_dataloader = DataLoader(
|
| 189 |
+
test_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
passages_dataloader, test_dataloader = fabric.setup_dataloaders(passages_dataloader, test_dataloader)
|
| 193 |
+
model = fabric.setup(model)
|
| 194 |
+
|
| 195 |
+
train_ids, train_embeddings, train_labels = infer(passages_dataloader, fabric, tokenizer, model, args.max_length)
|
| 196 |
+
test_ids, test_embeddings, test_labels = infer(test_dataloader, fabric, tokenizer, model, args.max_length)
|
| 197 |
+
|
| 198 |
+
torch.cuda.empty_cache()
|
| 199 |
+
if fabric.global_rank != 0:
|
| 200 |
+
return
|
| 201 |
+
|
| 202 |
+
layer_num = train_embeddings.shape[0]
|
| 203 |
+
test_labels = [int(label) for label in test_labels]
|
| 204 |
+
|
| 205 |
+
label_dict = {train_ids[i]: int(train_labels[i]) for i in range(len(train_ids))}
|
| 206 |
+
|
| 207 |
+
all_details = []
|
| 208 |
+
index = Indexer(args.embedding_dim, args.n_subquantizers, args.n_bits)
|
| 209 |
+
index.label_dict = label_dict
|
| 210 |
+
|
| 211 |
+
with Pool(processes=args.pool_workers) as pool:
|
| 212 |
+
for i in range(args.min_layer, layer_num):
|
| 213 |
+
now_best_metrics = None
|
| 214 |
+
index.reset()
|
| 215 |
+
index.index_data(train_ids, train_embeddings[i])
|
| 216 |
+
preds = {k: [] for k in range(1, args.max_K + 1)}
|
| 217 |
+
top_ids_and_scores = index.search_knn(test_embeddings[i], args.max_K, index_batch_size=128)
|
| 218 |
+
|
| 219 |
+
args_list = [
|
| 220 |
+
(ids, scores, labels, args.temperature)
|
| 221 |
+
for ids, scores, labels in top_ids_and_scores
|
| 222 |
+
]
|
| 223 |
+
for result in tqdm(pool.imap(process_element, args_list), total=len(args_list)):
|
| 224 |
+
for k, value in result.items():
|
| 225 |
+
preds[k].append(value)
|
| 226 |
+
|
| 227 |
+
for k in range(2, args.max_K + 1):
|
| 228 |
+
metric = evaluate_metrics(test_labels, preds[k], threshold_param=-1)
|
| 229 |
+
if now_best_metrics is None or now_best_metrics["auroc"] < metric["auroc"]:
|
| 230 |
+
now_best_metrics = metric
|
| 231 |
+
now_best_metrics["k"] = k
|
| 232 |
+
now_best_metrics["layer"] = i
|
| 233 |
+
|
| 234 |
+
if now_best_metrics:
|
| 235 |
+
print(dict2str(now_best_metrics))
|
| 236 |
+
all_details.append(now_best_metrics)
|
| 237 |
+
|
| 238 |
+
if not all_details:
|
| 239 |
+
return
|
| 240 |
+
|
| 241 |
+
max_ids = max(range(len(all_details)), key=lambda idx: all_details[idx]["auroc"])
|
| 242 |
+
best_metrics = all_details[max_ids]
|
| 243 |
+
|
| 244 |
+
if args.save_embeddings:
|
| 245 |
+
save_pt(train_embeddings, train_labels, train_ids, args, best_metrics["layer"])
|
| 246 |
+
|
| 247 |
+
print("Best " + dict2str(best_metrics))
|
| 248 |
+
args.log_file.parent.mkdir(parents=True, exist_ok=True)
|
| 249 |
+
with args.log_file.open("a+", encoding="utf-8") as fp:
|
| 250 |
+
fp.write(
|
| 251 |
+
f"test model:{args.model_name_or_path} database_path:{args.database_path} mode:{args.test_dataset_path}\n"
|
| 252 |
+
)
|
| 253 |
+
fp.write(f"Last {dict2str(all_details[-1])}\n")
|
| 254 |
+
fp.write(f"Best {dict2str(best_metrics)}\n")
|
| 255 |
+
fp.write("------------------------------------------\n")
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def main(argv: Optional[Iterable[str]] = None) -> None:
|
| 259 |
+
parser = build_argument_parser()
|
| 260 |
+
args = parser.parse_args(argv)
|
| 261 |
+
evaluate(args)
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
if __name__ == "__main__":
|
| 265 |
+
main()
|
| 266 |
+
|
| 267 |
+
__all__ = ["build_argument_parser", "evaluate", "main"]
|
detree/cli/train.py
ADDED
|
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Training CLI for DETree."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import random
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Iterable, Optional
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn.functional as F # noqa: F401 # retained for backward compat with downstream imports
|
| 13 |
+
import torch.optim as optim
|
| 14 |
+
import yaml
|
| 15 |
+
from lightning import Fabric
|
| 16 |
+
from lightning.fabric.strategies import DeepSpeedStrategy, DDPStrategy
|
| 17 |
+
from torch.utils.data import DataLoader
|
| 18 |
+
from torch.utils.data.dataloader import default_collate
|
| 19 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 20 |
+
from tqdm import tqdm
|
| 21 |
+
from transformers import AutoTokenizer
|
| 22 |
+
|
| 23 |
+
from detree.model.simclr import SimCLR_Tree
|
| 24 |
+
from detree.utils.dataset import SCLDataset, load_datapath
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class ExperimentPaths:
|
| 29 |
+
"""Utility container describing where to store experiment artefacts."""
|
| 30 |
+
|
| 31 |
+
root: Path
|
| 32 |
+
runs: Path
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _build_collate_fn(tokenizer, max_length: int):
|
| 36 |
+
def collate_fn(batch: Iterable):
|
| 37 |
+
text, label, write_model = default_collate(batch)
|
| 38 |
+
encoded_batch = tokenizer.batch_encode_plus(
|
| 39 |
+
text,
|
| 40 |
+
return_tensors="pt",
|
| 41 |
+
max_length=max_length,
|
| 42 |
+
padding=True,
|
| 43 |
+
truncation=True,
|
| 44 |
+
)
|
| 45 |
+
return encoded_batch, label, write_model
|
| 46 |
+
|
| 47 |
+
return collate_fn
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _prepare_output_dir(
|
| 51 |
+
output_dir: Path, experiment_name: str, resume: bool, *, create_dirs: bool = True
|
| 52 |
+
) -> ExperimentPaths:
|
| 53 |
+
output_dir = output_dir.expanduser().resolve()
|
| 54 |
+
|
| 55 |
+
candidate = output_dir / experiment_name
|
| 56 |
+
if candidate.exists() and not resume:
|
| 57 |
+
suffix = 0
|
| 58 |
+
while (output_dir / f"{experiment_name}_v{suffix}").exists():
|
| 59 |
+
suffix += 1
|
| 60 |
+
candidate = output_dir / f"{experiment_name}_v{suffix}"
|
| 61 |
+
|
| 62 |
+
runs_dir = candidate / "runs"
|
| 63 |
+
if create_dirs:
|
| 64 |
+
candidate.mkdir(parents=True, exist_ok=True)
|
| 65 |
+
runs_dir.mkdir(parents=True, exist_ok=True)
|
| 66 |
+
|
| 67 |
+
return ExperimentPaths(root=candidate, runs=runs_dir)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def build_argument_parser() -> argparse.ArgumentParser:
|
| 71 |
+
parser = argparse.ArgumentParser(
|
| 72 |
+
description="Train DETree using the hierarchical contrastive objective",
|
| 73 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
| 74 |
+
)
|
| 75 |
+
parser.add_argument("--model-name", type=str, default="FacebookAI/roberta-large", help="Backbone encoder identifier.")
|
| 76 |
+
parser.add_argument("--device-num", type=int, default=1, help="Number of CUDA devices to use.")
|
| 77 |
+
parser.add_argument("--path", type=Path, required=True, help="Root directory of the dataset.")
|
| 78 |
+
parser.add_argument("--dataset-name", type=str, default="all", help="Dataset configuration name.")
|
| 79 |
+
parser.add_argument(
|
| 80 |
+
"--dataset", type=str, default="train", choices=("train", "test", "extra"), help="Dataset split to consume."
|
| 81 |
+
)
|
| 82 |
+
parser.add_argument("--tree-txt", type=Path, required=True, help="Tree definition file as produced by the HAT pipeline.")
|
| 83 |
+
parser.add_argument("--output-dir", type=Path, default=Path("runs"), help="Directory where experiment folders are saved.")
|
| 84 |
+
parser.add_argument("--experiment-name", type=str, default="detree_experiment", help="Base name for the run directory.")
|
| 85 |
+
parser.add_argument("--resume", action="store_true", help="Reuse the given experiment directory if it already exists.")
|
| 86 |
+
|
| 87 |
+
parser.add_argument("--projection-size", type=int, default=1024)
|
| 88 |
+
parser.add_argument("--temperature", type=float, default=0.07)
|
| 89 |
+
parser.add_argument("--num-workers", type=int, default=8)
|
| 90 |
+
parser.add_argument("--per-gpu-batch-size", type=int, default=64)
|
| 91 |
+
parser.add_argument("--per-gpu-eval-batch-size", type=int, default=16)
|
| 92 |
+
parser.add_argument("--max-length", type=int, default=512, help="Maximum sequence length for the tokenizer.")
|
| 93 |
+
parser.add_argument("--total-epoch", type=int, default=10)
|
| 94 |
+
parser.add_argument("--warmup-steps", type=int, default=2000)
|
| 95 |
+
parser.add_argument("--lr", type=float, default=3e-5)
|
| 96 |
+
parser.add_argument("--min-lr", type=float, default=5e-6)
|
| 97 |
+
parser.add_argument("--weight-decay", type=float, default=1e-4)
|
| 98 |
+
parser.add_argument("--beta1", type=float, default=0.9)
|
| 99 |
+
parser.add_argument("--beta2", type=float, default=0.99)
|
| 100 |
+
parser.add_argument("--eps", type=float, default=1e-6)
|
| 101 |
+
parser.add_argument("--adv-p", type=float, default=0.5, help="Probability of sampling adversarial data.")
|
| 102 |
+
parser.add_argument("--num-workers-eval", type=int, default=8, help="Reserved for compatibility.")
|
| 103 |
+
|
| 104 |
+
parser.add_argument("--lora-r", type=int, default=128)
|
| 105 |
+
parser.add_argument("--lora-alpha", type=int, default=256)
|
| 106 |
+
parser.add_argument("--lora-dropout", type=float, default=0.0)
|
| 107 |
+
|
| 108 |
+
parser.add_argument("--freeze-layer", type=int, default=0, help="Number of initial encoder layers to freeze.")
|
| 109 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 110 |
+
|
| 111 |
+
parser.add_argument("--adapter-path", type=Path, default=None, help="Optional path to resume LoRA training from.")
|
| 112 |
+
parser.add_argument("--pooling", type=str, default="max", choices=("max", "average", "cls"))
|
| 113 |
+
|
| 114 |
+
parser.add_argument("--lora", dest="lora", action="store_true", help="Enable LoRA adapters.")
|
| 115 |
+
parser.add_argument("--no-lora", dest="lora", action="store_false", help="Disable LoRA adapters.")
|
| 116 |
+
parser.set_defaults(lora=True)
|
| 117 |
+
|
| 118 |
+
parser.add_argument("--freeze-embedding-layer", dest="freeze_embedding_layer", action="store_true")
|
| 119 |
+
parser.add_argument("--no-freeze-embedding-layer", dest="freeze_embedding_layer", action="store_false")
|
| 120 |
+
parser.set_defaults(freeze_embedding_layer=True)
|
| 121 |
+
|
| 122 |
+
parser.add_argument("--adversarial", dest="adversarial", action="store_true")
|
| 123 |
+
parser.add_argument("--no-adversarial", dest="adversarial", action="store_false")
|
| 124 |
+
parser.set_defaults(adversarial=True)
|
| 125 |
+
|
| 126 |
+
parser.add_argument("--include-attack", dest="include_attack", action="store_true")
|
| 127 |
+
parser.add_argument("--no-include-attack", dest="include_attack", action="store_false")
|
| 128 |
+
parser.set_defaults(include_attack=True)
|
| 129 |
+
|
| 130 |
+
parser.add_argument("--has-mix", dest="has_mix", action="store_true")
|
| 131 |
+
parser.add_argument("--no-has-mix", dest="has_mix", action="store_false")
|
| 132 |
+
parser.set_defaults(has_mix=True)
|
| 133 |
+
|
| 134 |
+
parser.add_argument("--deepspeed", action="store_true", help="Use DeepSpeed strategy when multiple GPUs are available.")
|
| 135 |
+
|
| 136 |
+
return parser
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def train(args: argparse.Namespace) -> None:
|
| 140 |
+
torch.manual_seed(args.seed)
|
| 141 |
+
random.seed(args.seed)
|
| 142 |
+
torch.set_float32_matmul_precision("medium")
|
| 143 |
+
|
| 144 |
+
if args.device_num > 1:
|
| 145 |
+
if args.deepspeed:
|
| 146 |
+
strategy = DeepSpeedStrategy()
|
| 147 |
+
else:
|
| 148 |
+
strategy = DDPStrategy(find_unused_parameters=True)
|
| 149 |
+
fabric = Fabric(accelerator="cuda", precision="bf16-mixed", devices=args.device_num, strategy=strategy)
|
| 150 |
+
else:
|
| 151 |
+
fabric = Fabric(accelerator="cuda", precision="bf16-mixed", devices=args.device_num)
|
| 152 |
+
|
| 153 |
+
fabric.launch()
|
| 154 |
+
|
| 155 |
+
experiment_paths = ExperimentPaths(root=Path(args.output_dir), runs=Path(args.runs_dir))
|
| 156 |
+
if fabric.global_rank == 0:
|
| 157 |
+
experiment_paths.root.mkdir(parents=True, exist_ok=True)
|
| 158 |
+
experiment_paths.runs.mkdir(parents=True, exist_ok=True)
|
| 159 |
+
fabric.barrier()
|
| 160 |
+
|
| 161 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
| 162 |
+
collate_fn = _build_collate_fn(tokenizer, args.max_length)
|
| 163 |
+
|
| 164 |
+
model = SimCLR_Tree(args, fabric).train()
|
| 165 |
+
|
| 166 |
+
data_path = load_datapath(
|
| 167 |
+
str(args.path),
|
| 168 |
+
include_adversarial=args.adversarial,
|
| 169 |
+
dataset_name=args.dataset_name,
|
| 170 |
+
include_attack=args.include_attack,
|
| 171 |
+
)[args.dataset]
|
| 172 |
+
|
| 173 |
+
train_dataset = SCLDataset(
|
| 174 |
+
data_path,
|
| 175 |
+
fabric,
|
| 176 |
+
tokenizer,
|
| 177 |
+
name2id=model.names2id,
|
| 178 |
+
has_mix=args.has_mix,
|
| 179 |
+
adv_p=args.adv_p,
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
passages_dataloader = DataLoader(
|
| 183 |
+
train_dataset,
|
| 184 |
+
batch_size=args.per_gpu_batch_size,
|
| 185 |
+
num_workers=args.num_workers,
|
| 186 |
+
pin_memory=True,
|
| 187 |
+
shuffle=True,
|
| 188 |
+
drop_last=True,
|
| 189 |
+
collate_fn=collate_fn,
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
model.train()
|
| 193 |
+
if args.freeze_embedding_layer:
|
| 194 |
+
for name, param in model.model.named_parameters():
|
| 195 |
+
if "emb" in name or "model.pooler" in name:
|
| 196 |
+
param.requires_grad = False
|
| 197 |
+
if args.freeze_layer > 0:
|
| 198 |
+
for i in range(args.freeze_layer):
|
| 199 |
+
if f"encoder.layer.{i}." in name:
|
| 200 |
+
param.requires_grad = False
|
| 201 |
+
|
| 202 |
+
model = torch.compile(model)
|
| 203 |
+
if fabric.global_rank == 0:
|
| 204 |
+
print("Model has been initialized!")
|
| 205 |
+
for name, param in model.model.named_parameters():
|
| 206 |
+
print(name, param.requires_grad)
|
| 207 |
+
|
| 208 |
+
passages_dataloader = fabric.setup_dataloaders(passages_dataloader, use_distributed_sampler=False)
|
| 209 |
+
if fabric.global_rank == 0:
|
| 210 |
+
print("DataLoader has been initialized!")
|
| 211 |
+
|
| 212 |
+
if fabric.global_rank == 0:
|
| 213 |
+
writer = SummaryWriter(str(experiment_paths.runs))
|
| 214 |
+
print(f"Save dir is {args.output_dir}")
|
| 215 |
+
opt_dict = vars(args)
|
| 216 |
+
opt_dict["output_dir"] = str(args.output_dir)
|
| 217 |
+
with open(Path(args.output_dir) / "config.yaml", "w", encoding="utf-8") as file:
|
| 218 |
+
yaml.dump(opt_dict, file, sort_keys=False)
|
| 219 |
+
else:
|
| 220 |
+
writer = None
|
| 221 |
+
|
| 222 |
+
experiment_dir = experiment_paths.root
|
| 223 |
+
|
| 224 |
+
num_batches_per_epoch = len(passages_dataloader)
|
| 225 |
+
warmup_steps = args.warmup_steps
|
| 226 |
+
lr = args.lr
|
| 227 |
+
total_steps = args.total_epoch * num_batches_per_epoch - warmup_steps
|
| 228 |
+
|
| 229 |
+
optimizer = optim.AdamW(
|
| 230 |
+
filter(lambda p: p.requires_grad, model.parameters()),
|
| 231 |
+
lr=args.lr,
|
| 232 |
+
betas=(args.beta1, args.beta2),
|
| 233 |
+
eps=args.eps,
|
| 234 |
+
weight_decay=args.weight_decay,
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
schedule = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, total_steps, eta_min=args.min_lr)
|
| 238 |
+
model, optimizer = fabric.setup(model, optimizer)
|
| 239 |
+
|
| 240 |
+
if fabric.global_rank == 0:
|
| 241 |
+
for name, param in model.named_parameters():
|
| 242 |
+
if param.requires_grad:
|
| 243 |
+
print(name, param.requires_grad)
|
| 244 |
+
|
| 245 |
+
for epoch in range(args.total_epoch):
|
| 246 |
+
model.train()
|
| 247 |
+
avg_loss = 0.0
|
| 248 |
+
iterator = enumerate(passages_dataloader)
|
| 249 |
+
if fabric.global_rank == 0:
|
| 250 |
+
iterator = tqdm(iterator, total=len(passages_dataloader))
|
| 251 |
+
print(("\n" + "%11s" * 5) % ("Epoch", "GPU_mem", "loss1", "Avgloss", "lr"))
|
| 252 |
+
for i, batch in iterator:
|
| 253 |
+
current_step = epoch * num_batches_per_epoch + i
|
| 254 |
+
if current_step < warmup_steps:
|
| 255 |
+
current_lr = lr * current_step / max(warmup_steps, 1)
|
| 256 |
+
for param_group in optimizer.param_groups:
|
| 257 |
+
param_group["lr"] = current_lr
|
| 258 |
+
current_lr = optimizer.param_groups[0]["lr"]
|
| 259 |
+
|
| 260 |
+
encoded_batch, label, write_model = batch
|
| 261 |
+
loss, loss_classify = model(encoded_batch, write_model)
|
| 262 |
+
|
| 263 |
+
avg_loss = (avg_loss * i + loss.item()) / (i + 1)
|
| 264 |
+
fabric.backward(loss)
|
| 265 |
+
optimizer.step()
|
| 266 |
+
optimizer.zero_grad()
|
| 267 |
+
if current_step >= warmup_steps:
|
| 268 |
+
schedule.step()
|
| 269 |
+
|
| 270 |
+
mem = f"{torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0:.3g}G"
|
| 271 |
+
if fabric.global_rank == 0:
|
| 272 |
+
iterator.set_description(
|
| 273 |
+
("%11s" * 2 + "%11.4g" * 3)
|
| 274 |
+
% (f"{epoch + 1}/{args.total_epoch}", mem, loss_classify.item(), avg_loss, current_lr)
|
| 275 |
+
)
|
| 276 |
+
if writer and current_step % 10 == 0:
|
| 277 |
+
writer.add_scalar("lr", current_lr, current_step)
|
| 278 |
+
writer.add_scalar("loss", loss.item(), current_step)
|
| 279 |
+
writer.add_scalar("avg_loss", avg_loss, current_step)
|
| 280 |
+
writer.add_scalar("loss_classify", loss_classify.item(), current_step)
|
| 281 |
+
|
| 282 |
+
if fabric.global_rank == 0:
|
| 283 |
+
checkpoint_dir = experiment_dir / f"epoch_{epoch:02d}"
|
| 284 |
+
model.save_pretrained(str(checkpoint_dir), save_tokenizer=(epoch == 0))
|
| 285 |
+
print(f"Saved adapter checkpoint to {checkpoint_dir}", flush=True)
|
| 286 |
+
|
| 287 |
+
last_dir = experiment_dir / "last"
|
| 288 |
+
model.save_pretrained(str(last_dir), save_tokenizer=False)
|
| 289 |
+
print(f"Updated latest checkpoint at {last_dir}", flush=True)
|
| 290 |
+
|
| 291 |
+
fabric.barrier()
|
| 292 |
+
|
| 293 |
+
if writer:
|
| 294 |
+
writer.flush()
|
| 295 |
+
writer.close()
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
def main(argv: Optional[Iterable[str]] = None) -> None:
|
| 299 |
+
parser = build_argument_parser()
|
| 300 |
+
args = parser.parse_args(argv)
|
| 301 |
+
experiment_paths = _prepare_output_dir(
|
| 302 |
+
args.output_dir, args.experiment_name, resume=args.resume, create_dirs=False
|
| 303 |
+
)
|
| 304 |
+
args.output_dir = str(experiment_paths.root)
|
| 305 |
+
args.runs_dir = str(experiment_paths.runs)
|
| 306 |
+
train(args)
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
__all__ = ["build_argument_parser", "main", "train"]
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
if __name__ == "__main__":
|
| 313 |
+
main()
|