Spaces:
Runtime error
Runtime error
| # train_model.py | |
| from load_dataset import load_dataset_from_hub | |
| from model_pipeline import GrievanceClassifier | |
| from configs import get_config | |
| import time | |
| import os | |
| def run_grievance_training_pipeline(): | |
| """ | |
| Load configs, dataset, initialize classifier, | |
| and run the training pipeline with exception handling. | |
| Prints status messages for dynamic terminal viewing. | |
| """ | |
| try: | |
| print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] Loading configurations...", flush=True) | |
| configs = get_config() | |
| # Print a short, non-sensitive summary of configs | |
| print(f"[{time.strftime('%H:%M:%S')}] Configs loaded: dataset_repo_id={configs.dataset_repo_id}, " | |
| f"model_checkpoint={configs.model_checkpoint}, hub_model_id={configs.hub_model_id}, " | |
| f"num_labels={len(configs.label2id)}", | |
| flush=True) | |
| print(f"[{time.strftime('%H:%M:%S')}] Loading dataset from hub: {configs.dataset_repo_id} ...", flush=True) | |
| data = load_dataset_from_hub( | |
| model_repo=configs.dataset_repo_id, | |
| hf_token=configs.hf_token | |
| ) | |
| dataset = data['dataset'] | |
| dataset_metadata = data['metadata'] | |
| # Print dataset splits and sizes if available | |
| def _safe_len(split): | |
| try: | |
| return len(split) | |
| except Exception: | |
| return "unknown" | |
| train_len = _safe_len(dataset.get('train')) if dataset else "no dataset" | |
| eval_len = _safe_len(dataset.get('eval')) if dataset else "no dataset" | |
| test_len = _safe_len(dataset.get('test')) if dataset else "no dataset" | |
| print(f"[{time.strftime('%H:%M:%S')}] Dataset loaded: train={train_len}, eval={eval_len}, test={test_len}", flush=True) | |
| print(f"[{time.strftime('%H:%M:%S')}] Initializing classifier (checkpoint={configs.model_checkpoint}) ...", flush=True) | |
| classifier = GrievanceClassifier( | |
| model_checkpoint=configs.model_checkpoint, | |
| num_labels=len(configs.label2id), | |
| id2label=configs.id2label, | |
| label2id=configs.label2id, | |
| hf_token=configs.hf_token, | |
| wandb_api_key=configs.wandb_api_key, | |
| wandb_project_name=configs.wandb_project_name, | |
| ) | |
| print(f"[{time.strftime('%H:%M:%S')}] Classifier initialized.", flush=True) | |
| print(f"[{time.strftime('%H:%M:%S')}] Start training the model ...", flush=True) | |
| result = classifier.train_pipeline( | |
| train_dataset=dataset['train'], | |
| eval_dataset=dataset['eval'], | |
| test_dataset=dataset['test'], | |
| dataset_metadata= dataset_metadata, | |
| space_repo_id=configs.space_repo_id, | |
| hf_training_args={"hub_model_id": configs.hub_model_id}, | |
| api_endpoint=configs.api_endpoint, | |
| early_stopping_patience=configs.early_stopping_patience, | |
| deployed_sample_size=configs.deployed_sample_size, | |
| decision_threshold=configs.decision_threshold | |
| ) | |
| print(f"[{time.strftime('%H:%M:%S')}] Training completed successfully!", flush=True) | |
| # Print a brief summary of the result if it's a dict-like object | |
| try: | |
| if isinstance(result, dict): | |
| print(f"[{time.strftime('%H:%M:%S')}] Result keys: {list(result.keys())}", flush=True) | |
| else: | |
| print(f"[{time.strftime('%H:%M:%S')}] Result: {result}", flush=True) | |
| except Exception: | |
| print(f"[{time.strftime('%H:%M:%S')}] Training finished (could not display result details).", flush=True) | |
| # pause the space if it was run in the hf_space | |
| if configs.retrain_space_id: | |
| try: | |
| print(f"[{time.strftime('%H:%M:%S')}] Attempting to pause Hugging Face Space...", flush=True) | |
| classifier.api.pause_space(repo_id=configs.retrain_space_id, token=configs.hf_token) | |
| print(f"[{time.strftime('%H:%M:%S')}] Pause command executed.", flush=True) | |
| except Exception as e: | |
| print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] WARNING: Failed to pause HF Space: {e}", flush=True) | |
| return result | |
| except Exception as e: | |
| print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] ERROR: Grievance training pipeline failed: {e}", flush=True) | |
| raise RuntimeError(f"Grievance training pipeline failed: {e}") | |
| if __name__ == "__main__": | |
| run_grievance_training_pipeline() | |