File size: 4,212 Bytes
398a289
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#!/usr/bin/env python3
"""
Train Fusion MLP + beta only using CSV labeled data.
Required CSV columns: text/claim, evidence, label
"""

import argparse
import os

from loguru import logger

from src.data.csv_loader import CSVLabeledLoader
from src.training.fusion_trainer import FusionTrainingConfig, train_fusion_from_dataframe
from src.utils import normalize_text


def main():
    parser = argparse.ArgumentParser(
        description="Train Fusion MLP + beta only using CSV labeled data."
    )
    parser.add_argument(
        "--labeled_csv",
        type=str,
        required=True,
        help="Path to the labeled CSV file (text,evidence,label)",
    )
    parser.add_argument(
        "--batch_size", type=int, default=8, help="Batch size for training"
    )
    parser.add_argument(
        "--llm_batch_size", type=int, default=8, help="Batch size for LLM"
    )
    parser.add_argument(
        "--epochs", type=int, default=3, help="Number of training epochs"
    )
    parser.add_argument(
        "--model_path",
        type=str,
        default=os.getenv("LORA_MODEL_PATH", "models/lora_llm"),
        help="Path to the LoRA-trained model (default: models/lora_llm)",
    )
    parser.add_argument(
        "--device",
        type=str,
        default="cuda"
        if os.getenv("CUDA_VISIBLE_DEVICES")
        or os.system("nvidia-smi > /dev/null 2>&1") == 0
        else "cpu",
        help="Device to use (cuda/cpu)",
    )
    parser.add_argument(
        "--save_path",
        type=str,
        default=os.getenv("FUSION_OUTPUT_PATH", "models/fusion_model.pt"),
        help="Path to save the fusion model",
    )
    parser.add_argument(
        "--retriever_model",
        type=str,
        default=os.getenv("RETRIEVER_MODEL_PATH", "AITeamVN/Vietnamese_Embedding"),
        help="Path to trained dense retrieval model (default: models/retriever_model)",
    )

    args = parser.parse_args()

    logger.info(f"Loading labeled data from {args.labeled_csv}...")
    labeled_df = CSVLabeledLoader(args.labeled_csv).load()
    logger.info(f"Labeled data: {len(labeled_df)} samples")

    # Extract evidence and timestamps from dataframe
    evidences = labeled_df["evidence"].tolist()
    timestamps = (
        labeled_df["timestamp"].tolist()
        if "timestamp" in labeled_df.columns
        else [None] * len(evidences)
    )

    # Use dict to deduplicate by normalized text, keeping original text
    unique_docs = {}

    for evidence, ts in zip(evidences, timestamps):
        # Split evidence into individual articles
        # Evidence articles are separated by |||
        evidence_str = str(evidence)
        articles = evidence_str.split("|||")

        for article in articles:
            article = article.strip()
            if len(article) > 10:  # Filter out empty or very short strings
                # Normalize for deduplication key, but store original text
                norm_key = normalize_text(article)

                if norm_key not in unique_docs:
                    unique_docs[norm_key] = {
                        "text": article,  # Keep original text
                        "timestamp": ts,
                        "source": "csv",
                    }
                else:
                    # If duplicate, keep the document with non-None timestamp
                    if ts is not None and unique_docs[norm_key]["timestamp"] is None:
                        unique_docs[norm_key]["timestamp"] = ts

    kb_docs = list(unique_docs.values())
    logger.info(
        f"Knowledge base built: {len(kb_docs)} unique documents (deduplicated from {len(evidences)} evidence entries)"
    )

    fusion_config = FusionTrainingConfig(
        model_name=args.model_path,
        retriever_model=args.retriever_model,
        device=args.device,
        batch_size=args.batch_size,
        llm_batch_size=args.llm_batch_size,
        epochs=args.epochs,
    )

    train_fusion_from_dataframe(
        knowledge_base=kb_docs,
        labeled_df=labeled_df,
        config=fusion_config,
        save_path=args.save_path,
    )

    logger.info(f"Fusion training complete. Model saved to: {args.save_path}")


if __name__ == "__main__":
    main()