#!/usr/bin/env python3 """ Train Fusion MLP + beta only using CSV labeled data. Required CSV columns: text/claim, evidence, label """ import argparse import os from loguru import logger from src.data.csv_loader import CSVLabeledLoader from src.training.fusion_trainer import FusionTrainingConfig, train_fusion_from_dataframe from src.utils import normalize_text def main(): parser = argparse.ArgumentParser( description="Train Fusion MLP + beta only using CSV labeled data." ) parser.add_argument( "--labeled_csv", type=str, required=True, help="Path to the labeled CSV file (text,evidence,label)", ) parser.add_argument( "--batch_size", type=int, default=8, help="Batch size for training" ) parser.add_argument( "--llm_batch_size", type=int, default=8, help="Batch size for LLM" ) parser.add_argument( "--epochs", type=int, default=3, help="Number of training epochs" ) parser.add_argument( "--model_path", type=str, default=os.getenv("LORA_MODEL_PATH", "models/lora_llm"), help="Path to the LoRA-trained model (default: models/lora_llm)", ) parser.add_argument( "--device", type=str, default="cuda" if os.getenv("CUDA_VISIBLE_DEVICES") or os.system("nvidia-smi > /dev/null 2>&1") == 0 else "cpu", help="Device to use (cuda/cpu)", ) parser.add_argument( "--save_path", type=str, default=os.getenv("FUSION_OUTPUT_PATH", "models/fusion_model.pt"), help="Path to save the fusion model", ) parser.add_argument( "--retriever_model", type=str, default=os.getenv("RETRIEVER_MODEL_PATH", "AITeamVN/Vietnamese_Embedding"), help="Path to trained dense retrieval model (default: models/retriever_model)", ) args = parser.parse_args() logger.info(f"Loading labeled data from {args.labeled_csv}...") labeled_df = CSVLabeledLoader(args.labeled_csv).load() logger.info(f"Labeled data: {len(labeled_df)} samples") # Extract evidence and timestamps from dataframe evidences = labeled_df["evidence"].tolist() timestamps = ( labeled_df["timestamp"].tolist() if "timestamp" in labeled_df.columns else [None] * len(evidences) ) # Use dict to deduplicate by normalized text, keeping original text unique_docs = {} for evidence, ts in zip(evidences, timestamps): # Split evidence into individual articles # Evidence articles are separated by ||| evidence_str = str(evidence) articles = evidence_str.split("|||") for article in articles: article = article.strip() if len(article) > 10: # Filter out empty or very short strings # Normalize for deduplication key, but store original text norm_key = normalize_text(article) if norm_key not in unique_docs: unique_docs[norm_key] = { "text": article, # Keep original text "timestamp": ts, "source": "csv", } else: # If duplicate, keep the document with non-None timestamp if ts is not None and unique_docs[norm_key]["timestamp"] is None: unique_docs[norm_key]["timestamp"] = ts kb_docs = list(unique_docs.values()) logger.info( f"Knowledge base built: {len(kb_docs)} unique documents (deduplicated from {len(evidences)} evidence entries)" ) fusion_config = FusionTrainingConfig( model_name=args.model_path, retriever_model=args.retriever_model, device=args.device, batch_size=args.batch_size, llm_batch_size=args.llm_batch_size, epochs=args.epochs, ) train_fusion_from_dataframe( knowledge_base=kb_docs, labeled_df=labeled_df, config=fusion_config, save_path=args.save_path, ) logger.info(f"Fusion training complete. Model saved to: {args.save_path}") if __name__ == "__main__": main()