Authentica / detree /cli /merge_databases.py
MAS-AI-0000's picture
Upload 4 files
6241c55 verified
"""Merge multiple embedding databases into a single unified database.
Combines text and image (or any number of) ``.pt`` embedding databases into
one file that can be loaded by the DeTree kNN evaluator or ``Detector``.
Labels are remapped to a binary scheme: **0 = AI / LLM**, **1 = Real / Human**.
All layers present across every database are preserved in the output. When
two databases both contain the same layer (e.g. the layer the image projector
was trained on), their embeddings are concatenated at that layer. Layers
that exist in only one database are passed through unchanged.
Because different layers can contain different numbers of entries, the output
stores ``labels`` and ``ids`` as dicts keyed by layer index rather than as
flat tensors. The ``Detector`` already supports this format.
Typical usage::
python -m detree.cli.merge_databases \\
--databases databases/text_compressed.pt databases/image_embeddings.pt \\
--output databases/merged.pt
"""
from __future__ import annotations
import argparse
from pathlib import Path
from typing import Iterable, List, Optional
import torch
# ======================================================================
# Argument parser
# ======================================================================
def build_argument_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description=(
"Merge two or more .pt embedding databases (e.g. text + image) "
"into a single unified database for kNN evaluation. "
"All layers present in any input database are preserved."
),
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--databases",
type=Path,
nargs="+",
required=True,
help="Paths to the .pt databases to merge (order does not matter).",
)
parser.add_argument(
"--output",
type=Path,
required=True,
help="Output path for the merged database (.pt).",
)
return parser
# ======================================================================
# Merge logic
# ======================================================================
def merge_databases(args: argparse.Namespace) -> None:
# Accumulate per-layer embeddings and labels across all input databases.
# Structure: {layer_int: {"embeddings": [...], "labels": [...]}}
layer_data: dict = {}
print(f"Merging {len(args.databases)} database(s) — keeping all layers\n")
for db_path in args.databases:
data = torch.load(db_path, map_location="cpu")
embeddings = data["embeddings"]
labels = data["labels"]
classes = data["classes"]
# Normalise embeddings to always be a layer-keyed dict
if not isinstance(embeddings, dict):
raise ValueError(
f"{db_path}: 'embeddings' must be a dict keyed by layer index."
)
# Normalise labels: may already be a per-layer dict or a flat tensor
if isinstance(labels, dict):
label_per_layer: dict = {int(k): v for k, v in labels.items()}
else:
# Flat tensor — shared across all layers in this database
label_per_layer = {int(k): labels for k in embeddings}
available_layers = sorted(int(k) for k in embeddings)
print(f" {db_path.name}: layers {available_layers}")
for layer in available_layers:
layer_emb = embeddings[layer]
layer_labels_raw = label_per_layer.get(layer, label_per_layer.get(min(label_per_layer)))
# Remap labels to binary 0=AI, 1=Human
if "human" in classes:
human_idx = classes.index("human")
binary_labels = (layer_labels_raw == human_idx).long()
else:
binary_labels = layer_labels_raw.long()
n_entries = binary_labels.shape[0]
n_ai = int((binary_labels == 0).sum().item())
n_real = int((binary_labels == 1).sum().item())
print(
f" layer {layer}: {n_entries} entries "
f"({n_ai} AI, {n_real} Real), dim={layer_emb.shape[1]}"
)
if layer not in layer_data:
layer_data[layer] = {"embeddings": [], "labels": []}
layer_data[layer]["embeddings"].append(layer_emb)
layer_data[layer]["labels"].append(binary_labels)
# --- build merged tensors per layer ----------------------------------
merged_emb_dict: dict = {}
merged_label_dict: dict = {}
merged_id_dict: dict = {}
print()
total_unique = 0
id_offset = 0
for layer in sorted(layer_data):
embs = torch.cat(layer_data[layer]["embeddings"], dim=0)
labs = torch.cat(layer_data[layer]["labels"], dim=0)
ids = torch.arange(id_offset, id_offset + embs.shape[0], dtype=torch.long)
id_offset += embs.shape[0]
total_unique += embs.shape[0]
merged_emb_dict[layer] = embs
merged_label_dict[layer] = labs
merged_id_dict[layer] = ids
n_ai = int((labs == 0).sum().item())
n_real = int((labs == 1).sum().item())
print(
f" Merged layer {layer}: {embs.shape[0]} entries "
f"({n_ai} AI, {n_real} Real), dim={embs.shape[1]}"
)
emb_dict = {
"embeddings": merged_emb_dict,
"labels": merged_label_dict,
"ids": merged_id_dict,
"classes": ["llm", "human"],
}
args.output.parent.mkdir(parents=True, exist_ok=True)
torch.save(emb_dict, args.output)
print(
f"\nMerged database saved to {args.output}\n"
f" {len(merged_emb_dict)} layers: {sorted(merged_emb_dict)}"
)
# ======================================================================
# Entry-point
# ======================================================================
def main(argv: Optional[Iterable[str]] = None) -> None:
parser = build_argument_parser()
args = parser.parse_args(argv)
merge_databases(args)
if __name__ == "__main__":
main()
__all__ = ["build_argument_parser", "merge_databases", "main"]