Binary_classifier / train_classifier.py
SaitejaJate's picture
Upload 5 files
88b8fd6 verified
"""
Training script for CBT binary classifier.
Run this script to train the model on your CSV data.
"""
import argparse
import logging
from binary_classifier import CBTBinaryClassifier
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
def main():
parser = argparse.ArgumentParser(description='Train CBT Binary Classifier')
parser.add_argument('--normal_csv', required=True,
help='Path to CSV file with normal conversations')
parser.add_argument('--cbt_csv', required=True,
help='Path to CSV file with CBT conversations')
parser.add_argument('--text_column', default='text',
help='Name of the text column in CSV files')
parser.add_argument('--output_dir', default='./cbt_classifier',
help='Directory to save the trained model')
parser.add_argument('--model_name', default='distilbert-base-uncased',
help='Pre-trained model to use (distilbert-base-uncased recommended for laptops)')
args = parser.parse_args()
# Initialize classifier
classifier = CBTBinaryClassifier(model_name=args.model_name)
# Prepare data
print("Preparing data...")
df = classifier.prepare_data(
normal_csv_path=args.normal_csv,
cbt_csv_path=args.cbt_csv,
text_column=args.text_column
)
# Tokenize data
print("Tokenizing data...")
dataset = classifier.tokenize_data(df)
# Split data
print("Splitting data...")
train_dataset, val_dataset, test_dataset = classifier.split_data(dataset)
# Train model
print("Training model...")
print("Note: Training optimized for laptop performance (smaller batches, fewer epochs)")
classifier.train_model(train_dataset, val_dataset, output_dir=args.output_dir)
# Evaluate model
print("Evaluating model...")
classifier.evaluate_model(test_dataset)
print(f"\nTraining complete! Model saved to {args.output_dir}")
print("\nTo use the model for inference:")
print(f"from binary_classifier import CBTBinaryClassifier")
print(f"classifier = CBTBinaryClassifier()")
print(f"classifier.load_model('{args.output_dir}')")
print(f"result = classifier.predict('Your text here')")
if __name__ == "__main__":
main()