mr-kush's picture
Add dataset metadata handling in training pipeline
625fb41
# train_model.py
from load_dataset import load_dataset_from_hub
from model_pipeline import GrievanceClassifier
from configs import get_config
import time
import os
def run_grievance_training_pipeline():
"""
Load configs, dataset, initialize classifier,
and run the training pipeline with exception handling.
Prints status messages for dynamic terminal viewing.
"""
try:
print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] Loading configurations...", flush=True)
configs = get_config()
# Print a short, non-sensitive summary of configs
print(f"[{time.strftime('%H:%M:%S')}] Configs loaded: dataset_repo_id={configs.dataset_repo_id}, "
f"model_checkpoint={configs.model_checkpoint}, hub_model_id={configs.hub_model_id}, "
f"num_labels={len(configs.label2id)}",
flush=True)
print(f"[{time.strftime('%H:%M:%S')}] Loading dataset from hub: {configs.dataset_repo_id} ...", flush=True)
data = load_dataset_from_hub(
model_repo=configs.dataset_repo_id,
hf_token=configs.hf_token
)
dataset = data['dataset']
dataset_metadata = data['metadata']
# Print dataset splits and sizes if available
def _safe_len(split):
try:
return len(split)
except Exception:
return "unknown"
train_len = _safe_len(dataset.get('train')) if dataset else "no dataset"
eval_len = _safe_len(dataset.get('eval')) if dataset else "no dataset"
test_len = _safe_len(dataset.get('test')) if dataset else "no dataset"
print(f"[{time.strftime('%H:%M:%S')}] Dataset loaded: train={train_len}, eval={eval_len}, test={test_len}", flush=True)
print(f"[{time.strftime('%H:%M:%S')}] Initializing classifier (checkpoint={configs.model_checkpoint}) ...", flush=True)
classifier = GrievanceClassifier(
model_checkpoint=configs.model_checkpoint,
num_labels=len(configs.label2id),
id2label=configs.id2label,
label2id=configs.label2id,
hf_token=configs.hf_token,
wandb_api_key=configs.wandb_api_key,
wandb_project_name=configs.wandb_project_name,
)
print(f"[{time.strftime('%H:%M:%S')}] Classifier initialized.", flush=True)
print(f"[{time.strftime('%H:%M:%S')}] Start training the model ...", flush=True)
result = classifier.train_pipeline(
train_dataset=dataset['train'],
eval_dataset=dataset['eval'],
test_dataset=dataset['test'],
dataset_metadata= dataset_metadata,
space_repo_id=configs.space_repo_id,
hf_training_args={"hub_model_id": configs.hub_model_id},
api_endpoint=configs.api_endpoint,
early_stopping_patience=configs.early_stopping_patience,
deployed_sample_size=configs.deployed_sample_size,
decision_threshold=configs.decision_threshold
)
print(f"[{time.strftime('%H:%M:%S')}] Training completed successfully!", flush=True)
# Print a brief summary of the result if it's a dict-like object
try:
if isinstance(result, dict):
print(f"[{time.strftime('%H:%M:%S')}] Result keys: {list(result.keys())}", flush=True)
else:
print(f"[{time.strftime('%H:%M:%S')}] Result: {result}", flush=True)
except Exception:
print(f"[{time.strftime('%H:%M:%S')}] Training finished (could not display result details).", flush=True)
# pause the space if it was run in the hf_space
if configs.retrain_space_id:
try:
print(f"[{time.strftime('%H:%M:%S')}] Attempting to pause Hugging Face Space...", flush=True)
classifier.api.pause_space(repo_id=configs.retrain_space_id, token=configs.hf_token)
print(f"[{time.strftime('%H:%M:%S')}] Pause command executed.", flush=True)
except Exception as e:
print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] WARNING: Failed to pause HF Space: {e}", flush=True)
return result
except Exception as e:
print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] ERROR: Grievance training pipeline failed: {e}", flush=True)
raise RuntimeError(f"Grievance training pipeline failed: {e}")
if __name__ == "__main__":
run_grievance_training_pipeline()