ayush2917 commited on
Commit
6a9e10a
·
verified ·
1 Parent(s): 8449cee

Create train.py

Browse files
Files changed (1) hide show
  1. src/train.py +26 -0
src/train.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # scripts/train.py
2
+ from src.preprocessing import load_and_preprocess_data
3
+ from src.feature_engineering import tokenize_texts
4
+ from src.model import train_model, evaluate_model
5
+ from src.utils import plot_confusion_matrix
6
+
7
+ def main():
8
+ # Load and preprocess data
9
+ train_df, test_df = load_and_preprocess_data(sample=True)
10
+
11
+ # Tokenize
12
+ train_encodings = tokenize_texts(train_df["text"])
13
+ test_encodings = tokenize_texts(test_df["text"])
14
+
15
+ # Train model
16
+ model, label_map = train_model(
17
+ train_encodings, train_df["category"], test_encodings, test_df["category"]
18
+ )
19
+
20
+ # Evaluate
21
+ report, cm = evaluate_model(model, test_encodings, test_df["category"])
22
+ print("Classification Report:\n", report)
23
+ plot_confusion_matrix(cm, list(label_map.keys()))
24
+
25
+ if __name__ == "__main__":
26
+ main()