|
|
|
|
|
"""Training script for PHDM 21D Embedding Model. |
|
|
|
|
|
Trains a sentence-transformers embedding model on the SCBE-AETHERMOORE |
|
|
knowledge base, projecting into 21-dimensional Poincare Ball space. |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import argparse |
|
|
import os |
|
|
from pathlib import Path |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from datasets import load_dataset |
|
|
from sentence_transformers import ( |
|
|
SentenceTransformer, |
|
|
SentenceTransformerTrainer, |
|
|
SentenceTransformerTrainingArguments, |
|
|
losses, |
|
|
) |
|
|
from sentence_transformers.training_args import BatchSamplers |
|
|
|
|
|
|
|
|
|
|
|
PHDM_DIM = 21 |
|
|
NEUROTRANSMITTER_WEIGHTS = { |
|
|
"KO": 1.0, |
|
|
"AV": 1.62, |
|
|
"RU": 2.62, |
|
|
"CA": 4.24, |
|
|
"UM": 6.85, |
|
|
"DR": 11.09, |
|
|
} |
|
|
|
|
|
|
|
|
def parse_args() -> argparse.Namespace: |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Train PHDM 21D embedding model on knowledge base." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--base-model", |
|
|
default="sentence-transformers/all-MiniLM-L6-v2", |
|
|
help="Base sentence transformer model to fine-tune.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--dataset-id", |
|
|
default="issdandavis/scbe-aethermoore-knowledge-base", |
|
|
help="HuggingFace dataset ID for training data.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output-dir", |
|
|
default="./phdm-model-output", |
|
|
help="Directory for model checkpoints.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--epochs", |
|
|
type=int, |
|
|
default=3, |
|
|
help="Number of training epochs.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--batch-size", |
|
|
type=int, |
|
|
default=16, |
|
|
help="Training batch size.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--learning-rate", |
|
|
type=float, |
|
|
default=2e-5, |
|
|
help="Learning rate.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--token", |
|
|
default=os.environ.get("HF_TOKEN"), |
|
|
help="HuggingFace token. Defaults to HF_TOKEN env var.", |
|
|
) |
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def prepare_training_pairs(dataset): |
|
|
"""Prepare (anchor, positive) pairs from knowledge base records.""" |
|
|
pairs = [] |
|
|
for record in dataset: |
|
|
title = record.get("title", "") |
|
|
text = record.get("text", "") |
|
|
if title and text and len(text) > 50: |
|
|
|
|
|
pairs.append({"anchor": title, "positive": text[:512]}) |
|
|
return pairs |
|
|
|
|
|
|
|
|
def main() -> None: |
|
|
args = parse_args() |
|
|
|
|
|
print(f"Loading base model: {args.base_model}") |
|
|
model = SentenceTransformer(args.base_model) |
|
|
|
|
|
print(f"Loading dataset: {args.dataset_id}") |
|
|
try: |
|
|
dataset = load_dataset( |
|
|
args.dataset_id, |
|
|
split="train", |
|
|
token=args.token, |
|
|
) |
|
|
print(f"Loaded {len(dataset)} records") |
|
|
except Exception as exc: |
|
|
raise SystemExit(f"Failed to load dataset: {exc}") from exc |
|
|
|
|
|
|
|
|
print("Preparing training pairs...") |
|
|
train_pairs = prepare_training_pairs(dataset) |
|
|
print(f"Created {len(train_pairs)} training pairs") |
|
|
|
|
|
if not train_pairs: |
|
|
raise SystemExit("No valid training pairs found in dataset.") |
|
|
|
|
|
|
|
|
training_args = SentenceTransformerTrainingArguments( |
|
|
output_dir=args.output_dir, |
|
|
num_train_epochs=args.epochs, |
|
|
per_device_train_batch_size=args.batch_size, |
|
|
learning_rate=args.learning_rate, |
|
|
warmup_ratio=0.1, |
|
|
fp16=torch.cuda.is_available(), |
|
|
batch_sampler=BatchSamplers.NO_DUPLICATES, |
|
|
eval_strategy="no", |
|
|
save_strategy="epoch", |
|
|
logging_steps=100, |
|
|
save_total_limit=2, |
|
|
) |
|
|
|
|
|
|
|
|
loss = losses.MultipleNegativesRankingLoss(model) |
|
|
|
|
|
|
|
|
trainer = SentenceTransformerTrainer( |
|
|
model=model, |
|
|
args=training_args, |
|
|
train_dataset=train_pairs, |
|
|
loss=loss, |
|
|
) |
|
|
|
|
|
print("Starting training...") |
|
|
trainer.train() |
|
|
|
|
|
|
|
|
final_path = Path(args.output_dir) / "final" |
|
|
model.save(str(final_path)) |
|
|
print(f"Model saved to: {final_path}") |
|
|
|
|
|
|
|
|
if args.token: |
|
|
print("Pushing model to HuggingFace Hub...") |
|
|
model.push_to_hub( |
|
|
"issdandavis/phdm-21d-embedding", |
|
|
token=args.token, |
|
|
commit_message="feat: update model weights from training", |
|
|
) |
|
|
print("Model pushed to Hub!") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |