Spaces:
No application file
No application file
File size: 835 Bytes
6a9e10a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
# scripts/train.py
from src.preprocessing import load_and_preprocess_data
from src.feature_engineering import tokenize_texts
from src.model import train_model, evaluate_model
from src.utils import plot_confusion_matrix
def main():
# Load and preprocess data
train_df, test_df = load_and_preprocess_data(sample=True)
# Tokenize
train_encodings = tokenize_texts(train_df["text"])
test_encodings = tokenize_texts(test_df["text"])
# Train model
model, label_map = train_model(
train_encodings, train_df["category"], test_encodings, test_df["category"]
)
# Evaluate
report, cm = evaluate_model(model, test_encodings, test_df["category"])
print("Classification Report:\n", report)
plot_confusion_matrix(cm, list(label_map.keys()))
if __name__ == "__main__":
main() |