File size: 835 Bytes
6a9e10a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# scripts/train.py
from src.preprocessing import load_and_preprocess_data
from src.feature_engineering import tokenize_texts
from src.model import train_model, evaluate_model
from src.utils import plot_confusion_matrix

def main():
    # Load and preprocess data
    train_df, test_df = load_and_preprocess_data(sample=True)
    
    # Tokenize
    train_encodings = tokenize_texts(train_df["text"])
    test_encodings = tokenize_texts(test_df["text"])
    
    # Train model
    model, label_map = train_model(
        train_encodings, train_df["category"], test_encodings, test_df["category"]
    )
    
    # Evaluate
    report, cm = evaluate_model(model, test_encodings, test_df["category"])
    print("Classification Report:\n", report)
    plot_confusion_matrix(cm, list(label_map.keys()))

if __name__ == "__main__":
    main()