Spaces:
Sleeping
Sleeping
File size: 1,960 Bytes
9cbb56b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 | """
Main pipeline for DistilBERT sentiment analysis project
File: main.py
"""
import os
import argparse
from train import (
load_imdb_data,
preprocess_data,
load_model,
setup_trainer,
train_model,
evaluate_model,
save_model
)
# Remove app import since we'll run it separately
def train_pipeline(subset_size=None):
"""Complete training pipeline"""
print("=== Starting Training Pipeline ===")
# 1. Load dataset
dataset = load_imdb_data(subset_size=subset_size)
# 2. Preprocess data
tokenized_dataset, tokenizer = preprocess_data(dataset)
# 3. Load model
model = load_model()
# 4. Setup trainer
trainer = setup_trainer(
model,
tokenizer,
tokenized_dataset["train"],
tokenized_dataset["test"]
)
# 5. Train model
train_model(trainer)
# 6. Evaluate model
results = evaluate_model(trainer)
# 7. Save model
save_model(trainer, tokenizer)
print("=== Training Pipeline Completed ===")
return results
def main():
parser = argparse.ArgumentParser(description="DistilBERT Sentiment Analysis - Training Only")
parser.add_argument("--subset", type=int, default=None,
help="Use subset of data for training (for testing)")
args = parser.parse_args()
# Check if model already exists
if os.path.exists("./model") and os.path.exists("./model/config.json"):
response = input("Model already exists. Retrain? (y/n): ")
if response.lower() != 'y':
print("Skipping training...")
print("To run the app: python app.py")
return
# Train the model
train_pipeline(subset_size=args.subset)
print("\n🎉 Training completed!")
print("To run the app: python app.py")
if __name__ == "__main__":
main() |