|
|
""" |
|
|
Training script for CBT binary classifier. |
|
|
Run this script to train the model on your CSV data. |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import logging |
|
|
from binary_classifier import CBTBinaryClassifier |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
|
) |
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description='Train CBT Binary Classifier') |
|
|
parser.add_argument('--normal_csv', required=True, |
|
|
help='Path to CSV file with normal conversations') |
|
|
parser.add_argument('--cbt_csv', required=True, |
|
|
help='Path to CSV file with CBT conversations') |
|
|
parser.add_argument('--text_column', default='text', |
|
|
help='Name of the text column in CSV files') |
|
|
parser.add_argument('--output_dir', default='./cbt_classifier', |
|
|
help='Directory to save the trained model') |
|
|
parser.add_argument('--model_name', default='distilbert-base-uncased', |
|
|
help='Pre-trained model to use (distilbert-base-uncased recommended for laptops)') |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
classifier = CBTBinaryClassifier(model_name=args.model_name) |
|
|
|
|
|
|
|
|
print("Preparing data...") |
|
|
df = classifier.prepare_data( |
|
|
normal_csv_path=args.normal_csv, |
|
|
cbt_csv_path=args.cbt_csv, |
|
|
text_column=args.text_column |
|
|
) |
|
|
|
|
|
|
|
|
print("Tokenizing data...") |
|
|
dataset = classifier.tokenize_data(df) |
|
|
|
|
|
|
|
|
print("Splitting data...") |
|
|
train_dataset, val_dataset, test_dataset = classifier.split_data(dataset) |
|
|
|
|
|
|
|
|
print("Training model...") |
|
|
print("Note: Training optimized for laptop performance (smaller batches, fewer epochs)") |
|
|
classifier.train_model(train_dataset, val_dataset, output_dir=args.output_dir) |
|
|
|
|
|
|
|
|
print("Evaluating model...") |
|
|
classifier.evaluate_model(test_dataset) |
|
|
|
|
|
print(f"\nTraining complete! Model saved to {args.output_dir}") |
|
|
print("\nTo use the model for inference:") |
|
|
print(f"from binary_classifier import CBTBinaryClassifier") |
|
|
print(f"classifier = CBTBinaryClassifier()") |
|
|
print(f"classifier.load_model('{args.output_dir}')") |
|
|
print(f"result = classifier.predict('Your text here')") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |