Spaces:
Sleeping
Sleeping
File size: 4,212 Bytes
398a289 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 | #!/usr/bin/env python3
"""
Train Fusion MLP + beta only using CSV labeled data.
Required CSV columns: text/claim, evidence, label
"""
import argparse
import os
from loguru import logger
from src.data.csv_loader import CSVLabeledLoader
from src.training.fusion_trainer import FusionTrainingConfig, train_fusion_from_dataframe
from src.utils import normalize_text
def main():
parser = argparse.ArgumentParser(
description="Train Fusion MLP + beta only using CSV labeled data."
)
parser.add_argument(
"--labeled_csv",
type=str,
required=True,
help="Path to the labeled CSV file (text,evidence,label)",
)
parser.add_argument(
"--batch_size", type=int, default=8, help="Batch size for training"
)
parser.add_argument(
"--llm_batch_size", type=int, default=8, help="Batch size for LLM"
)
parser.add_argument(
"--epochs", type=int, default=3, help="Number of training epochs"
)
parser.add_argument(
"--model_path",
type=str,
default=os.getenv("LORA_MODEL_PATH", "models/lora_llm"),
help="Path to the LoRA-trained model (default: models/lora_llm)",
)
parser.add_argument(
"--device",
type=str,
default="cuda"
if os.getenv("CUDA_VISIBLE_DEVICES")
or os.system("nvidia-smi > /dev/null 2>&1") == 0
else "cpu",
help="Device to use (cuda/cpu)",
)
parser.add_argument(
"--save_path",
type=str,
default=os.getenv("FUSION_OUTPUT_PATH", "models/fusion_model.pt"),
help="Path to save the fusion model",
)
parser.add_argument(
"--retriever_model",
type=str,
default=os.getenv("RETRIEVER_MODEL_PATH", "AITeamVN/Vietnamese_Embedding"),
help="Path to trained dense retrieval model (default: models/retriever_model)",
)
args = parser.parse_args()
logger.info(f"Loading labeled data from {args.labeled_csv}...")
labeled_df = CSVLabeledLoader(args.labeled_csv).load()
logger.info(f"Labeled data: {len(labeled_df)} samples")
# Extract evidence and timestamps from dataframe
evidences = labeled_df["evidence"].tolist()
timestamps = (
labeled_df["timestamp"].tolist()
if "timestamp" in labeled_df.columns
else [None] * len(evidences)
)
# Use dict to deduplicate by normalized text, keeping original text
unique_docs = {}
for evidence, ts in zip(evidences, timestamps):
# Split evidence into individual articles
# Evidence articles are separated by |||
evidence_str = str(evidence)
articles = evidence_str.split("|||")
for article in articles:
article = article.strip()
if len(article) > 10: # Filter out empty or very short strings
# Normalize for deduplication key, but store original text
norm_key = normalize_text(article)
if norm_key not in unique_docs:
unique_docs[norm_key] = {
"text": article, # Keep original text
"timestamp": ts,
"source": "csv",
}
else:
# If duplicate, keep the document with non-None timestamp
if ts is not None and unique_docs[norm_key]["timestamp"] is None:
unique_docs[norm_key]["timestamp"] = ts
kb_docs = list(unique_docs.values())
logger.info(
f"Knowledge base built: {len(kb_docs)} unique documents (deduplicated from {len(evidences)} evidence entries)"
)
fusion_config = FusionTrainingConfig(
model_name=args.model_path,
retriever_model=args.retriever_model,
device=args.device,
batch_size=args.batch_size,
llm_batch_size=args.llm_batch_size,
epochs=args.epochs,
)
train_fusion_from_dataframe(
knowledge_base=kb_docs,
labeled_df=labeled_df,
config=fusion_config,
save_path=args.save_path,
)
logger.info(f"Fusion training complete. Model saved to: {args.save_path}")
if __name__ == "__main__":
main()
|