File size: 5,447 Bytes
5b6f681
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
"""Training script for fine-tuning transformer models."""

import os
import argparse
import json
from typing import Optional
import torch
from transformers import (
    AutoTokenizer, 
    AutoModelForSequenceClassification,
    TrainingArguments, 
    Trainer,
    EarlyStoppingCallback
)
from src.data_utils import load_config, load_and_prepare_dataset, prepare_labels_for_classification
from src.model_utils import compute_metrics, save_model_info, plot_training_history, get_model_size


def setup_training_args(config: dict, output_dir: str) -> TrainingArguments:
    """Setup training arguments from config."""
    training_config = config["training"]
    training_config["output_dir"] = output_dir
    
    return TrainingArguments(**training_config)


def train_model(
    config_path: str = "config.json",
    output_dir: str = "./results",
    resume_from_checkpoint: Optional[str] = None
):
    """
    Main training function.
    
    Args:
        config_path: Path to configuration file
        output_dir: Output directory for model and results
        resume_from_checkpoint: Path to checkpoint to resume from
    """
    # Load configuration
    config = load_config(config_path)
    
    print("πŸš€ Starting training with configuration:")
    print(json.dumps(config, indent=2))
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Load model and tokenizer
    model_name = config["model"]["name"]
    num_labels = config["model"]["num_labels"]
    max_length = config["model"]["max_length"]
    
    print(f"πŸ“¦ Loading model: {model_name}")
    model = AutoModelForSequenceClassification.from_pretrained(
        model_name, 
        num_labels=num_labels
    )
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    # Print model information
    model_info = get_model_size(model)
    print(f"πŸ“Š Model info: {model_info['param_count']:,} parameters, {model_info['total_size_mb']:.1f} MB")
    
    # Load and prepare dataset
    data_config = config["data"]
    print(f"πŸ“š Loading dataset: {data_config['dataset_name']}")
    
    train_dataset, eval_dataset, test_dataset = load_and_prepare_dataset(
        dataset_name=data_config["dataset_name"],
        tokenizer_name=model_name,
        train_size=data_config["train_size"],
        eval_size=data_config["eval_size"],
        test_size=data_config["test_size"],
        max_length=max_length
    )
    
    # Prepare labels
    train_dataset = prepare_labels_for_classification(train_dataset)
    eval_dataset = prepare_labels_for_classification(eval_dataset)
    test_dataset = prepare_labels_for_classification(test_dataset)
    
    print(f"πŸ“ˆ Dataset sizes - Train: {len(train_dataset)}, Eval: {len(eval_dataset)}, Test: {len(test_dataset)}")
    
    # Setup training arguments
    training_args = setup_training_args(config, output_dir)
    
    # Setup trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=2)]
    )
    
    # Train model
    print("🎯 Starting training...")
    if resume_from_checkpoint:
        print(f"πŸ”„ Resuming from checkpoint: {resume_from_checkpoint}")
        trainer.train(resume_from_checkpoint=resume_from_checkpoint)
    else:
        trainer.train()
    
    # Save the model
    print("πŸ’Ύ Saving model...")
    trainer.save_model()
    tokenizer.save_pretrained(output_dir)
    
    # Plot training history
    if hasattr(trainer.state, 'log_history'):
        print("πŸ“Š Plotting training history...")
        plot_training_history(
            trainer.state.log_history, 
            os.path.join(output_dir, "training_history.png")
        )
    
    # Final evaluation on test set
    print("πŸ” Evaluating on test set...")
    test_results = trainer.evaluate(eval_dataset=test_dataset)
    
    print("βœ… Training completed!")
    print("πŸ“‹ Final test results:")
    for key, value in test_results.items():
        print(f"  {key}: {value:.4f}")
    
    # Save model info and metrics
    save_model_info(output_dir, config, test_results)
    
    return trainer, test_results


def main():
    """CLI entry point for training."""
    parser = argparse.ArgumentParser(description="Train a transformer model for sentiment analysis")
    parser.add_argument("--config", type=str, default="config.json", help="Path to config file")
    parser.add_argument("--output_dir", type=str, default="./results", help="Output directory")
    parser.add_argument("--resume", type=str, default=None, help="Resume from checkpoint")
    parser.add_argument("--gpu", action="store_true", help="Force GPU usage (if available)")
    
    args = parser.parse_args()
    
    # Check GPU availability
    if torch.cuda.is_available():
        device = torch.cuda.get_device_name(0)
        print(f"πŸš€ GPU available: {device}")
        if args.gpu:
            os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    else:
        print("πŸ’» Running on CPU")
    
    # Run training
    trainer, results = train_model(
        config_path=args.config,
        output_dir=args.output_dir,
        resume_from_checkpoint=args.resume
    )
    
    print(f"πŸŽ‰ Training finished! Model saved to: {args.output_dir}")


if __name__ == "__main__":
    main()