""" Main Training Script This script orchestrates the model training process. It is configuration-driven and uses MLflow for experiment tracking. Usage: python scripts/train.py --config-path configs/base_config.yaml """ from pathlib import Path import sys import argparse import yaml from typing import Dict, Optional, Any import pandas as pd import torch import mlflow from torch.utils.data import DataLoader, TensorDataset from tqdm import tqdm # Ensure the backend is in the path to import registry and preprocessing sys.path.append(str(Path(__file__).resolve().parents[1])) from config import TARGET_LEN from backend.utils.preprocessing import preprocess_spectrum from models.registry import build def load_data(data_path: Path, target_len: int): """Load and preprocess data from a CSV file.""" df = pd.read_csv(data_path) # This is a placeholder for your actual data loading. # You need to parse your 'spectra' column into x and y values. # For this example, we assume 'y_values' are stored as a string of numbers. # A more robust solution would use np.load or similar if data is saved in binary format. all_y = [] # This loop is inefficient and for demonstration only. Vectorize in production. for _, row in tqdm(df.iterrows(), total=len(df), desc=f"Processing {data_path.name}"): # Dummy x_values, as preprocess_spectrum primarily uses y_values x_values = range(len(row['spectrum'].split())) y_values = [float(y) for y in row['spectrum'].split()] _, y_processed = preprocess_spectrum( x_values, y_values, modality='raman') all_y.append(y_processed) features = torch.tensor(all_y, dtype=torch.float32).unsqueeze(1) labels = torch.tensor(df['label'].values, dtype=torch.long) return TensorDataset(features, labels) def train(config: dict, jobs_db: Optional[Dict[str, Any]] = None, job_id: Optional[str] = None): """Main training and validation loop.""" try: # --- MLflow Setup --- mlflow.set_experiment(config['experiment_name']) with mlflow.start_run(run_name=config.get('run_name', 'default_run')) as run: mlflow.log_params(config) if jobs_db and job_id: jobs_db[job_id]['mlflow_run_id'] = run.info.run_id jobs_db[job_id]['status'] = 'RUNNING' print(f"MLflow Run ID: {run.info.run_id}") # --- Data Loading --- data_dir = Path(config['data_dir']) train_dataset = load_data(data_dir / config['train_csv'], TARGET_LEN) val_dataset = load_data(data_dir / config['val_csv'], TARGET_LEN) train_loader = DataLoader( train_dataset, batch_size=config['batch_size'], shuffle=True) val_loader = DataLoader(val_dataset, batch_size=config['batch_size']) # --- Model, Optimizer, Loss --- device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") model = build(config['model_name'], TARGET_LEN).to(device) optimizer = getattr(torch.optim, config['optimizer'])( model.parameters(), lr=config['learning_rate']) criterion = getattr(torch.nn, config['loss_function'])() # --- Training Loop --- best_val_loss = float('inf') for epoch in range(config['epochs']): model.train() train_loss = 0.0 for features, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['epochs']} [Train]"): features, labels = features.to(device), labels.to(device) optimizer.zero_grad() outputs = model(features) loss = criterion(outputs, labels) loss.backward() optimizer.step() train_loss += loss.item() avg_train_loss = train_loss / len(train_loader) mlflow.log_metric("train_loss", avg_train_loss, step=epoch) # --- Validation Loop --- model.eval() val_loss = 0.0 with torch.no_grad(): for features, labels in tqdm(val_loader, desc=f"Epoch {epoch+1}/{config['epochs']} [Val]"): features, labels = features.to(device), labels.to(device) outputs = model(features) loss = criterion(outputs, labels) val_loss += loss.item() avg_val_loss = val_loss / len(val_loader) mlflow.log_metric("val_loss", avg_val_loss, step=epoch) print( f"Epoch {epoch+1}: Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}") # --- Progress Update for Web UI --- if jobs_db and job_id: progress = (epoch + 1) / config['epochs'] jobs_db[job_id]['progress'] = progress jobs_db[job_id]['metrics']['train_loss'].append(avg_train_loss) jobs_db[job_id]['metrics']['val_loss'].append(avg_val_loss) jobs_db[job_id]['current_epoch'] = epoch + 1 # --- Save Best Model --- if avg_val_loss < best_val_loss: best_val_loss = avg_val_loss mlflow.pytorch.log_model( model, "model", registered_model_name=f"{config.get('run_name', 'default_run')}_best") print( f"New best model saved at epoch {epoch+1} with validation loss: {best_val_loss:.4f}") if jobs_db and job_id: jobs_db[job_id]['status'] = 'COMPLETED' jobs_db[job_id]['progress'] = 1.0 print("✅ Training complete.") except Exception as e: print(f"❌ Training failed: {e}") if jobs_db and job_id: jobs_db[job_id]['status'] = 'FAILED' jobs_db[job_id]['error'] = str(e) raise if __name__ == "__main__": parser = argparse.ArgumentParser( description="Train a spectral classification model.") parser.add_argument( "--config-path", type=Path, required=True, help="Path to the YAML configuration file." ) args = parser.parse_args() with open(args.config_path, 'r', encoding='utf-8') as f: config = yaml.safe_load(f) # Run training from CLI without web-specific job tracking train(config=config)