| |
| from load_dataset import load_dataset_from_hub |
| from model_pipeline import GrievanceClassifier |
| from configs import get_config |
| import time |
| import os |
|
|
| def run_grievance_training_pipeline(): |
| """ |
| Load configs, dataset, initialize classifier, |
| and run the training pipeline with exception handling. |
| Prints status messages for dynamic terminal viewing. |
| """ |
| try: |
| print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] Loading configurations...", flush=True) |
| configs = get_config() |
| |
| print(f"[{time.strftime('%H:%M:%S')}] Configs loaded: dataset_repo_id={configs.dataset_repo_id}, " |
| f"model_checkpoint={configs.model_checkpoint}, hub_model_id={configs.hub_model_id}, " |
| f"num_labels={len(configs.label2id)}", |
| flush=True) |
|
|
|
|
| print(f"[{time.strftime('%H:%M:%S')}] Loading dataset from hub: {configs.dataset_repo_id} ...", flush=True) |
| data = load_dataset_from_hub( |
| model_repo=configs.dataset_repo_id, |
| hf_token=configs.hf_token |
| ) |
| dataset = data['dataset'] |
| dataset_metadata = data['metadata'] |
|
|
| |
| def _safe_len(split): |
| try: |
| return len(split) |
| except Exception: |
| return "unknown" |
| train_len = _safe_len(dataset.get('train')) if dataset else "no dataset" |
| eval_len = _safe_len(dataset.get('eval')) if dataset else "no dataset" |
| test_len = _safe_len(dataset.get('test')) if dataset else "no dataset" |
| print(f"[{time.strftime('%H:%M:%S')}] Dataset loaded: train={train_len}, eval={eval_len}, test={test_len}", flush=True) |
|
|
| print(f"[{time.strftime('%H:%M:%S')}] Initializing classifier (checkpoint={configs.model_checkpoint}) ...", flush=True) |
| classifier = GrievanceClassifier( |
| model_checkpoint=configs.model_checkpoint, |
| num_labels=len(configs.label2id), |
| id2label=configs.id2label, |
| label2id=configs.label2id, |
| hf_token=configs.hf_token, |
| wandb_api_key=configs.wandb_api_key, |
| wandb_project_name=configs.wandb_project_name, |
| ) |
| print(f"[{time.strftime('%H:%M:%S')}] Classifier initialized.", flush=True) |
|
|
| print(f"[{time.strftime('%H:%M:%S')}] Start training the model ...", flush=True) |
| result = classifier.train_pipeline( |
| train_dataset=dataset['train'], |
| eval_dataset=dataset['eval'], |
| test_dataset=dataset['test'], |
| dataset_metadata= dataset_metadata, |
| space_repo_id=configs.space_repo_id, |
| hf_training_args={"hub_model_id": configs.hub_model_id}, |
| api_endpoint=configs.api_endpoint, |
| early_stopping_patience=configs.early_stopping_patience, |
| deployed_sample_size=configs.deployed_sample_size, |
| decision_threshold=configs.decision_threshold |
| ) |
|
|
| print(f"[{time.strftime('%H:%M:%S')}] Training completed successfully!", flush=True) |
| |
| try: |
| if isinstance(result, dict): |
| print(f"[{time.strftime('%H:%M:%S')}] Result keys: {list(result.keys())}", flush=True) |
| else: |
| print(f"[{time.strftime('%H:%M:%S')}] Result: {result}", flush=True) |
| except Exception: |
| print(f"[{time.strftime('%H:%M:%S')}] Training finished (could not display result details).", flush=True) |
|
|
| |
| if configs.retrain_space_id: |
| try: |
| print(f"[{time.strftime('%H:%M:%S')}] Attempting to pause Hugging Face Space...", flush=True) |
| |
| classifier.api.pause_space(repo_id=configs.retrain_space_id, token=configs.hf_token) |
| |
| print(f"[{time.strftime('%H:%M:%S')}] Pause command executed.", flush=True) |
| except Exception as e: |
| print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] WARNING: Failed to pause HF Space: {e}", flush=True) |
|
|
| return result |
|
|
| |
| except Exception as e: |
| print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] ERROR: Grievance training pipeline failed: {e}", flush=True) |
| raise RuntimeError(f"Grievance training pipeline failed: {e}") |
|
|
|
|
| if __name__ == "__main__": |
| run_grievance_training_pipeline() |
|
|