Spaces:
Running
Running
| """ | |
| Scheduled Training Script for Gapura AI | |
| Automatically retrain models when new data is available or on schedule | |
| """ | |
| import os | |
| import sys | |
| import json | |
| import logging | |
| from datetime import datetime, timedelta | |
| from typing import Dict, Any, Optional | |
| import argparse | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", | |
| ) | |
| logger = logging.getLogger(__name__) | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from data.sheets_service import GoogleSheetsService | |
| from data.cache_service import get_cache | |
| from training.train_regression import ResolutionTimePredictor | |
| from training.train_nlp import SeverityClassifier, IssueTypeClassifier, train_multitask_and_export | |
| TRAINING_HISTORY_FILE = os.path.join( | |
| os.path.dirname(os.path.dirname(os.path.abspath(__file__))), | |
| "models", | |
| "training_history.json", | |
| ) | |
| class TrainingScheduler: | |
| """Manages scheduled model retraining""" | |
| def __init__(self): | |
| self.history = self._load_history() | |
| self.sheets_service = None | |
| def _load_history(self) -> Dict[str, Any]: | |
| """Load training history from file""" | |
| if os.path.exists(TRAINING_HISTORY_FILE): | |
| try: | |
| with open(TRAINING_HISTORY_FILE, "r") as f: | |
| return json.load(f) | |
| except Exception as e: | |
| logger.warning(f"Failed to load training history: {e}") | |
| return { | |
| "last_training": None, | |
| "last_record_count": 0, | |
| "training_events": [], | |
| "config": { | |
| "min_new_records": 50, | |
| "max_interval_days": 7, | |
| "auto_retrain_enabled": True, | |
| }, | |
| } | |
| def _save_history(self): | |
| """Save training history to file""" | |
| os.makedirs(os.path.dirname(TRAINING_HISTORY_FILE), exist_ok=True) | |
| with open(TRAINING_HISTORY_FILE, "w") as f: | |
| json.dump(self.history, f, indent=2, default=str) | |
| logger.info(f"Training history saved to {TRAINING_HISTORY_FILE}") | |
| def _init_sheets_service(self): | |
| """Initialize Google Sheets service with caching""" | |
| if self.sheets_service is None: | |
| cache = get_cache() | |
| self.sheets_service = GoogleSheetsService(cache=cache) | |
| return self.sheets_service | |
| def check_should_retrain(self) -> tuple[bool, str]: | |
| """Check if retraining is needed based on config""" | |
| config = self.history.get("config", {}) | |
| if not config.get("auto_retrain_enabled", True): | |
| return False, "Auto retraining is disabled" | |
| last_training = self.history.get("last_training") | |
| last_record_count = self.history.get("last_record_count", 0) | |
| if last_training is None: | |
| return True, "No previous training found" | |
| try: | |
| last_dt = datetime.fromisoformat(last_training) | |
| days_since = (datetime.now() - last_dt).days | |
| if days_since >= config.get("max_interval_days", 7): | |
| return True, f"Max interval reached ({days_since} days)" | |
| sheets = self._init_sheets_service() | |
| spreadsheet_id = os.getenv("GOOGLE_SHEET_ID") | |
| non_cargo = sheets.fetch_sheet_data( | |
| spreadsheet_id, "NON CARGO", bypass_cache=True | |
| ) | |
| cargo = sheets.fetch_sheet_data(spreadsheet_id, "CGO", bypass_cache=True) | |
| current_count = len(non_cargo) + len(cargo) | |
| new_records = current_count - last_record_count | |
| if new_records >= config.get("min_new_records", 50): | |
| return ( | |
| True, | |
| f"New records threshold reached ({new_records} new records)", | |
| ) | |
| return ( | |
| False, | |
| f"Retraining not needed ({new_records} new records, {days_since} days since last training)", | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error checking retrain condition: {e}") | |
| return False, f"Error: {str(e)}" | |
| def run_training(self, force: bool = False) -> Dict[str, Any]: | |
| """Run model training""" | |
| should_retrain, reason = self.check_should_retrain() | |
| if not should_retrain and not force: | |
| logger.info(f"Skipping training: {reason}") | |
| return { | |
| "status": "skipped", | |
| "reason": reason, | |
| "timestamp": datetime.now().isoformat(), | |
| } | |
| logger.info(f"Starting training: {reason}") | |
| try: | |
| sheets = self._init_sheets_service() | |
| spreadsheet_id = os.getenv("GOOGLE_SHEET_ID") | |
| if not spreadsheet_id: | |
| raise ValueError("GOOGLE_SHEET_ID not set") | |
| logger.info("Fetching data from Google Sheets...") | |
| non_cargo = sheets.fetch_sheet_data(spreadsheet_id, "NON CARGO", "A1:AA500") | |
| cargo = sheets.fetch_sheet_data(spreadsheet_id, "CGO", "A1:Z500") | |
| all_data = non_cargo + cargo | |
| logger.info(f"Total records: {len(all_data)}") | |
| if len(all_data) < 10: | |
| raise ValueError(f"Not enough data to train: {len(all_data)} records") | |
| predictor = ResolutionTimePredictor() | |
| metrics = predictor.train(all_data) | |
| if not metrics: | |
| raise ValueError("Training failed - no metrics returned") | |
| model_dir = os.path.join( | |
| os.path.dirname(os.path.dirname(os.path.abspath(__file__))), | |
| "models", | |
| "regression", | |
| ) | |
| os.makedirs(model_dir, exist_ok=True) | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| model_path = os.path.join( | |
| model_dir, f"resolution_predictor_{timestamp}.pkl" | |
| ) | |
| latest_path = os.path.join(model_dir, "resolution_predictor_latest.pkl") | |
| predictor.save(model_path) | |
| predictor.save(latest_path) | |
| # Train NLP models | |
| nlp_metrics = {} | |
| try: | |
| severity_clf = SeverityClassifier() | |
| sev_metrics = severity_clf.train(all_data) | |
| if sev_metrics: | |
| os.makedirs( | |
| os.path.join( | |
| os.path.dirname(os.path.dirname(os.path.abspath(__file__))), | |
| "models", | |
| "nlp", | |
| ), | |
| exist_ok=True, | |
| ) | |
| severity_clf.save( | |
| os.path.join( | |
| os.path.dirname(os.path.dirname(os.path.abspath(__file__))), | |
| "models", | |
| "nlp", | |
| "severity_classifier", | |
| ) | |
| ) | |
| nlp_metrics["severity"] = sev_metrics | |
| except Exception as e: | |
| logger.warning(f"Severity training failed: {e}") | |
| try: | |
| issue_clf = IssueTypeClassifier() | |
| iss_metrics = issue_clf.train(all_data) | |
| if iss_metrics: | |
| issue_clf.save( | |
| os.path.join( | |
| os.path.dirname(os.path.dirname(os.path.abspath(__file__))), | |
| "models", | |
| "nlp", | |
| "issue_classifier", | |
| ) | |
| ) | |
| nlp_metrics["issue_type"] = iss_metrics | |
| except Exception as e: | |
| logger.warning(f"Issue type training failed: {e}") | |
| # Train Multi-Task Transformer and export ONNX | |
| try: | |
| mt_output_dir = os.path.join( | |
| os.path.dirname(os.path.dirname(os.path.abspath(__file__))), | |
| "models", | |
| ) | |
| train_multitask_and_export(all_data, output_dir=mt_output_dir) | |
| logger.info("✓ Multi-task transformer training/export attempted") | |
| except Exception as e: | |
| logger.warning(f"Multi-task training failed: {e}") | |
| self.history["last_training"] = datetime.now().isoformat() | |
| self.history["last_record_count"] = len(all_data) | |
| self.history["training_events"].append( | |
| { | |
| "timestamp": datetime.now().isoformat(), | |
| "records": len(all_data), | |
| "metrics": { | |
| "test_mae": metrics.get("test_mae"), | |
| "test_r2": metrics.get("test_r2"), | |
| "nlp": nlp_metrics, | |
| }, | |
| "trigger": reason, | |
| } | |
| ) | |
| if len(self.history["training_events"]) > 50: | |
| self.history["training_events"] = self.history["training_events"][-50:] | |
| self._save_history() | |
| cache = get_cache() | |
| cache.delete_pattern("sheets:*") | |
| logger.info("Cache invalidated after training") | |
| logger.info(f"Training completed successfully!") | |
| logger.info(f"Model saved to: {model_path}") | |
| logger.info(f"Test MAE: {metrics.get('test_mae', 'N/A'):.3f}") | |
| logger.info(f"Test R²: {metrics.get('test_r2', 'N/A'):.3f}") | |
| return { | |
| "status": "success", | |
| "reason": reason, | |
| "timestamp": datetime.now().isoformat(), | |
| "records_trained": len(all_data), | |
| "model_path": model_path, | |
| "metrics": { | |
| "test_mae": metrics.get("test_mae"), | |
| "test_rmse": metrics.get("test_rmse"), | |
| "test_r2": metrics.get("test_r2"), | |
| }, | |
| } | |
| except Exception as e: | |
| logger.error(f"Training failed: {e}", exc_info=True) | |
| return { | |
| "status": "error", | |
| "error": str(e), | |
| "timestamp": datetime.now().isoformat(), | |
| } | |
| def update_config(self, **kwargs): | |
| """Update training configuration""" | |
| config = self.history.get("config", {}) | |
| config.update(kwargs) | |
| self.history["config"] = config | |
| self._save_history() | |
| logger.info(f"Config updated: {kwargs}") | |
| def get_status(self) -> Dict[str, Any]: | |
| """Get current training status""" | |
| should_retrain, reason = self.check_should_retrain() | |
| return { | |
| "last_training": self.history.get("last_training"), | |
| "last_record_count": self.history.get("last_record_count"), | |
| "should_retrain": should_retrain, | |
| "reason": reason, | |
| "config": self.history.get("config", {}), | |
| "recent_events": self.history.get("training_events", [])[-5:], | |
| } | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Gapura AI Scheduled Training") | |
| parser.add_argument( | |
| "--force", action="store_true", help="Force training regardless of conditions" | |
| ) | |
| parser.add_argument("--status", action="store_true", help="Show training status") | |
| parser.add_argument("--config", type=str, help="Update config (JSON format)") | |
| parser.add_argument( | |
| "--schedule", type=int, help="Run as scheduled job every N hours" | |
| ) | |
| args = parser.parse_args() | |
| scheduler = TrainingScheduler() | |
| if args.status: | |
| status = scheduler.get_status() | |
| print(json.dumps(status, indent=2, default=str)) | |
| return | |
| if args.config: | |
| try: | |
| config_updates = json.loads(args.config) | |
| scheduler.update_config(**config_updates) | |
| print(f"Config updated: {config_updates}") | |
| except json.JSONDecodeError as e: | |
| print(f"Invalid JSON config: {e}") | |
| sys.exit(1) | |
| return | |
| if args.schedule: | |
| from apscheduler.schedulers.blocking import BlockingScheduler | |
| sched = BlockingScheduler() | |
| def scheduled_training(): | |
| logger.info("Running scheduled training check...") | |
| result = scheduler.run_training() | |
| logger.info(f"Training result: {result}") | |
| logger.info(f"Starting scheduled training (every {args.schedule} hours)") | |
| sched.start() | |
| else: | |
| result = scheduler.run_training(force=args.force) | |
| print(json.dumps(result, indent=2, default=str)) | |
| if __name__ == "__main__": | |
| main() | |