ktoan911's picture
Upload folder using huggingface_hub
398a289 verified
#!/usr/bin/env python3
"""
Train Fusion MLP + beta only using CSV labeled data.
Required CSV columns: text/claim, evidence, label
"""
import argparse
import os
from loguru import logger
from src.data.csv_loader import CSVLabeledLoader
from src.training.fusion_trainer import FusionTrainingConfig, train_fusion_from_dataframe
from src.utils import normalize_text
def main():
parser = argparse.ArgumentParser(
description="Train Fusion MLP + beta only using CSV labeled data."
)
parser.add_argument(
"--labeled_csv",
type=str,
required=True,
help="Path to the labeled CSV file (text,evidence,label)",
)
parser.add_argument(
"--batch_size", type=int, default=8, help="Batch size for training"
)
parser.add_argument(
"--llm_batch_size", type=int, default=8, help="Batch size for LLM"
)
parser.add_argument(
"--epochs", type=int, default=3, help="Number of training epochs"
)
parser.add_argument(
"--model_path",
type=str,
default=os.getenv("LORA_MODEL_PATH", "models/lora_llm"),
help="Path to the LoRA-trained model (default: models/lora_llm)",
)
parser.add_argument(
"--device",
type=str,
default="cuda"
if os.getenv("CUDA_VISIBLE_DEVICES")
or os.system("nvidia-smi > /dev/null 2>&1") == 0
else "cpu",
help="Device to use (cuda/cpu)",
)
parser.add_argument(
"--save_path",
type=str,
default=os.getenv("FUSION_OUTPUT_PATH", "models/fusion_model.pt"),
help="Path to save the fusion model",
)
parser.add_argument(
"--retriever_model",
type=str,
default=os.getenv("RETRIEVER_MODEL_PATH", "AITeamVN/Vietnamese_Embedding"),
help="Path to trained dense retrieval model (default: models/retriever_model)",
)
args = parser.parse_args()
logger.info(f"Loading labeled data from {args.labeled_csv}...")
labeled_df = CSVLabeledLoader(args.labeled_csv).load()
logger.info(f"Labeled data: {len(labeled_df)} samples")
# Extract evidence and timestamps from dataframe
evidences = labeled_df["evidence"].tolist()
timestamps = (
labeled_df["timestamp"].tolist()
if "timestamp" in labeled_df.columns
else [None] * len(evidences)
)
# Use dict to deduplicate by normalized text, keeping original text
unique_docs = {}
for evidence, ts in zip(evidences, timestamps):
# Split evidence into individual articles
# Evidence articles are separated by |||
evidence_str = str(evidence)
articles = evidence_str.split("|||")
for article in articles:
article = article.strip()
if len(article) > 10: # Filter out empty or very short strings
# Normalize for deduplication key, but store original text
norm_key = normalize_text(article)
if norm_key not in unique_docs:
unique_docs[norm_key] = {
"text": article, # Keep original text
"timestamp": ts,
"source": "csv",
}
else:
# If duplicate, keep the document with non-None timestamp
if ts is not None and unique_docs[norm_key]["timestamp"] is None:
unique_docs[norm_key]["timestamp"] = ts
kb_docs = list(unique_docs.values())
logger.info(
f"Knowledge base built: {len(kb_docs)} unique documents (deduplicated from {len(evidences)} evidence entries)"
)
fusion_config = FusionTrainingConfig(
model_name=args.model_path,
retriever_model=args.retriever_model,
device=args.device,
batch_size=args.batch_size,
llm_batch_size=args.llm_batch_size,
epochs=args.epochs,
)
train_fusion_from_dataframe(
knowledge_base=kb_docs,
labeled_df=labeled_df,
config=fusion_config,
save_path=args.save_path,
)
logger.info(f"Fusion training complete. Model saved to: {args.save_path}")
if __name__ == "__main__":
main()