File size: 1,376 Bytes
86b932c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import os
import sys
import logging
import time
import subprocess

_PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if str(_PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(_PROJECT_ROOT))

logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(name)s | %(message)s")
logger = logging.getLogger("stage3_training")

def run_training(cfg: dict = None):
    t0 = time.perf_counter()
    logger.info("STAGE 3: TRAINING START")
    
    python_exe = sys.executable
    models_dir = os.path.join(_PROJECT_ROOT, "src", "models")
    
    scripts = [
        ("Logistic Regression", "logistic_model.py"),
        ("Bi-LSTM", "lstm_model.py"),
        ("DistilBERT", "distilbert_model.py"),
        ("RoBERTa", "roberta_model.py"),
        ("Meta-Classifier", "meta_classifier.py")
    ]
    
    for name, script_name in scripts:
        script_path = os.path.join(models_dir, script_name)
        logger.info(f"==> Launching {name} Training ({script_name})")
        val = subprocess.run([python_exe, script_path], cwd=_PROJECT_ROOT)
        if val.returncode != 0:
            logger.error(f"{name} aborted with exit code {val.returncode}")
            sys.exit(1)
            
    t_end = time.perf_counter()
    logger.info("STAGE 3 FINISHED in %.2f seconds", t_end - t0)

if __name__ == "__main__":
    run_training()