File size: 1,960 Bytes
9cbb56b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
"""

Main pipeline for DistilBERT sentiment analysis project

File: main.py

"""

import os
import argparse
from train import (
    load_imdb_data, 
    preprocess_data, 
    load_model, 
    setup_trainer,
    train_model,
    evaluate_model,
    save_model
)
# Remove app import since we'll run it separately

def train_pipeline(subset_size=None):
    """Complete training pipeline"""
    print("=== Starting Training Pipeline ===")
    
    # 1. Load dataset
    dataset = load_imdb_data(subset_size=subset_size)
    
    # 2. Preprocess data
    tokenized_dataset, tokenizer = preprocess_data(dataset)
    
    # 3. Load model
    model = load_model()
    
    # 4. Setup trainer
    trainer = setup_trainer(
        model, 
        tokenizer, 
        tokenized_dataset["train"], 
        tokenized_dataset["test"]
    )
    
    # 5. Train model
    train_model(trainer)
    
    # 6. Evaluate model
    results = evaluate_model(trainer)
    
    # 7. Save model
    save_model(trainer, tokenizer)
    
    print("=== Training Pipeline Completed ===")
    return results

def main():
    parser = argparse.ArgumentParser(description="DistilBERT Sentiment Analysis - Training Only")
    parser.add_argument("--subset", type=int, default=None,
                       help="Use subset of data for training (for testing)")
    
    args = parser.parse_args()
    
    # Check if model already exists
    if os.path.exists("./model") and os.path.exists("./model/config.json"):
        response = input("Model already exists. Retrain? (y/n): ")
        if response.lower() != 'y':
            print("Skipping training...")
            print("To run the app: python app.py")
            return
    
    # Train the model
    train_pipeline(subset_size=args.subset)
    
    print("\n🎉 Training completed!")
    print("To run the app: python app.py")

if __name__ == "__main__":
    main()