MAS-AI-0000 commited on
Commit
4d939fc
·
verified ·
1 Parent(s): 1736cee

Upload 9 files

Browse files
detree/cli/database.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Generate clustered prototype databases from embeddings."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ from pathlib import Path
7
+ from typing import Iterable, Optional
8
+
9
+ import faiss
10
+ import numpy as np
11
+ import torch
12
+
13
+
14
+ class GPUKMeansClusterer:
15
+ def __init__(self, dim: int, n_clusters: int = 500, n_iter: int = 20, n_gpu: int = 1):
16
+ self.clus = faiss.Clustering(dim, n_clusters)
17
+ self.clus.verbose = True
18
+ self.clus.niter = n_iter
19
+ self.dim = dim
20
+ self.n_clusters = n_clusters
21
+ self.clus.update_index = True
22
+
23
+ res = [faiss.StandardGpuResources() for _ in range(n_gpu)]
24
+ flat_config = []
25
+ for i in range(n_gpu):
26
+ cfg = faiss.GpuIndexFlatConfig()
27
+ cfg.useFloat16 = False
28
+ cfg.device = i
29
+ flat_config.append(cfg)
30
+
31
+ if n_gpu == 1:
32
+ self.index = faiss.GpuIndexFlatL2(res[0], self.dim, flat_config[0])
33
+ else:
34
+ indexes = [faiss.GpuIndexFlatL2(res[i], self.dim, flat_config[i]) for i in range(n_gpu)]
35
+ self.index = faiss.IndexReplicas()
36
+ for sub_index in indexes:
37
+ self.index.addIndex(sub_index)
38
+
39
+ def fit(self, embeddings_np: np.ndarray) -> np.ndarray:
40
+ self.index.reset()
41
+ self.clus.train(embeddings_np, self.index)
42
+ centroids = faiss.vector_float_to_array(self.clus.centroids)
43
+ centroids = centroids.reshape(self.n_clusters, self.dim)
44
+ return centroids
45
+
46
+
47
+ def gen_data(dict_data):
48
+ embeddings = dict_data["embeddings"]
49
+ labels = dict_data["labels"]
50
+ ids = dict_data["ids"]
51
+ classes = dict_data["classes"]
52
+ return embeddings, labels, ids, classes
53
+
54
+
55
+ def build_argument_parser() -> argparse.ArgumentParser:
56
+ parser = argparse.ArgumentParser(
57
+ description="Cluster embeddings into prototype databases using GPU K-Means.",
58
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
59
+ )
60
+ parser.add_argument("--database", type=Path, required=True, help="Input embedding database (.pt).")
61
+ parser.add_argument("--output", type=Path, required=True, help="Output path for the clustered database.")
62
+ parser.add_argument("--clusters", type=int, default=10000)
63
+ parser.add_argument("--dimension", type=int, default=1024)
64
+ parser.add_argument("--iterations", type=int, default=100)
65
+ parser.add_argument("--gpus", type=int, default=1)
66
+ parser.add_argument("--human-class-name", type=str, default="human", help="Label representing humans in the class list.")
67
+ return parser
68
+
69
+
70
+ def cluster_database(args: argparse.Namespace) -> None:
71
+ data_emb, data_labels, data_ids, data_classes = gen_data(torch.load(args.database))
72
+ human_idx = data_classes.index(args.human_class_name)
73
+ datapos = (data_labels == human_idx).long()
74
+ pos2cnt = {0: args.clusters, 1: args.clusters}
75
+ pos2name = {0: ["llm"], 1: ["human"]}
76
+
77
+ datapos_np = datapos.cpu().numpy()
78
+ kmeans = GPUKMeansClusterer(args.dimension, n_clusters=args.clusters, n_iter=args.iterations, n_gpu=args.gpus)
79
+ all_centers = {}
80
+ save_labels = None
81
+ for key in data_emb:
82
+ now_emb = data_emb[key].float().cpu().numpy()
83
+ all_center = []
84
+ all_labels = []
85
+ for pos in pos2cnt:
86
+ pos_emb = now_emb[datapos_np == pos]
87
+ pos_center = kmeans.fit(pos_emb)
88
+ all_center.append(pos_center)
89
+ all_labels.append(np.full((pos_center.shape[0],), pos))
90
+ all_center = np.concatenate(all_center, axis=0)
91
+ all_labels = np.concatenate(all_labels, axis=0)
92
+ all_center = torch.from_numpy(all_center).to(dtype=torch.bfloat16)
93
+ all_labels = torch.from_numpy(all_labels).to(dtype=torch.long)
94
+ all_centers[key] = all_center
95
+ save_labels = all_labels
96
+
97
+ save_ids = torch.arange(save_labels.shape[0], dtype=torch.long)
98
+ classes = [None] * len(pos2name.keys())
99
+ for pos in pos2name:
100
+ classes[pos] = ','.join(pos2name[pos])
101
+
102
+ emb_dict = {"embeddings": all_centers, "labels": save_labels, "ids": save_ids, "classes": classes}
103
+ args.output.parent.mkdir(parents=True, exist_ok=True)
104
+ torch.save(emb_dict, args.output)
105
+ print(f"All centers saved to: {args.output}")
106
+
107
+
108
+ def main(argv: Optional[Iterable[str]] = None) -> None:
109
+ parser = build_argument_parser()
110
+ args = parser.parse_args(argv)
111
+ cluster_database(args)
112
+
113
+
114
+ if __name__ == "__main__":
115
+ main()
116
+
117
+ __all__ = ["build_argument_parser", "cluster_database", "main"]
detree/cli/embeddings.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Embedding generation CLI for DETree."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ from pathlib import Path
7
+ from typing import Iterable, Literal, Optional
8
+
9
+ import pandas as pd
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from lightning import Fabric
13
+ from torch.utils.data import DataLoader, Dataset
14
+ from tqdm import tqdm
15
+
16
+ from detree.model.text_embedding import TextEmbeddingModel
17
+ from detree.utils.dataset import SCLDataset, load_datapath
18
+
19
+
20
+ def infer(passages_dataloader, fabric, tokenizer, model, args):
21
+ if fabric.global_rank == 0:
22
+ passages_dataloader = tqdm(passages_dataloader)
23
+ all_ids, all_embeddings, all_labels = [], {}, []
24
+ for layer in args.need_layer:
25
+ all_embeddings[layer] = []
26
+ with torch.no_grad():
27
+ for batch in passages_dataloader:
28
+ text, label, write_model, ids = batch
29
+ encoded_batch = tokenizer.batch_encode_plus(
30
+ text,
31
+ return_tensors="pt",
32
+ max_length=args.max_length,
33
+ padding="max_length",
34
+ truncation=True,
35
+ )
36
+ encoded_batch = {k: v.cuda() for k, v in encoded_batch.items()}
37
+ embeddings = model(encoded_batch, hidden_states=True)
38
+ embeddings = fabric.all_gather(embeddings).view(-1, embeddings.size(-2), embeddings.size(-1))
39
+ label = fabric.all_gather(write_model).view(-1)
40
+ ids = fabric.all_gather(ids).view(-1)
41
+ if fabric.global_rank == 0:
42
+ embeddings = F.normalize(embeddings, dim=-1).cpu().to(torch.bfloat16)
43
+ for layer in args.need_layer:
44
+ all_embeddings[layer].append(embeddings[:, layer, :].clone())
45
+ all_ids.extend(ids.cpu().tolist())
46
+ all_labels.extend(label.cpu().tolist())
47
+ del embeddings, label, ids
48
+ if fabric.global_rank == 0:
49
+ for layer in args.need_layer:
50
+ all_embeddings[layer] = torch.cat(all_embeddings[layer], dim=0)
51
+ return torch.tensor(all_ids), all_embeddings, torch.tensor(all_labels)
52
+ return [], [], []
53
+
54
+
55
+ def stable_long_hash(input_string: str) -> int:
56
+ import hashlib
57
+
58
+ hash_object = hashlib.sha256(input_string.encode())
59
+ hex_digest = hash_object.hexdigest()
60
+ int_hash = int(hex_digest, 16)
61
+ return int_hash & ((1 << 63) - 1)
62
+
63
+
64
+ def load_data(split: Literal["train", "test", "extra"], include_adversarial: bool, fp: Path) -> pd.DataFrame:
65
+ if split not in ("train", "test", "extra"):
66
+ raise ValueError("`split` must be one of (\"train\", \"test\", \"extra\")")
67
+
68
+ fname = f"{split}.csv" if include_adversarial else f"{split}_none.csv"
69
+ fp = fp / fname
70
+ return pd.read_csv(fp)
71
+
72
+
73
+ class PassagesDataset(Dataset):
74
+ def __init__(self, data):
75
+ self.passages = []
76
+ for item in data:
77
+ if item["attack"] not in ("none", "paraphrase") and stable_long_hash(item["generation"]) % 10 < 5:
78
+ continue
79
+ self.passages.append(item)
80
+ classes = sorted({item["model"] for item in data})
81
+ self.classes = list(classes)
82
+ self.human_id = self.classes.index("human")
83
+
84
+ def __len__(self):
85
+ return len(self.passages)
86
+
87
+ def __getitem__(self, idx):
88
+ data_now = self.passages[idx]
89
+ text = data_now["generation"]
90
+ model = self.classes.index(data_now["model"])
91
+ label = int(model == self.human_id)
92
+ ids = stable_long_hash(text)
93
+ return text, int(label), int(model), int(ids)
94
+
95
+
96
+ def build_argument_parser() -> argparse.ArgumentParser:
97
+ parser = argparse.ArgumentParser(
98
+ description="Generate embedding databases for DETree evaluators",
99
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
100
+ )
101
+ parser.add_argument("--device-num", type=int, default=1)
102
+ parser.add_argument("--batch-size", type=int, default=64)
103
+ parser.add_argument("--num-workers", type=int, default=8)
104
+ parser.add_argument("--max-length", type=int, default=512)
105
+
106
+ parser.add_argument("--path", type=Path, required=True, help="Dataset root directory or JSONL file path.")
107
+ parser.add_argument("--database-name", type=str, default="M4_monolingual")
108
+ parser.add_argument(
109
+ "--model-name",
110
+ type=str,
111
+ default="FacebookAI/roberta-large",
112
+ help=(
113
+ "Model identifier for embeddings generation. Accepts either a Hugging Face "
114
+ "model hub name or a local path to a directory in Hugging Face format."
115
+ ),
116
+ )
117
+
118
+ parser.add_argument("--pooling", type=str, default="max", choices=("max", "average", "cls"))
119
+ parser.add_argument("--need-layer", type=int, nargs="+", default=[16, 17, 18, 19, 22, 23])
120
+
121
+ parser.add_argument("--adversarial", dest="adversarial", action="store_true")
122
+ parser.add_argument("--no-adversarial", dest="adversarial", action="store_false")
123
+ parser.set_defaults(adversarial=True)
124
+
125
+ parser.add_argument("--has-mix", dest="has_mix", action="store_true")
126
+ parser.add_argument("--no-has-mix", dest="has_mix", action="store_false")
127
+ parser.set_defaults(has_mix=False)
128
+
129
+ parser.add_argument("--savedir", type=Path, required=True, help="Output directory for the embedding database.")
130
+ parser.add_argument("--name", type=str, required=True, help="Filename (without extension) for the saved embeddings.")
131
+ parser.add_argument("--split", type=str, default="train", choices=("train", "test", "extra"))
132
+
133
+ return parser
134
+
135
+
136
+ def generate_embeddings(args: argparse.Namespace) -> None:
137
+ if args.device_num > 1:
138
+ fabric = Fabric(accelerator="cuda", precision="bf16-mixed", devices=args.device_num, strategy="ddp")
139
+ else:
140
+ fabric = Fabric(accelerator="cuda", precision="bf16-mixed", devices=args.device_num)
141
+ fabric.launch()
142
+
143
+ model = TextEmbeddingModel(
144
+ args.model_name,
145
+ output_hidden_states=True,
146
+ infer=True,
147
+ use_pooling=args.pooling,
148
+ ).cuda()
149
+ tokenizer = model.tokenizer
150
+ model.eval()
151
+
152
+ path_str = str(args.path)
153
+ if "LLM_detect_data" in path_str:
154
+ now_data = load_data(args.split, include_adversarial=args.adversarial, fp=args.path)
155
+ now_data = now_data.to_dict(orient="records")
156
+ dataset = PassagesDataset(now_data)
157
+ dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers)
158
+ dataloader = fabric.setup_dataloaders(dataloader)
159
+ elif path_str.endswith(".jsonl"):
160
+ dataset = SCLDataset([path_str], fabric, tokenizer, need_ids=True, adv_p=0)
161
+ dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers)
162
+ dataloader = fabric.setup_dataloaders(dataloader, use_distributed_sampler=False)
163
+ else:
164
+ data_path = load_datapath(
165
+ path_str,
166
+ include_adversarial=args.adversarial,
167
+ dataset_name=args.database_name,
168
+ )[args.split]
169
+ dataset = SCLDataset(data_path, fabric, tokenizer, need_ids=True, adv_p=0, has_mix=args.has_mix)
170
+ dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers)
171
+ dataloader = fabric.setup_dataloaders(dataloader, use_distributed_sampler=False)
172
+
173
+ model = fabric.setup(model)
174
+ classes = dataset.classes
175
+ train_ids, train_embeddings, train_labels = infer(dataloader, fabric, tokenizer, model, args)
176
+
177
+ torch.cuda.empty_cache()
178
+ if fabric.global_rank == 0:
179
+ args.savedir.mkdir(parents=True, exist_ok=True)
180
+ emb_dict = {
181
+ "embeddings": train_embeddings,
182
+ "labels": train_labels,
183
+ "ids": train_ids,
184
+ "classes": classes,
185
+ }
186
+ output_path = args.savedir / f"{args.name}.pt"
187
+ torch.save(emb_dict, output_path)
188
+ print(f"Saved embedding database to {output_path}")
189
+
190
+
191
+ def main(argv: Optional[Iterable[str]] = None) -> None:
192
+ parser = build_argument_parser()
193
+ args = parser.parse_args(argv)
194
+ generate_embeddings(args)
195
+
196
+
197
+ if __name__ == "__main__":
198
+ main()
199
+
200
+ __all__ = ["build_argument_parser", "generate_embeddings", "main"]
detree/cli/gen_tree.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tree generation CLI utilities for DETree."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import json
7
+ from pathlib import Path
8
+ from typing import Iterable, Sequence, Set
9
+
10
+ from detree.utils.dataset import load_datapath, model_alias_mapping
11
+
12
+
13
+ def _str2bool(value: str) -> bool:
14
+ """Parse common textual boolean representations used by legacy scripts."""
15
+
16
+ if isinstance(value, bool):
17
+ return value
18
+ lowered = value.lower()
19
+ if lowered in {"true", "1", "yes", "y"}:
20
+ return True
21
+ if lowered in {"false", "0", "no", "n"}:
22
+ return False
23
+ raise argparse.ArgumentTypeError(f"Boolean value expected, got: {value}")
24
+
25
+
26
+ def get_data_model(data_path: Iterable[Path], has_mix: bool = True) -> Set[str]:
27
+ """Collect all model identifiers present in the provided dataset paths."""
28
+
29
+ llm_name: Set[str] = set()
30
+ cnt = 0
31
+ for path in data_path:
32
+ print(f"reading {path}")
33
+ with path.open(mode="r", encoding="utf-8") as jsonl_file:
34
+ for line in jsonl_file:
35
+ now = json.loads(line)
36
+ if now["src"] not in model_alias_mapping:
37
+ model_alias_mapping[now["src"]] = now["src"]
38
+ now["src"] = model_alias_mapping[now["src"]]
39
+ if not has_mix and "human" in now["src"] and now["src"] != "human":
40
+ continue
41
+ if now["src"] not in llm_name:
42
+ llm_name.add(now["src"])
43
+ cnt += 1
44
+ print(cnt)
45
+ return llm_name
46
+
47
+
48
+ def build_argument_parser() -> argparse.ArgumentParser:
49
+ """Create the argument parser for the tree generation CLI."""
50
+
51
+ parser = argparse.ArgumentParser(
52
+ description="Generate DETree-compatible tree definitions from dataset files.",
53
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
54
+ )
55
+ parser.add_argument("--path", type=Path, default=Path("/opt/AI-text-Dataset"), help="Root directory of the dataset.")
56
+ parser.add_argument("--dataset_name", type=str, default="all", help="Dataset configuration name.")
57
+ parser.add_argument(
58
+ "--mode",
59
+ type=str,
60
+ choices=("train", "test", "extra"),
61
+ default="train",
62
+ help="Dataset split to consume.",
63
+ )
64
+ parser.add_argument("--tree_txt", type=Path, default=Path("output/Tree_RAID_pcl.txt"), help="Output tree definition path.")
65
+ parser.add_argument("--adversarial", type=_str2bool, default=True, help="Whether to include adversarial data splits.")
66
+ parser.add_argument("--has_mix", type=_str2bool, default=True, help="Whether to keep mixed human/model generations.")
67
+ return parser
68
+
69
+
70
+ def main(args: argparse.Namespace) -> None:
71
+ """Entry point for building DETree-compatible tree structures."""
72
+
73
+ dataset_paths: Sequence[str] = load_datapath(args.path, args.adversarial, args.dataset_name)[args.mode]
74
+ print(f"data_path: {dataset_paths}")
75
+ llm_name = sorted(get_data_model((Path(p) for p in dataset_paths), args.has_mix))
76
+ root = len(llm_name)
77
+ args.tree_txt.parent.mkdir(parents=True, exist_ok=True)
78
+ with args.tree_txt.open("w", encoding="utf-8") as f:
79
+ for i, item in enumerate(llm_name):
80
+ f.write(f"{i} {root} {item}\n")
81
+ f.write(f"{root} -1 none\n")
82
+
83
+
84
+ if __name__ == "__main__":
85
+ parser = build_argument_parser()
86
+ main(parser.parse_args())
detree/cli/hierarchical_clustering.py ADDED
@@ -0,0 +1,497 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import random
3
+ from pathlib import Path
4
+ from typing import Iterable, Optional
5
+
6
+ import matplotlib.cm as cm
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ from scipy.cluster.hierarchy import dendrogram, linkage
10
+ from scipy.spatial.distance import euclidean, squareform
11
+ from sklearn.metrics import silhouette_score
12
+
13
+ def read_similarity_matrix(file_path: Path):
14
+ with file_path.open('r', encoding='utf-8') as f:
15
+ lines = f.readlines()
16
+ names = lines[0].strip().split()
17
+ matrix = []
18
+
19
+ for line in lines[1:]:
20
+ row = line.strip().split()[1:]
21
+ matrix.append([float(x) for x in row])
22
+
23
+ similarity_matrix = np.array(matrix)
24
+ return names, similarity_matrix
25
+
26
+ class TreeNode:
27
+ def __init__(self, name=None):
28
+
29
+ self.name = name
30
+ self.children = []
31
+ self.value = 0
32
+ self.split = True
33
+
34
+ def add_child(self, child):
35
+ self.children.append(child)
36
+
37
+ def build_tree(Z, names):
38
+ nodes = [TreeNode(name) for name in names]
39
+ for i, link in enumerate(Z):
40
+ node = TreeNode()
41
+ node.value = link[2]
42
+ node.add_child(int(link[0]))
43
+ node.add_child(int(link[1]))
44
+ nodes.append(node)
45
+ return nodes
46
+
47
+ def find_best_thold(node_idx,nodes, distance_matrix,min_socre=0,max_socre=1):
48
+ node = nodes[node_idx]
49
+ threshold_range = np.linspace(min_socre * node.value, max_socre * node.value, 50)
50
+ silhouette_scores = []
51
+ all_n_clusters = []
52
+
53
+ for threshold in threshold_range:
54
+ labels,_ = gen_label_from_node(node_idx,nodes,threshold)
55
+ labels = sorted(labels,key=lambda x:x[1])
56
+ labels = [x[0] for x in labels]
57
+ n_clusters = len(np.unique(labels))
58
+ if n_clusters > 1 and n_clusters < len(distance_matrix):
59
+ score = silhouette_score(distance_matrix, labels, metric='precomputed')
60
+ else:
61
+ score = -1
62
+ silhouette_scores.append(score)
63
+ all_n_clusters.append(n_clusters)
64
+ best_threshold_idx = np.argmax(silhouette_scores)
65
+ best_threshold = threshold_range[best_threshold_idx]
66
+ best_score = silhouette_scores[best_threshold_idx]
67
+ return best_threshold, best_score
68
+
69
+ def gen_label_from_node(node_idx,nodes,thd,now_label=0):
70
+ node = nodes[node_idx]
71
+ if len(node.children)==0:
72
+ return [(now_label,node_idx)],now_label
73
+ else:
74
+ if node.value>thd:
75
+ label_list = []
76
+ for child in node.children:
77
+ now_label_list,now_label = gen_label_from_node(child,nodes,thd,now_label)
78
+ now_label+=1
79
+ label_list+=now_label_list
80
+ return label_list,now_label
81
+ else:
82
+ label_list = []
83
+ for child in node.children:
84
+ now_label_list,now_label = gen_label_from_node(child,nodes,thd,now_label)
85
+ label_list+=now_label_list
86
+ return label_list,now_label
87
+
88
+ def find_new_root(node_idx,nodes,thd):
89
+ node = nodes[node_idx]
90
+ if node.value<=thd:
91
+ return [node_idx]
92
+
93
+ new_root = []
94
+ for child in node.children:
95
+ new_root+=find_new_root(child,nodes,thd)
96
+ return new_root
97
+
98
+ def get_leaf(node_idx,nodes):
99
+ node = nodes[node_idx]
100
+ if len(node.children)==0:
101
+ return [node_idx]
102
+
103
+ leaf_list = []
104
+ for child in node.children:
105
+ leaf_list+=get_leaf(child,nodes)
106
+ return leaf_list
107
+
108
+ def merge_tree(node_idx,nodes,distance_matrix,deep=0,end_thd=0.25):
109
+ if len(nodes[node_idx].children)==0:
110
+ return
111
+ print(f"Node {node_idx}: Value: {nodes[node_idx].value}, Depth: {deep}")
112
+ if nodes[node_idx].value<=end_thd or deep>=5:
113
+ nodes[node_idx].children = get_leaf(node_idx,nodes)
114
+ nodes[node_idx].split = False
115
+ return
116
+ leaf_list = np.array(sorted(get_leaf(node_idx,nodes)))
117
+ new_distance_matrix = distance_matrix[leaf_list][:,leaf_list]
118
+ best_threshold, best_score = find_best_thold(node_idx, nodes, new_distance_matrix,min_socre=0)
119
+ if best_score==-1:
120
+ nodes[node_idx].children = get_leaf(node_idx,nodes)
121
+ return
122
+ new_root = find_new_root(node_idx,nodes,best_threshold)
123
+ nodes[node_idx].children = new_root
124
+
125
+ for child in new_root:
126
+ merge_tree(child,nodes,distance_matrix,deep=deep+1,end_thd=end_thd)
127
+
128
+ def merge_dict(a,b):
129
+ for key in b.keys():
130
+ if key in a.keys():
131
+ a[key]+=b[key]
132
+ else:
133
+ a[key] = b[key]
134
+ return a
135
+
136
+ def update_tree(node_idx, nodes, edge_list, fa=-1, deep=0):
137
+ node = nodes[node_idx]
138
+
139
+ if len(node.children)==0:
140
+ edge_list.append((fa,node_idx,[nodes[node_idx].name]))
141
+ return {deep:[[node_idx]]}
142
+
143
+ if node.split==False:
144
+ leafs = get_leaf(node_idx,nodes)
145
+ edge_list.append((fa,node_idx,[nodes[idx].name for idx in leafs]))
146
+ return {deep:[leafs]}
147
+
148
+ edge_list.append((fa,node_idx,[]))
149
+ new_tree = {}
150
+ for child in node.children:
151
+ new_tree = merge_dict(
152
+ new_tree,
153
+ update_tree(child, nodes, edge_list, node_idx, deep=deep+1),
154
+ )
155
+ if deep not in new_tree.keys():
156
+ new_tree[deep] = []
157
+ new_tree[deep].append(get_leaf(node_idx,nodes))
158
+
159
+ return new_tree
160
+
161
+ def color_distance(c1, c2):
162
+ return euclidean(c1[:3], c2[:3]) # only consider the RGB components
163
+
164
+ def ensure_color_diversity(colors, min_distance=0.2):
165
+ random.shuffle(colors)
166
+ for i in range(1, len(colors)):
167
+ if color_distance(colors[i], colors[i-1]) < min_distance:
168
+ for j in range(i + 1, len(colors)):
169
+ if color_distance(colors[i], colors[j]) > min_distance:
170
+ colors[i], colors[j] = colors[j], colors[i]
171
+ break
172
+ return colors
173
+
174
+
175
+ def draw_table(new_tree, names, max_deep=3, save_path='fig/E/test.pdf'):
176
+ base_list = new_tree[0][0]
177
+ data = [base_list]
178
+ cmap = cm.get_cmap('tab20c', 2048)
179
+ cmap = [cmap(i) for i in range(2048)]
180
+ cmap = ensure_color_diversity(cmap)
181
+ cell_colours = [['#FFDDC1' for _ in base_list]]
182
+ color_start=0
183
+
184
+ for i in range(1,max_deep+1):
185
+ if i not in new_tree.keys():
186
+ print(f"Level {i} not in new_tree")
187
+ continue
188
+ data.append([names[base] for base in base_list])
189
+ color_list = []
190
+ for k,base in enumerate(base_list):
191
+ color_id = -1
192
+ for j in range(len(new_tree[i])):
193
+ if base in new_tree[i][j]:
194
+ color_id = j
195
+ break
196
+ if color_id==-1:
197
+ color_list.append(cell_colours[-1][k])
198
+ else:
199
+ color_list.append(cmap[color_start+color_id])
200
+ cell_colours.append(color_list)
201
+ color_start+=len(new_tree[i])
202
+
203
+ data = list(zip(*data))
204
+ cell_colours = list(zip(*cell_colours))
205
+ columns = ['Node ID']+['Level {}'.format(i) for i in range(1,max_deep+1)]
206
+ plt.figure(figsize=(30, 40))
207
+ table = plt.table(cellText=data, colLabels=columns, loc='center', cellLoc='center',
208
+ colColours=['#f5f5f5']*len(columns),cellColours=cell_colours)
209
+ table.auto_set_column_width([0, 1])
210
+ plt.axis('off')
211
+ plt.savefig(save_path, format='pdf' ,bbox_inches='tight',pad_inches=0.01)
212
+
213
+ def fix_asymmetry(matrix):
214
+ matrix = (matrix + matrix.T) / 2
215
+ return matrix
216
+
217
+ def rename(edge):
218
+ cnt=0
219
+ reid={}
220
+ du={}
221
+ edge_dict={}
222
+ queue=[]
223
+ for i in range(len(edge)):
224
+ du[edge[i][0]]=du.get(edge[i][0],0)+1
225
+ edge_dict[edge[i][1]]=edge[i]
226
+ if edge[i][2] != []:
227
+ queue.append(edge[i][1])
228
+ while len(queue)>0:
229
+ now = queue.pop(0)
230
+ if now==-1:
231
+ reid[now]=-1
232
+ continue
233
+ if now not in reid.keys():
234
+ reid[now]=cnt
235
+ cnt+=1
236
+ now_edge = edge_dict[now]
237
+ du[now_edge[0]]-=1
238
+ if du[now_edge[0]]==0:
239
+ queue.append(now_edge[0])
240
+ new_edge = [(reid[x[0]],reid[x[1]],x[2]) for x in edge]
241
+ return new_edge
242
+
243
+ def save_edge(edge,save_path):
244
+ with open(save_path,'w') as f:
245
+ for e in edge:
246
+ if e[2]:
247
+ name_str = ','.join(e[2])
248
+ else:
249
+ name_str = 'none'
250
+ f.write(f"{e[1]} {e[0]} {name_str}\n")
251
+
252
+ def filter_class(names, similarity_matrix):
253
+ choose_idx = []
254
+ for i in range(len(names)):
255
+ if 'extend' not in names[i] and 'polish' not in names[i] and\
256
+ 'translate' not in names[i] and 'paraphrase' not in names[i]:
257
+ if 'B' in names[i] or 'human' in names[i]:
258
+ choose_idx.append(i)
259
+ else:
260
+ if random.random()<0.3:
261
+ choose_idx.append(i)
262
+ elif 'human' in names[i]:
263
+ if random.random()<0.3:
264
+ choose_idx.append(i)
265
+ elif random.random()<0.15:
266
+ choose_idx.append(i)
267
+ new_names = [names[i] for i in choose_idx]
268
+ choose_idx = np.array(choose_idx)
269
+ new_similarity_matrix = similarity_matrix[choose_idx][:,choose_idx]
270
+ return new_names, new_similarity_matrix
271
+
272
+ def filter(names, similarity_matrix,filter_human=False,filter_llm=False,filter_mix=False):
273
+ choose_idx = []
274
+ for i in range(len(names)):
275
+ if names[i] == 'human' and filter_human:
276
+ continue
277
+ if filter_llm and 'human' not in names[i]:
278
+ continue
279
+ if filter_mix and 'human' in names[i] and names[i]!='human':
280
+ continue
281
+ choose_idx.append(i)
282
+ new_names = [names[i] for i in choose_idx]
283
+ choose_idx = np.array(choose_idx)
284
+ new_similarity_matrix = similarity_matrix[choose_idx][:,choose_idx]
285
+ return new_names, new_similarity_matrix
286
+
287
+ def reid_tree_dict(tree_dict, nodes, names):
288
+ name_to_index = {name: idx for idx, name in enumerate(names)}
289
+ for deep,values in tree_dict.items():
290
+ rename_now = []
291
+ # print(values,len(values))
292
+ for list_ in values:
293
+ now_list = []
294
+ for idx in list_:
295
+ name = nodes[idx].name
296
+ if name not in name_to_index:
297
+ name_to_index[name] = len(names)
298
+ names.append(name)
299
+ name_idx = name_to_index[name]
300
+ now_list.append(name_idx)
301
+ rename_now.append(now_list)
302
+ tree_dict[deep] = rename_now
303
+ return tree_dict
304
+
305
+ def gen_tree(similarity_matrix,names,opt):
306
+ distance_matrix = 1 - similarity_matrix
307
+ np.fill_diagonal(distance_matrix, 0)
308
+ condensed_distance_matrix = squareform(distance_matrix)
309
+ Z = linkage(condensed_distance_matrix, method='weighted') # alternative methods include 'single', 'complete', or 'ward'
310
+ if opt.save_drg:
311
+ plt.figure(figsize=(30, 47))
312
+ dendrogram(Z, labels=names, orientation='right',leaf_font_size=16) # rotate the dendrogram so the root is on the right
313
+ plt.savefig(opt.dendrogram_path, format='pdf' ,bbox_inches='tight')
314
+ nodes = build_tree(Z, names)
315
+ merge_tree(len(nodes)-1,nodes,distance_matrix,end_thd=opt.end_score)
316
+
317
+ return nodes
318
+
319
+ def chage_tree_priori1(nodes):
320
+ human_node = TreeNode(name='human')
321
+ root = TreeNode()
322
+ root.add_child(len(nodes))
323
+ root.add_child(len(nodes)-1)
324
+ nodes.append(human_node)
325
+ nodes.append(root)
326
+ return nodes
327
+
328
+ def chage_tree_priori2(human_nodes,llm_nodes):
329
+ root = TreeNode()
330
+ root.add_child(len(human_nodes)-1)
331
+ root.add_child(len(human_nodes)+len(llm_nodes)-1)
332
+ for i in range(len(llm_nodes)):
333
+ llm_nodes[i].children = [len(human_nodes)+x for x in llm_nodes[i].children]
334
+ nodes = human_nodes+llm_nodes
335
+ nodes.append(root)
336
+ return nodes
337
+
338
+ def chage_tree_priori3(co_nodes,llm_nodes):
339
+ human_node = TreeNode(name='human')
340
+ root = TreeNode()
341
+ root.add_child(len(co_nodes)+len(llm_nodes))
342
+ root.add_child(len(co_nodes)-1)
343
+ root.add_child(len(co_nodes)+len(llm_nodes)-1)
344
+ for i in range(len(llm_nodes)):
345
+ llm_nodes[i].children = [len(co_nodes)+x for x in llm_nodes[i].children]
346
+ nodes = co_nodes+llm_nodes
347
+ nodes.append(human_node)
348
+ nodes.append(root)
349
+ return nodes
350
+
351
+ def randmo_filter(names, similarity_matrix):
352
+ choose_idx = []
353
+ for i in range(len(names)):
354
+ if 'human' in names[i]:
355
+ choose_idx.append(i)
356
+ elif 'fair' in names[i] or 'pplm' in names[i] or 'gpt2-pytorch' in names[i] or ' transfo' in names[i] or 'ctrl' in names[i]:
357
+ continue
358
+ elif 'xlnet' in names[i] or 'grover' in names[i]:
359
+ if random.random()<0.07:
360
+ choose_idx.append(i)
361
+ elif random.random()<0.22:
362
+ choose_idx.append(i)
363
+ new_names = []
364
+ for i in choose_idx:
365
+ if names[i].startswith('7B') or names[i].startswith('13B') or names[i].startswith('30B') or names[i].startswith('65B'):
366
+ new_names.append('LLaMA_'+names[i])
367
+ else:
368
+ new_names.append(names[i])
369
+ choose_idx = np.array(choose_idx)
370
+ new_similarity_matrix = similarity_matrix[choose_idx][:,choose_idx]
371
+ return new_names, new_similarity_matrix
372
+
373
+ def ishuman(name):
374
+ return ('human' in name)
375
+ def ismachine(name):
376
+ return ('machine' in name or 'rephrase' in name)
377
+
378
+ def get_llm(x):
379
+ if 'gpt-3.5-turbo' in x:
380
+ return 'gpt-3.5-turbo'
381
+ elif 'gpt-4o' in x:
382
+ return 'gpt-4o'
383
+ elif 'llama-3.3-70b' in x:
384
+ return 'llama-3.3-70b'
385
+ elif 'gemini-1.5-pro' in x:
386
+ return 'gemini-1.5-pro'
387
+ elif 'claude-3-5-sonnet' in x:
388
+ return 'claude-3-5-sonnet'
389
+ elif 'qwen2.5-72b' in x:
390
+ return 'qwen2.5-72b'
391
+ else:
392
+ raise ValueError(f"Invalid class name: {x}")
393
+
394
+ def get_name(name):
395
+ name = name.split('_')
396
+ assert len(name) == 2
397
+ if ishuman(name[0]):
398
+ if name[1]=='humanize:human' or name[1]=='human':
399
+ return 'human'
400
+ elif name[1]=='humanize:tool':
401
+ return 'human_humanize_tool'
402
+ else:
403
+ llm_name = get_llm(name[1])
404
+ return f'human_rephrase_{llm_name}'
405
+ elif ismachine(name[0]):
406
+ llm_name = get_llm(name[0])
407
+ if name[1]=='humanize:human' or name[1]=='human':
408
+ return f'{llm_name}_humanize_human'
409
+ elif name[1]=='humanize:tool':
410
+ return f'{llm_name}_humanize_tool'
411
+ elif 'humanize:' in name[1]:
412
+ llm_name2 = get_llm(name[1])
413
+ return f'{llm_name}_humanize_{llm_name2}'
414
+ else:
415
+ return llm_name
416
+
417
+ def clear_names(names):
418
+ new_names = []
419
+ for name in names:
420
+ new_names.append(get_name(name))
421
+ return new_names
422
+
423
+ def build_argument_parser() -> argparse.ArgumentParser:
424
+ parser = argparse.ArgumentParser(
425
+ description="Construct the HAT tree from a similarity matrix.",
426
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
427
+ )
428
+ parser.add_argument('--file-path', type=Path, required=True, help='Input similarity matrix text file.')
429
+ parser.add_argument('--priori',type=int,default=1,choices=[0,1,2,3])
430
+ parser.add_argument('--save-txt-path', type=Path, required=True, help='Destination path for the tree definition.')
431
+ parser.add_argument('--save-table-path', type=Path, required=True, help='Destination path for the visualised table.')
432
+ parser.add_argument('--dendrogram-path', type=Path, default=None, help='Optional path for the dendrogram PDF when saved.')
433
+ parser.add_argument('--save-drg', action='store_true', help='Persist the dendrogram PDF alongside the tree.')
434
+ parser.add_argument('--no-save-drg', dest='save_drg', action='store_false')
435
+ parser.set_defaults(save_drg=True)
436
+ parser.add_argument('--save-max-dep', type=int, default=5)
437
+ parser.add_argument('--end-score', type=float, default=0.1)
438
+ parser.add_argument('--randmo-filter', action='store_true', help='Randomly subsample similarity entries.')
439
+ return parser
440
+
441
+
442
+ def main(argv: Optional[Iterable[str]] = None) -> None:
443
+ parser = build_argument_parser()
444
+ opt = parser.parse_args(argv)
445
+
446
+ names, similarity_matrix = read_similarity_matrix(opt.file_path)
447
+ if opt.save_drg:
448
+ if opt.dendrogram_path is None:
449
+ opt.dendrogram_path = opt.save_table_path.with_name(
450
+ f"{opt.save_table_path.stem}_dendrogram.pdf"
451
+ )
452
+ opt.dendrogram_path.parent.mkdir(parents=True, exist_ok=True)
453
+ else:
454
+ opt.dendrogram_path = None
455
+ similarity_matrix = fix_asymmetry(similarity_matrix)
456
+ if opt.randmo_filter:
457
+ names, similarity_matrix = randmo_filter(names, similarity_matrix)
458
+ # names = clear_names(names)
459
+ if opt.priori==1:
460
+ llm_names, llm_similarity_matrix = filter(names, similarity_matrix,filter_human=True)
461
+ nodes = gen_tree(llm_similarity_matrix,llm_names,opt)
462
+ nodes = chage_tree_priori1(nodes)
463
+
464
+ elif opt.priori==2:
465
+ human_names, human_similarity_matrix = filter(names, similarity_matrix,filter_llm=True)
466
+ human_nodes = gen_tree(human_similarity_matrix,human_names,opt)
467
+ llm_names, llm_similarity_matrix = filter(names, similarity_matrix,filter_human=True,filter_mix=True)
468
+ llm_nodes = gen_tree(llm_similarity_matrix,llm_names,opt)
469
+ nodes = chage_tree_priori2(human_nodes,llm_nodes)
470
+
471
+ elif opt.priori==3:
472
+ co_names, co_similarity_matrix = filter(names, similarity_matrix,filter_llm=True,filter_human=True)
473
+ co_nodes = gen_tree(co_similarity_matrix,co_names,opt)
474
+ llm_names, llm_similarity_matrix = filter(names, similarity_matrix,filter_human=True,filter_mix=True)
475
+ llm_nodes = gen_tree(llm_similarity_matrix,llm_names,opt)
476
+ nodes = chage_tree_priori3(co_nodes,llm_nodes)
477
+
478
+ elif opt.priori==0:
479
+ nodes = gen_tree(similarity_matrix,names,opt)
480
+ else:
481
+ raise ValueError("Invalid value for --priori. Choose from 0, 1, 2, or 3.")
482
+
483
+ edge=[]
484
+ tree_dict = update_tree(len(nodes)-1, nodes, edge)
485
+ edge = rename(edge)
486
+ opt.save_txt_path.parent.mkdir(parents=True, exist_ok=True)
487
+ opt.save_table_path.parent.mkdir(parents=True, exist_ok=True)
488
+ save_edge(edge,opt.save_txt_path)
489
+ tree_dict = reid_tree_dict(tree_dict, nodes, names)
490
+ draw_table(tree_dict, names, max_deep=opt.save_max_dep, save_path=opt.save_table_path)
491
+
492
+
493
+ if __name__ == "__main__":
494
+ main()
495
+
496
+
497
+ __all__ = ["build_argument_parser", "main", "read_similarity_matrix", "gen_tree"]
detree/cli/merge_lora.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Merge LoRA adapters into base models."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ from pathlib import Path
7
+ from typing import Iterable, Optional
8
+
9
+ from peft import PeftModel
10
+ from transformers import AutoModel, AutoTokenizer
11
+
12
+
13
+ def merge_lora_adapter(base_model: str, adapter_path: Path, output_dir: Path, safe_serialization: bool = True) -> None:
14
+ output_dir.mkdir(parents=True, exist_ok=True)
15
+
16
+ model = AutoModel.from_pretrained(base_model, trust_remote_code=True)
17
+ tokenizer = AutoTokenizer.from_pretrained(base_model)
18
+
19
+ peft_model = PeftModel.from_pretrained(model, str(adapter_path))
20
+ merged_model = peft_model.merge_and_unload()
21
+ merged_model.save_pretrained(output_dir, safe_serialization=safe_serialization)
22
+ tokenizer.save_pretrained(output_dir)
23
+
24
+
25
+ def build_argument_parser() -> argparse.ArgumentParser:
26
+ parser = argparse.ArgumentParser(description="Merge a LoRA adapter into its base Hugging Face model.")
27
+ parser.add_argument("--base-model", type=str, required=True, help="Base model name or path.")
28
+ parser.add_argument("--adapter-path", type=Path, required=True, help="Directory containing the LoRA adapter weights.")
29
+ parser.add_argument("--output-dir", type=Path, required=True, help="Directory to store the merged model.")
30
+ parser.add_argument(
31
+ "--no-safe-serialization",
32
+ action="store_true",
33
+ help="Disable safetensors when saving the merged model.",
34
+ )
35
+ return parser
36
+
37
+
38
+ def main(argv: Optional[Iterable[str]] = None) -> None:
39
+ parser = build_argument_parser()
40
+ args = parser.parse_args(argv)
41
+ merge_lora_adapter(
42
+ args.base_model,
43
+ args.adapter_path,
44
+ args.output_dir,
45
+ safe_serialization=not args.no_safe_serialization,
46
+ )
47
+
48
+
49
+ if __name__ == "__main__":
50
+ main()
51
+
52
+ __all__ = ["build_argument_parser", "merge_lora_adapter", "main"]
detree/cli/similarity_matrix.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Compute similarity matrices from embedding databases."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ from pathlib import Path
7
+ from typing import Iterable, Optional
8
+
9
+ import matplotlib.pyplot as plt
10
+ import torch
11
+
12
+
13
+ def gen_data(dict_data):
14
+ embeddings = dict_data["embeddings"]
15
+ labels = dict_data["labels"]
16
+ ids = dict_data["ids"]
17
+ classes = dict_data["classes"]
18
+ return embeddings, labels, ids, classes
19
+
20
+
21
+ def build_argument_parser() -> argparse.ArgumentParser:
22
+ parser = argparse.ArgumentParser(
23
+ description="Generate class similarity matrices for DETree.",
24
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
25
+ )
26
+ parser.add_argument("--database", type=Path, required=True, help="Path to the embedding database (.pt).")
27
+ parser.add_argument("--output-dir", type=Path, required=True, help="Directory to store the similarity outputs.")
28
+ parser.add_argument("--layers", type=int, nargs="*", default=None, help="Specific layers to export. Defaults to all.")
29
+ return parser
30
+
31
+
32
+ def compute_similarity(database: Path, output_dir: Path, layers: Optional[Iterable[int]]) -> None:
33
+ output_dir.mkdir(parents=True, exist_ok=True)
34
+ data_emb, data_labels, data_ids, data_classes = gen_data(torch.load(database))
35
+
36
+ if layers is None:
37
+ layers = list(data_emb.keys())
38
+
39
+ for layer in layers:
40
+ center = []
41
+ for item in data_classes:
42
+ index = data_classes.index(item)
43
+ now_emb = data_emb[layer][data_labels == index]
44
+ center.append(torch.mean(now_emb, dim=0))
45
+ center = torch.stack(center)
46
+ similarity = center @ center.T
47
+ similarity_np = similarity.cpu().float().numpy()
48
+
49
+ txt_path = output_dir / f"similarity_layer_{layer}.txt"
50
+ with txt_path.open("w", encoding="utf-8") as f:
51
+ f.write(" ".join(data_classes) + "\n")
52
+ for i, class_name in enumerate(data_classes):
53
+ row = " ".join(f"{similarity_np[i, j]:.4f}" for j in range(len(data_classes)))
54
+ f.write(f"{class_name} {row}\n")
55
+
56
+ plt.figure(figsize=(30, 30))
57
+ plt.imshow(similarity_np, cmap="viridis")
58
+ plt.colorbar()
59
+ plt.xticks(range(len(data_classes)), data_classes, rotation=45, fontsize=12)
60
+ plt.yticks(range(len(data_classes)), data_classes, fontsize=12)
61
+ plt.title(f"Similarity Matrix (layer {layer})", fontsize=20)
62
+ fig_path = output_dir / f"similarity_layer_{layer}.png"
63
+ plt.savefig(fig_path, dpi=300, bbox_inches="tight")
64
+ plt.close()
65
+ print(f"Saved similarity matrix for layer {layer} to {txt_path} and {fig_path}")
66
+
67
+
68
+ def main(argv: Optional[Iterable[str]] = None) -> None:
69
+ parser = build_argument_parser()
70
+ args = parser.parse_args(argv)
71
+ compute_similarity(args.database, args.output_dir, args.layers)
72
+
73
+
74
+ if __name__ == "__main__":
75
+ main()
76
+
77
+ __all__ = ["build_argument_parser", "compute_similarity", "main"]
detree/cli/test_database_score_knn.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """kNN evaluation using pre-computed embedding databases."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import json
7
+ import os
8
+ from multiprocessing import Pool, cpu_count
9
+ from pathlib import Path
10
+ from typing import Iterable, List, Optional, Sequence, Tuple
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn.functional as F
15
+ from lightning import Fabric
16
+ from torch.nn.functional import softmax as F_softmax
17
+ from torch.utils.data import DataLoader, Dataset
18
+ from tqdm import tqdm
19
+
20
+ from detree.model.text_embedding import TextEmbeddingModel
21
+ from detree.utils.index import Indexer
22
+ from detree.utils.utils import evaluate_metrics
23
+
24
+ os.environ.setdefault("TOKENIZERS_PARALLELISM", "true")
25
+
26
+
27
+ def load_jsonl(file_path: Path) -> List[dict]:
28
+ out = []
29
+ with file_path.open(mode="r", encoding="utf-8") as jsonl_file:
30
+ for line in jsonl_file:
31
+ item = json.loads(line)
32
+ out.append(item)
33
+ print(f"Loaded {len(out)} examples from {file_path}")
34
+ return out
35
+
36
+
37
+ def gen_data(dict_data):
38
+ embeddings = dict_data["embeddings"]
39
+ labels = dict_data["labels"]
40
+ ids = dict_data["ids"]
41
+ classes = dict_data["classes"]
42
+ return embeddings, labels, ids, classes
43
+
44
+
45
+ class PassagesDataset(Dataset):
46
+ def __init__(self, data: Sequence[dict]):
47
+ self.passages = list(data)
48
+
49
+ def __len__(self) -> int:
50
+ return len(self.passages)
51
+
52
+ def __getitem__(self, idx: int):
53
+ data_now = self.passages[idx]
54
+ text = data_now["text"]
55
+ label = data_now["label"]
56
+ ids = data_now["id"]
57
+ return text, int(label), int(ids)
58
+
59
+
60
+ def infer(passages_dataloader, fabric, tokenizer, model, need_layers: Sequence[int], max_length: int = 512):
61
+ if fabric.global_rank == 0:
62
+ passages_dataloader = tqdm(passages_dataloader)
63
+ all_ids: List[int] = []
64
+ all_embeddings: List[torch.Tensor] = []
65
+ all_labels: List[int] = []
66
+ with torch.no_grad():
67
+ for batch in passages_dataloader:
68
+ text, label, ids = batch
69
+ encoded_batch = tokenizer.batch_encode_plus(
70
+ text,
71
+ return_tensors="pt",
72
+ max_length=max_length,
73
+ padding="max_length",
74
+ truncation=True,
75
+ )
76
+ encoded_batch = {k: v.cuda() for k, v in encoded_batch.items()}
77
+ embeddings = model(encoded_batch, hidden_states=True)
78
+ embeddings = fabric.all_gather(embeddings).view(-1, embeddings.size(-2), embeddings.size(-1))
79
+ label = fabric.all_gather(label).view(-1)
80
+ ids = fabric.all_gather(ids).view(-1)
81
+ if fabric.global_rank == 0:
82
+ all_embeddings.append(embeddings.cpu())
83
+ all_ids.extend(ids.cpu().tolist())
84
+ all_labels.extend(label.cpu().tolist())
85
+ if fabric.global_rank == 0:
86
+ embeddings_tensor = torch.cat(all_embeddings, dim=0)
87
+ embeddings_tensor = F.normalize(embeddings_tensor, dim=-1).permute(1, 0, 2).numpy()
88
+ embeddings_tensor = {layer: embeddings_tensor[layer] for layer in need_layers}
89
+ return all_ids, embeddings_tensor, all_labels
90
+ return [], [], []
91
+
92
+
93
+ def dict2str(metrics: dict) -> str:
94
+ out_str = ""
95
+ if "layer" in metrics:
96
+ out_str += f"layer:{metrics['layer']} "
97
+ if "k" in metrics:
98
+ out_str += f"k:{metrics['k']} "
99
+ for key, value in metrics.items():
100
+ if key not in {"layer", "k"}:
101
+ out_str += f"{key}:{value} "
102
+ return out_str.strip()
103
+
104
+
105
+ def process_element(args: Tuple[Sequence[int], Sequence[float], Sequence[int], float]):
106
+ ids, scores, labels, temperature = args
107
+ now_score = torch.zeros(2)
108
+ sorted_indices = np.argsort(scores)[::-1]
109
+ element_preds = {}
110
+
111
+ for k, idx in enumerate(sorted_indices):
112
+ label = labels[idx]
113
+ now_score[label] += scores[idx] * temperature
114
+ prob = F_softmax(now_score, dim=-1)[1].item()
115
+ element_preds[k + 1] = prob
116
+
117
+ return element_preds
118
+
119
+
120
+ def build_argument_parser() -> argparse.ArgumentParser:
121
+ parser = argparse.ArgumentParser(
122
+ description="Evaluate DETree with a precomputed embedding database.",
123
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
124
+ )
125
+ parser.add_argument("--device-num", type=int, default=1)
126
+ parser.add_argument("--batch-size", type=int, default=32)
127
+ parser.add_argument("--num-workers", type=int, default=8)
128
+ parser.add_argument("--max-length", type=int, default=512)
129
+
130
+ parser.add_argument("--database-path", type=Path, required=True, help="Path to the saved embedding database (.pt).")
131
+ parser.add_argument("--test-dataset-path", type=Path, required=True, help="Evaluation JSONL file.")
132
+ parser.add_argument("--model-name-or-path", type=str, required=True)
133
+ parser.add_argument("--temperature", type=float, default=0.05)
134
+
135
+ parser.add_argument("--max-k", type=int, default=51, dest="max_K")
136
+ parser.add_argument("--pooling", type=str, default="max", choices=("max", "average", "cls"))
137
+
138
+ parser.add_argument("--embedding-dim", type=int, default=1024)
139
+ parser.add_argument("--pool-workers", type=int, default=min(32, cpu_count()))
140
+ parser.add_argument("--log-file", type=Path, default=Path("runs/val.txt"))
141
+
142
+ return parser
143
+
144
+
145
+ def evaluate(args: argparse.Namespace) -> None:
146
+ if args.device_num > 1:
147
+ fabric = Fabric(accelerator="cuda", devices=args.device_num, strategy="ddp", precision="bf16-mixed")
148
+ else:
149
+ fabric = Fabric(accelerator="cuda", devices=args.device_num, precision="bf16-mixed")
150
+ fabric.launch()
151
+
152
+ model = TextEmbeddingModel(
153
+ args.model_name_or_path,
154
+ output_hidden_states=True,
155
+ infer=True,
156
+ use_pooling=args.pooling,
157
+ ).cuda()
158
+ tokenizer = model.tokenizer
159
+ model.eval()
160
+
161
+ if fabric.global_rank == 0:
162
+ db_embeddings, db_labels, db_ids, classes = gen_data(torch.load(args.database_path))
163
+ need_layers = list(db_embeddings.keys())
164
+ else:
165
+ db_embeddings = db_labels = db_ids = classes = None
166
+ need_layers = []
167
+ need_layers = fabric.broadcast(need_layers)
168
+
169
+ test_database = load_jsonl(args.test_dataset_path)
170
+ test_dataset = PassagesDataset(test_database)
171
+ test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False)
172
+ test_dataloader = fabric.setup_dataloaders(test_dataloader)
173
+ model = fabric.setup(model)
174
+ test_ids, test_embeddings, test_labels = infer(test_dataloader, fabric, tokenizer, model, need_layers, args.max_length)
175
+
176
+ torch.cuda.empty_cache()
177
+ if fabric.global_rank != 0:
178
+ return
179
+
180
+ test_labels = [int(label) for label in test_labels]
181
+ index = Indexer(args.embedding_dim)
182
+ human_idx = classes.index("human")
183
+
184
+ all_details = []
185
+ with Pool(processes=args.pool_workers) as pool:
186
+ for layer in need_layers:
187
+ now_best_metrics = None
188
+ label_dict = {}
189
+ train_embeddings = db_embeddings[layer].float().numpy()
190
+ if isinstance(db_labels, dict):
191
+ train_labels = db_labels[layer].tolist()
192
+ train_ids = db_ids[layer].tolist()
193
+ else:
194
+ train_labels = db_labels.tolist()
195
+ train_ids = db_ids.tolist()
196
+
197
+ for i in range(len(train_ids)):
198
+ label_dict[int(train_ids[i])] = int(train_labels[i] == human_idx)
199
+ index.label_dict = label_dict
200
+ index.reset()
201
+ index.index_data(train_ids, train_embeddings)
202
+ preds = {k: [] for k in range(1, args.max_K + 1)}
203
+ top_ids_and_scores = index.search_knn(test_embeddings[layer], args.max_K, index_batch_size=128)
204
+
205
+ args_list = [
206
+ (ids, scores, labels, args.temperature)
207
+ for ids, scores, labels in top_ids_and_scores
208
+ ]
209
+ for result in tqdm(pool.imap(process_element, args_list), total=len(args_list)):
210
+ for k, value in result.items():
211
+ preds[k].append(value)
212
+
213
+ for k in range(1, args.max_K + 1):
214
+ metric = evaluate_metrics(test_labels, preds[k], threshold_param=-1)
215
+ if now_best_metrics is None or now_best_metrics["auroc"] < metric["auroc"]:
216
+ now_best_metrics = metric
217
+ now_best_metrics["k"] = k
218
+ now_best_metrics["layer"] = layer
219
+
220
+ if now_best_metrics:
221
+ print(dict2str(now_best_metrics))
222
+ all_details.append(now_best_metrics)
223
+
224
+ if not all_details:
225
+ return
226
+
227
+ max_ids = max(range(len(all_details)), key=lambda idx: all_details[idx]["auroc"])
228
+ best_metrics = all_details[max_ids]
229
+ print("Best " + dict2str(best_metrics))
230
+ args.log_file.parent.mkdir(parents=True, exist_ok=True)
231
+ with args.log_file.open("a+", encoding="utf-8") as fp:
232
+ fp.write(f"test model:{args.model_name_or_path} mode:{args.test_dataset_path} database_path:{args.database_path}\n")
233
+ fp.write(f"Last {dict2str(all_details[-1])}\n")
234
+ fp.write(f"Best {dict2str(best_metrics)}\n")
235
+ fp.write("------------------------------------------\n")
236
+
237
+
238
+ def main(argv: Optional[Iterable[str]] = None) -> None:
239
+ parser = build_argument_parser()
240
+ args = parser.parse_args(argv)
241
+ evaluate(args)
242
+
243
+
244
+ if __name__ == "__main__":
245
+ main()
246
+
247
+ __all__ = ["build_argument_parser", "evaluate", "main"]
detree/cli/test_score_knn.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """kNN evaluation against raw text datasets."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import json
7
+ import os
8
+ from multiprocessing import Pool, cpu_count
9
+ from pathlib import Path
10
+ from typing import Iterable, List, Optional, Sequence, Tuple
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn.functional as F
15
+ from lightning import Fabric
16
+ from torch.nn.functional import softmax as F_softmax
17
+ from torch.utils.data import DataLoader, Dataset
18
+ from tqdm import tqdm
19
+
20
+ from detree.model.text_embedding import TextEmbeddingModel
21
+ from detree.utils.index import Indexer
22
+ from detree.utils.utils import evaluate_metrics
23
+
24
+ os.environ.setdefault("TOKENIZERS_PARALLELISM", "true")
25
+
26
+
27
+ def load_jsonl(file_path: Path) -> List[dict]:
28
+ out = []
29
+ with file_path.open(mode="r", encoding="utf-8") as jsonl_file:
30
+ for line in jsonl_file:
31
+ item = json.loads(line)
32
+ out.append(item)
33
+ print(f"Loaded {len(out)} examples from {file_path}")
34
+ return out
35
+
36
+
37
+ class PassagesDataset(Dataset):
38
+ def __init__(self, data: Sequence[dict]):
39
+ self.passages = list(data)
40
+
41
+ def __len__(self) -> int:
42
+ return len(self.passages)
43
+
44
+ def __getitem__(self, idx: int):
45
+ data_now = self.passages[idx]
46
+ text = data_now["text"]
47
+ label = data_now["label"]
48
+ ids = data_now["id"]
49
+ return text, int(label), int(ids)
50
+
51
+
52
+ def infer(passages_dataloader, fabric, tokenizer, model, max_length: int = 512):
53
+ if fabric.global_rank == 0:
54
+ passages_dataloader = tqdm(passages_dataloader)
55
+ all_ids: List[int] = []
56
+ all_embeddings: List[torch.Tensor] = []
57
+ all_labels: List[int] = []
58
+ with torch.no_grad():
59
+ for batch in passages_dataloader:
60
+ text, label, ids = batch
61
+ encoded_batch = tokenizer.batch_encode_plus(
62
+ text,
63
+ return_tensors="pt",
64
+ max_length=max_length,
65
+ padding="max_length",
66
+ truncation=True,
67
+ )
68
+ encoded_batch = {k: v.cuda() for k, v in encoded_batch.items()}
69
+ embeddings = model(encoded_batch, hidden_states=True)
70
+ embeddings = fabric.all_gather(embeddings).view(-1, embeddings.size(-2), embeddings.size(-1))
71
+ label = fabric.all_gather(label).view(-1)
72
+ ids = fabric.all_gather(ids).view(-1)
73
+ if fabric.global_rank == 0:
74
+ all_embeddings.append(embeddings.cpu())
75
+ all_ids.extend(ids.cpu().tolist())
76
+ all_labels.extend(label.cpu().tolist())
77
+ if fabric.global_rank == 0:
78
+ embeddings_tensor = torch.cat(all_embeddings, dim=0)
79
+ embeddings_tensor = F.normalize(embeddings_tensor, dim=-1).permute(1, 0, 2)
80
+ return all_ids, embeddings_tensor.numpy(), all_labels
81
+ return [], [], []
82
+
83
+
84
+ def save_pt(train_embeddings, all_labels, train_ids, args, best_layer):
85
+ save_layer = [best_layer, train_embeddings.shape[0] - 1]
86
+ all_embeddings = {i: torch.tensor(train_embeddings[i]) for i in save_layer}
87
+ emb_dict = {
88
+ "embeddings": all_embeddings,
89
+ "labels": torch.tensor(all_labels),
90
+ "ids": torch.tensor(train_ids),
91
+ "classes": ["llm", "human"],
92
+ }
93
+ args.savedir.mkdir(parents=True, exist_ok=True)
94
+ output_path = args.savedir / f"{args.name}.pt"
95
+ torch.save(emb_dict, output_path)
96
+ print(f"Saved embedding snapshot to {output_path}")
97
+
98
+
99
+ def dict2str(metrics: dict) -> str:
100
+ out_str = ""
101
+ if "layer" in metrics:
102
+ out_str += f"layer:{metrics['layer']} "
103
+ if "k" in metrics:
104
+ out_str += f"k:{metrics['k']} "
105
+ for key, value in metrics.items():
106
+ if key not in {"layer", "k"}:
107
+ out_str += f"{key}:{value} "
108
+ return out_str.strip()
109
+
110
+
111
+ def process_element(args: Tuple[Sequence[int], Sequence[float], Sequence[int], float]):
112
+ ids, scores, labels, temperature = args
113
+ now_score = torch.zeros(2)
114
+ sorted_indices = np.argsort(scores)[::-1]
115
+ element_preds = {}
116
+
117
+ for k, idx in enumerate(sorted_indices):
118
+ label = labels[idx]
119
+ now_score[label] += scores[idx] * temperature
120
+ prob = F_softmax(now_score, dim=-1)[1].item()
121
+ element_preds[k + 1] = prob
122
+
123
+ return element_preds
124
+
125
+
126
+ def build_argument_parser() -> argparse.ArgumentParser:
127
+ parser = argparse.ArgumentParser(
128
+ description="Evaluate DETree checkpoints using a kNN classifier over hidden states.",
129
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
130
+ )
131
+ parser.add_argument("--device-num", type=int, default=1)
132
+ parser.add_argument("--batch-size", type=int, default=32)
133
+ parser.add_argument("--num-workers", type=int, default=8)
134
+ parser.add_argument("--max-length", type=int, default=512)
135
+
136
+ parser.add_argument("--database-path", type=Path, required=True, help="Training set JSONL file.")
137
+ parser.add_argument("--test-dataset-path", type=Path, required=True, help="Evaluation set JSONL file.")
138
+ parser.add_argument(
139
+ "--model-name-or-path",
140
+ type=str,
141
+ required=True,
142
+ help="Model identifier from Hugging Face or local path to a merged checkpoint.",
143
+ )
144
+ parser.add_argument("--temperature", type=float, default=0.05)
145
+
146
+ parser.add_argument("--max-k", type=int, default=50, dest="max_K", help="Maximum k to evaluate for kNN.")
147
+ parser.add_argument("--min-layer", type=int, default=15, help="Minimum hidden layer index to evaluate.")
148
+ parser.add_argument("--pooling", type=str, default="max", choices=("max", "average", "cls"))
149
+
150
+ parser.add_argument("--embedding-dim", type=int, default=1024)
151
+ parser.add_argument("--n-subquantizers", type=int, default=1)
152
+ parser.add_argument("--n-bits", type=int, default=8)
153
+
154
+ parser.add_argument("--savedir", type=Path, default=Path("runs"))
155
+ parser.add_argument("--name", type=str, default="database_knn_eval")
156
+ parser.add_argument("--pool-workers", type=int, default=min(32, cpu_count()))
157
+ parser.add_argument("--save-embeddings", action="store_true", help="Persist embeddings for the best-performing layer.")
158
+ parser.add_argument("--log-file", type=Path, default=Path("runs/val.txt"))
159
+
160
+ return parser
161
+
162
+
163
+ def evaluate(args: argparse.Namespace) -> None:
164
+ if args.device_num > 1:
165
+ fabric = Fabric(accelerator="cuda", devices=args.device_num, strategy="ddp", precision="bf16-mixed")
166
+ else:
167
+ fabric = Fabric(accelerator="cuda", devices=args.device_num, precision="bf16-mixed")
168
+ fabric.launch()
169
+
170
+ model = TextEmbeddingModel(
171
+ args.model_name_or_path,
172
+ output_hidden_states=True,
173
+ infer=True,
174
+ use_pooling=args.pooling,
175
+ ).cuda()
176
+ tokenizer = model.tokenizer
177
+ model.eval()
178
+
179
+ database = load_jsonl(args.database_path)
180
+ test_database = load_jsonl(args.test_dataset_path)
181
+
182
+ passages_dataset = PassagesDataset(database)
183
+ test_dataset = PassagesDataset(test_database)
184
+
185
+ passages_dataloader = DataLoader(
186
+ passages_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True
187
+ )
188
+ test_dataloader = DataLoader(
189
+ test_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False
190
+ )
191
+
192
+ passages_dataloader, test_dataloader = fabric.setup_dataloaders(passages_dataloader, test_dataloader)
193
+ model = fabric.setup(model)
194
+
195
+ train_ids, train_embeddings, train_labels = infer(passages_dataloader, fabric, tokenizer, model, args.max_length)
196
+ test_ids, test_embeddings, test_labels = infer(test_dataloader, fabric, tokenizer, model, args.max_length)
197
+
198
+ torch.cuda.empty_cache()
199
+ if fabric.global_rank != 0:
200
+ return
201
+
202
+ layer_num = train_embeddings.shape[0]
203
+ test_labels = [int(label) for label in test_labels]
204
+
205
+ label_dict = {train_ids[i]: int(train_labels[i]) for i in range(len(train_ids))}
206
+
207
+ all_details = []
208
+ index = Indexer(args.embedding_dim, args.n_subquantizers, args.n_bits)
209
+ index.label_dict = label_dict
210
+
211
+ with Pool(processes=args.pool_workers) as pool:
212
+ for i in range(args.min_layer, layer_num):
213
+ now_best_metrics = None
214
+ index.reset()
215
+ index.index_data(train_ids, train_embeddings[i])
216
+ preds = {k: [] for k in range(1, args.max_K + 1)}
217
+ top_ids_and_scores = index.search_knn(test_embeddings[i], args.max_K, index_batch_size=128)
218
+
219
+ args_list = [
220
+ (ids, scores, labels, args.temperature)
221
+ for ids, scores, labels in top_ids_and_scores
222
+ ]
223
+ for result in tqdm(pool.imap(process_element, args_list), total=len(args_list)):
224
+ for k, value in result.items():
225
+ preds[k].append(value)
226
+
227
+ for k in range(2, args.max_K + 1):
228
+ metric = evaluate_metrics(test_labels, preds[k], threshold_param=-1)
229
+ if now_best_metrics is None or now_best_metrics["auroc"] < metric["auroc"]:
230
+ now_best_metrics = metric
231
+ now_best_metrics["k"] = k
232
+ now_best_metrics["layer"] = i
233
+
234
+ if now_best_metrics:
235
+ print(dict2str(now_best_metrics))
236
+ all_details.append(now_best_metrics)
237
+
238
+ if not all_details:
239
+ return
240
+
241
+ max_ids = max(range(len(all_details)), key=lambda idx: all_details[idx]["auroc"])
242
+ best_metrics = all_details[max_ids]
243
+
244
+ if args.save_embeddings:
245
+ save_pt(train_embeddings, train_labels, train_ids, args, best_metrics["layer"])
246
+
247
+ print("Best " + dict2str(best_metrics))
248
+ args.log_file.parent.mkdir(parents=True, exist_ok=True)
249
+ with args.log_file.open("a+", encoding="utf-8") as fp:
250
+ fp.write(
251
+ f"test model:{args.model_name_or_path} database_path:{args.database_path} mode:{args.test_dataset_path}\n"
252
+ )
253
+ fp.write(f"Last {dict2str(all_details[-1])}\n")
254
+ fp.write(f"Best {dict2str(best_metrics)}\n")
255
+ fp.write("------------------------------------------\n")
256
+
257
+
258
+ def main(argv: Optional[Iterable[str]] = None) -> None:
259
+ parser = build_argument_parser()
260
+ args = parser.parse_args(argv)
261
+ evaluate(args)
262
+
263
+
264
+ if __name__ == "__main__":
265
+ main()
266
+
267
+ __all__ = ["build_argument_parser", "evaluate", "main"]
detree/cli/train.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Training CLI for DETree."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import random
7
+ from dataclasses import dataclass
8
+ from pathlib import Path
9
+ from typing import Iterable, Optional
10
+
11
+ import torch
12
+ import torch.nn.functional as F # noqa: F401 # retained for backward compat with downstream imports
13
+ import torch.optim as optim
14
+ import yaml
15
+ from lightning import Fabric
16
+ from lightning.fabric.strategies import DeepSpeedStrategy, DDPStrategy
17
+ from torch.utils.data import DataLoader
18
+ from torch.utils.data.dataloader import default_collate
19
+ from torch.utils.tensorboard import SummaryWriter
20
+ from tqdm import tqdm
21
+ from transformers import AutoTokenizer
22
+
23
+ from detree.model.simclr import SimCLR_Tree
24
+ from detree.utils.dataset import SCLDataset, load_datapath
25
+
26
+
27
+ @dataclass
28
+ class ExperimentPaths:
29
+ """Utility container describing where to store experiment artefacts."""
30
+
31
+ root: Path
32
+ runs: Path
33
+
34
+
35
+ def _build_collate_fn(tokenizer, max_length: int):
36
+ def collate_fn(batch: Iterable):
37
+ text, label, write_model = default_collate(batch)
38
+ encoded_batch = tokenizer.batch_encode_plus(
39
+ text,
40
+ return_tensors="pt",
41
+ max_length=max_length,
42
+ padding=True,
43
+ truncation=True,
44
+ )
45
+ return encoded_batch, label, write_model
46
+
47
+ return collate_fn
48
+
49
+
50
+ def _prepare_output_dir(
51
+ output_dir: Path, experiment_name: str, resume: bool, *, create_dirs: bool = True
52
+ ) -> ExperimentPaths:
53
+ output_dir = output_dir.expanduser().resolve()
54
+
55
+ candidate = output_dir / experiment_name
56
+ if candidate.exists() and not resume:
57
+ suffix = 0
58
+ while (output_dir / f"{experiment_name}_v{suffix}").exists():
59
+ suffix += 1
60
+ candidate = output_dir / f"{experiment_name}_v{suffix}"
61
+
62
+ runs_dir = candidate / "runs"
63
+ if create_dirs:
64
+ candidate.mkdir(parents=True, exist_ok=True)
65
+ runs_dir.mkdir(parents=True, exist_ok=True)
66
+
67
+ return ExperimentPaths(root=candidate, runs=runs_dir)
68
+
69
+
70
+ def build_argument_parser() -> argparse.ArgumentParser:
71
+ parser = argparse.ArgumentParser(
72
+ description="Train DETree using the hierarchical contrastive objective",
73
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
74
+ )
75
+ parser.add_argument("--model-name", type=str, default="FacebookAI/roberta-large", help="Backbone encoder identifier.")
76
+ parser.add_argument("--device-num", type=int, default=1, help="Number of CUDA devices to use.")
77
+ parser.add_argument("--path", type=Path, required=True, help="Root directory of the dataset.")
78
+ parser.add_argument("--dataset-name", type=str, default="all", help="Dataset configuration name.")
79
+ parser.add_argument(
80
+ "--dataset", type=str, default="train", choices=("train", "test", "extra"), help="Dataset split to consume."
81
+ )
82
+ parser.add_argument("--tree-txt", type=Path, required=True, help="Tree definition file as produced by the HAT pipeline.")
83
+ parser.add_argument("--output-dir", type=Path, default=Path("runs"), help="Directory where experiment folders are saved.")
84
+ parser.add_argument("--experiment-name", type=str, default="detree_experiment", help="Base name for the run directory.")
85
+ parser.add_argument("--resume", action="store_true", help="Reuse the given experiment directory if it already exists.")
86
+
87
+ parser.add_argument("--projection-size", type=int, default=1024)
88
+ parser.add_argument("--temperature", type=float, default=0.07)
89
+ parser.add_argument("--num-workers", type=int, default=8)
90
+ parser.add_argument("--per-gpu-batch-size", type=int, default=64)
91
+ parser.add_argument("--per-gpu-eval-batch-size", type=int, default=16)
92
+ parser.add_argument("--max-length", type=int, default=512, help="Maximum sequence length for the tokenizer.")
93
+ parser.add_argument("--total-epoch", type=int, default=10)
94
+ parser.add_argument("--warmup-steps", type=int, default=2000)
95
+ parser.add_argument("--lr", type=float, default=3e-5)
96
+ parser.add_argument("--min-lr", type=float, default=5e-6)
97
+ parser.add_argument("--weight-decay", type=float, default=1e-4)
98
+ parser.add_argument("--beta1", type=float, default=0.9)
99
+ parser.add_argument("--beta2", type=float, default=0.99)
100
+ parser.add_argument("--eps", type=float, default=1e-6)
101
+ parser.add_argument("--adv-p", type=float, default=0.5, help="Probability of sampling adversarial data.")
102
+ parser.add_argument("--num-workers-eval", type=int, default=8, help="Reserved for compatibility.")
103
+
104
+ parser.add_argument("--lora-r", type=int, default=128)
105
+ parser.add_argument("--lora-alpha", type=int, default=256)
106
+ parser.add_argument("--lora-dropout", type=float, default=0.0)
107
+
108
+ parser.add_argument("--freeze-layer", type=int, default=0, help="Number of initial encoder layers to freeze.")
109
+ parser.add_argument("--seed", type=int, default=42)
110
+
111
+ parser.add_argument("--adapter-path", type=Path, default=None, help="Optional path to resume LoRA training from.")
112
+ parser.add_argument("--pooling", type=str, default="max", choices=("max", "average", "cls"))
113
+
114
+ parser.add_argument("--lora", dest="lora", action="store_true", help="Enable LoRA adapters.")
115
+ parser.add_argument("--no-lora", dest="lora", action="store_false", help="Disable LoRA adapters.")
116
+ parser.set_defaults(lora=True)
117
+
118
+ parser.add_argument("--freeze-embedding-layer", dest="freeze_embedding_layer", action="store_true")
119
+ parser.add_argument("--no-freeze-embedding-layer", dest="freeze_embedding_layer", action="store_false")
120
+ parser.set_defaults(freeze_embedding_layer=True)
121
+
122
+ parser.add_argument("--adversarial", dest="adversarial", action="store_true")
123
+ parser.add_argument("--no-adversarial", dest="adversarial", action="store_false")
124
+ parser.set_defaults(adversarial=True)
125
+
126
+ parser.add_argument("--include-attack", dest="include_attack", action="store_true")
127
+ parser.add_argument("--no-include-attack", dest="include_attack", action="store_false")
128
+ parser.set_defaults(include_attack=True)
129
+
130
+ parser.add_argument("--has-mix", dest="has_mix", action="store_true")
131
+ parser.add_argument("--no-has-mix", dest="has_mix", action="store_false")
132
+ parser.set_defaults(has_mix=True)
133
+
134
+ parser.add_argument("--deepspeed", action="store_true", help="Use DeepSpeed strategy when multiple GPUs are available.")
135
+
136
+ return parser
137
+
138
+
139
+ def train(args: argparse.Namespace) -> None:
140
+ torch.manual_seed(args.seed)
141
+ random.seed(args.seed)
142
+ torch.set_float32_matmul_precision("medium")
143
+
144
+ if args.device_num > 1:
145
+ if args.deepspeed:
146
+ strategy = DeepSpeedStrategy()
147
+ else:
148
+ strategy = DDPStrategy(find_unused_parameters=True)
149
+ fabric = Fabric(accelerator="cuda", precision="bf16-mixed", devices=args.device_num, strategy=strategy)
150
+ else:
151
+ fabric = Fabric(accelerator="cuda", precision="bf16-mixed", devices=args.device_num)
152
+
153
+ fabric.launch()
154
+
155
+ experiment_paths = ExperimentPaths(root=Path(args.output_dir), runs=Path(args.runs_dir))
156
+ if fabric.global_rank == 0:
157
+ experiment_paths.root.mkdir(parents=True, exist_ok=True)
158
+ experiment_paths.runs.mkdir(parents=True, exist_ok=True)
159
+ fabric.barrier()
160
+
161
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name)
162
+ collate_fn = _build_collate_fn(tokenizer, args.max_length)
163
+
164
+ model = SimCLR_Tree(args, fabric).train()
165
+
166
+ data_path = load_datapath(
167
+ str(args.path),
168
+ include_adversarial=args.adversarial,
169
+ dataset_name=args.dataset_name,
170
+ include_attack=args.include_attack,
171
+ )[args.dataset]
172
+
173
+ train_dataset = SCLDataset(
174
+ data_path,
175
+ fabric,
176
+ tokenizer,
177
+ name2id=model.names2id,
178
+ has_mix=args.has_mix,
179
+ adv_p=args.adv_p,
180
+ )
181
+
182
+ passages_dataloader = DataLoader(
183
+ train_dataset,
184
+ batch_size=args.per_gpu_batch_size,
185
+ num_workers=args.num_workers,
186
+ pin_memory=True,
187
+ shuffle=True,
188
+ drop_last=True,
189
+ collate_fn=collate_fn,
190
+ )
191
+
192
+ model.train()
193
+ if args.freeze_embedding_layer:
194
+ for name, param in model.model.named_parameters():
195
+ if "emb" in name or "model.pooler" in name:
196
+ param.requires_grad = False
197
+ if args.freeze_layer > 0:
198
+ for i in range(args.freeze_layer):
199
+ if f"encoder.layer.{i}." in name:
200
+ param.requires_grad = False
201
+
202
+ model = torch.compile(model)
203
+ if fabric.global_rank == 0:
204
+ print("Model has been initialized!")
205
+ for name, param in model.model.named_parameters():
206
+ print(name, param.requires_grad)
207
+
208
+ passages_dataloader = fabric.setup_dataloaders(passages_dataloader, use_distributed_sampler=False)
209
+ if fabric.global_rank == 0:
210
+ print("DataLoader has been initialized!")
211
+
212
+ if fabric.global_rank == 0:
213
+ writer = SummaryWriter(str(experiment_paths.runs))
214
+ print(f"Save dir is {args.output_dir}")
215
+ opt_dict = vars(args)
216
+ opt_dict["output_dir"] = str(args.output_dir)
217
+ with open(Path(args.output_dir) / "config.yaml", "w", encoding="utf-8") as file:
218
+ yaml.dump(opt_dict, file, sort_keys=False)
219
+ else:
220
+ writer = None
221
+
222
+ experiment_dir = experiment_paths.root
223
+
224
+ num_batches_per_epoch = len(passages_dataloader)
225
+ warmup_steps = args.warmup_steps
226
+ lr = args.lr
227
+ total_steps = args.total_epoch * num_batches_per_epoch - warmup_steps
228
+
229
+ optimizer = optim.AdamW(
230
+ filter(lambda p: p.requires_grad, model.parameters()),
231
+ lr=args.lr,
232
+ betas=(args.beta1, args.beta2),
233
+ eps=args.eps,
234
+ weight_decay=args.weight_decay,
235
+ )
236
+
237
+ schedule = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, total_steps, eta_min=args.min_lr)
238
+ model, optimizer = fabric.setup(model, optimizer)
239
+
240
+ if fabric.global_rank == 0:
241
+ for name, param in model.named_parameters():
242
+ if param.requires_grad:
243
+ print(name, param.requires_grad)
244
+
245
+ for epoch in range(args.total_epoch):
246
+ model.train()
247
+ avg_loss = 0.0
248
+ iterator = enumerate(passages_dataloader)
249
+ if fabric.global_rank == 0:
250
+ iterator = tqdm(iterator, total=len(passages_dataloader))
251
+ print(("\n" + "%11s" * 5) % ("Epoch", "GPU_mem", "loss1", "Avgloss", "lr"))
252
+ for i, batch in iterator:
253
+ current_step = epoch * num_batches_per_epoch + i
254
+ if current_step < warmup_steps:
255
+ current_lr = lr * current_step / max(warmup_steps, 1)
256
+ for param_group in optimizer.param_groups:
257
+ param_group["lr"] = current_lr
258
+ current_lr = optimizer.param_groups[0]["lr"]
259
+
260
+ encoded_batch, label, write_model = batch
261
+ loss, loss_classify = model(encoded_batch, write_model)
262
+
263
+ avg_loss = (avg_loss * i + loss.item()) / (i + 1)
264
+ fabric.backward(loss)
265
+ optimizer.step()
266
+ optimizer.zero_grad()
267
+ if current_step >= warmup_steps:
268
+ schedule.step()
269
+
270
+ mem = f"{torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0:.3g}G"
271
+ if fabric.global_rank == 0:
272
+ iterator.set_description(
273
+ ("%11s" * 2 + "%11.4g" * 3)
274
+ % (f"{epoch + 1}/{args.total_epoch}", mem, loss_classify.item(), avg_loss, current_lr)
275
+ )
276
+ if writer and current_step % 10 == 0:
277
+ writer.add_scalar("lr", current_lr, current_step)
278
+ writer.add_scalar("loss", loss.item(), current_step)
279
+ writer.add_scalar("avg_loss", avg_loss, current_step)
280
+ writer.add_scalar("loss_classify", loss_classify.item(), current_step)
281
+
282
+ if fabric.global_rank == 0:
283
+ checkpoint_dir = experiment_dir / f"epoch_{epoch:02d}"
284
+ model.save_pretrained(str(checkpoint_dir), save_tokenizer=(epoch == 0))
285
+ print(f"Saved adapter checkpoint to {checkpoint_dir}", flush=True)
286
+
287
+ last_dir = experiment_dir / "last"
288
+ model.save_pretrained(str(last_dir), save_tokenizer=False)
289
+ print(f"Updated latest checkpoint at {last_dir}", flush=True)
290
+
291
+ fabric.barrier()
292
+
293
+ if writer:
294
+ writer.flush()
295
+ writer.close()
296
+
297
+
298
+ def main(argv: Optional[Iterable[str]] = None) -> None:
299
+ parser = build_argument_parser()
300
+ args = parser.parse_args(argv)
301
+ experiment_paths = _prepare_output_dir(
302
+ args.output_dir, args.experiment_name, resume=args.resume, create_dirs=False
303
+ )
304
+ args.output_dir = str(experiment_paths.root)
305
+ args.runs_dir = str(experiment_paths.runs)
306
+ train(args)
307
+
308
+
309
+ __all__ = ["build_argument_parser", "main", "train"]
310
+
311
+
312
+ if __name__ == "__main__":
313
+ main()