gapura-oneclick / scripts /scheduled_training.py
Muhammad Ridzki Nugraha
Deploy API and config (Batch 3)
07476a1 verified
"""
Scheduled Training Script for Gapura AI
Automatically retrain models when new data is available or on schedule
"""
import os
import sys
import json
import logging
from datetime import datetime, timedelta
from typing import Dict, Any, Optional
import argparse
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from data.sheets_service import GoogleSheetsService
from data.cache_service import get_cache
from training.train_regression import ResolutionTimePredictor
from training.train_nlp import SeverityClassifier, IssueTypeClassifier, train_multitask_and_export
TRAINING_HISTORY_FILE = os.path.join(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
"models",
"training_history.json",
)
class TrainingScheduler:
"""Manages scheduled model retraining"""
def __init__(self):
self.history = self._load_history()
self.sheets_service = None
def _load_history(self) -> Dict[str, Any]:
"""Load training history from file"""
if os.path.exists(TRAINING_HISTORY_FILE):
try:
with open(TRAINING_HISTORY_FILE, "r") as f:
return json.load(f)
except Exception as e:
logger.warning(f"Failed to load training history: {e}")
return {
"last_training": None,
"last_record_count": 0,
"training_events": [],
"config": {
"min_new_records": 50,
"max_interval_days": 7,
"auto_retrain_enabled": True,
},
}
def _save_history(self):
"""Save training history to file"""
os.makedirs(os.path.dirname(TRAINING_HISTORY_FILE), exist_ok=True)
with open(TRAINING_HISTORY_FILE, "w") as f:
json.dump(self.history, f, indent=2, default=str)
logger.info(f"Training history saved to {TRAINING_HISTORY_FILE}")
def _init_sheets_service(self):
"""Initialize Google Sheets service with caching"""
if self.sheets_service is None:
cache = get_cache()
self.sheets_service = GoogleSheetsService(cache=cache)
return self.sheets_service
def check_should_retrain(self) -> tuple[bool, str]:
"""Check if retraining is needed based on config"""
config = self.history.get("config", {})
if not config.get("auto_retrain_enabled", True):
return False, "Auto retraining is disabled"
last_training = self.history.get("last_training")
last_record_count = self.history.get("last_record_count", 0)
if last_training is None:
return True, "No previous training found"
try:
last_dt = datetime.fromisoformat(last_training)
days_since = (datetime.now() - last_dt).days
if days_since >= config.get("max_interval_days", 7):
return True, f"Max interval reached ({days_since} days)"
sheets = self._init_sheets_service()
spreadsheet_id = os.getenv("GOOGLE_SHEET_ID")
non_cargo = sheets.fetch_sheet_data(
spreadsheet_id, "NON CARGO", bypass_cache=True
)
cargo = sheets.fetch_sheet_data(spreadsheet_id, "CGO", bypass_cache=True)
current_count = len(non_cargo) + len(cargo)
new_records = current_count - last_record_count
if new_records >= config.get("min_new_records", 50):
return (
True,
f"New records threshold reached ({new_records} new records)",
)
return (
False,
f"Retraining not needed ({new_records} new records, {days_since} days since last training)",
)
except Exception as e:
logger.error(f"Error checking retrain condition: {e}")
return False, f"Error: {str(e)}"
def run_training(self, force: bool = False) -> Dict[str, Any]:
"""Run model training"""
should_retrain, reason = self.check_should_retrain()
if not should_retrain and not force:
logger.info(f"Skipping training: {reason}")
return {
"status": "skipped",
"reason": reason,
"timestamp": datetime.now().isoformat(),
}
logger.info(f"Starting training: {reason}")
try:
sheets = self._init_sheets_service()
spreadsheet_id = os.getenv("GOOGLE_SHEET_ID")
if not spreadsheet_id:
raise ValueError("GOOGLE_SHEET_ID not set")
logger.info("Fetching data from Google Sheets...")
non_cargo = sheets.fetch_sheet_data(spreadsheet_id, "NON CARGO", "A1:AA500")
cargo = sheets.fetch_sheet_data(spreadsheet_id, "CGO", "A1:Z500")
all_data = non_cargo + cargo
logger.info(f"Total records: {len(all_data)}")
if len(all_data) < 10:
raise ValueError(f"Not enough data to train: {len(all_data)} records")
predictor = ResolutionTimePredictor()
metrics = predictor.train(all_data)
if not metrics:
raise ValueError("Training failed - no metrics returned")
model_dir = os.path.join(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
"models",
"regression",
)
os.makedirs(model_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_path = os.path.join(
model_dir, f"resolution_predictor_{timestamp}.pkl"
)
latest_path = os.path.join(model_dir, "resolution_predictor_latest.pkl")
predictor.save(model_path)
predictor.save(latest_path)
# Train NLP models
nlp_metrics = {}
try:
severity_clf = SeverityClassifier()
sev_metrics = severity_clf.train(all_data)
if sev_metrics:
os.makedirs(
os.path.join(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
"models",
"nlp",
),
exist_ok=True,
)
severity_clf.save(
os.path.join(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
"models",
"nlp",
"severity_classifier",
)
)
nlp_metrics["severity"] = sev_metrics
except Exception as e:
logger.warning(f"Severity training failed: {e}")
try:
issue_clf = IssueTypeClassifier()
iss_metrics = issue_clf.train(all_data)
if iss_metrics:
issue_clf.save(
os.path.join(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
"models",
"nlp",
"issue_classifier",
)
)
nlp_metrics["issue_type"] = iss_metrics
except Exception as e:
logger.warning(f"Issue type training failed: {e}")
# Train Multi-Task Transformer and export ONNX
try:
mt_output_dir = os.path.join(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
"models",
)
train_multitask_and_export(all_data, output_dir=mt_output_dir)
logger.info("✓ Multi-task transformer training/export attempted")
except Exception as e:
logger.warning(f"Multi-task training failed: {e}")
self.history["last_training"] = datetime.now().isoformat()
self.history["last_record_count"] = len(all_data)
self.history["training_events"].append(
{
"timestamp": datetime.now().isoformat(),
"records": len(all_data),
"metrics": {
"test_mae": metrics.get("test_mae"),
"test_r2": metrics.get("test_r2"),
"nlp": nlp_metrics,
},
"trigger": reason,
}
)
if len(self.history["training_events"]) > 50:
self.history["training_events"] = self.history["training_events"][-50:]
self._save_history()
cache = get_cache()
cache.delete_pattern("sheets:*")
logger.info("Cache invalidated after training")
logger.info(f"Training completed successfully!")
logger.info(f"Model saved to: {model_path}")
logger.info(f"Test MAE: {metrics.get('test_mae', 'N/A'):.3f}")
logger.info(f"Test R²: {metrics.get('test_r2', 'N/A'):.3f}")
return {
"status": "success",
"reason": reason,
"timestamp": datetime.now().isoformat(),
"records_trained": len(all_data),
"model_path": model_path,
"metrics": {
"test_mae": metrics.get("test_mae"),
"test_rmse": metrics.get("test_rmse"),
"test_r2": metrics.get("test_r2"),
},
}
except Exception as e:
logger.error(f"Training failed: {e}", exc_info=True)
return {
"status": "error",
"error": str(e),
"timestamp": datetime.now().isoformat(),
}
def update_config(self, **kwargs):
"""Update training configuration"""
config = self.history.get("config", {})
config.update(kwargs)
self.history["config"] = config
self._save_history()
logger.info(f"Config updated: {kwargs}")
def get_status(self) -> Dict[str, Any]:
"""Get current training status"""
should_retrain, reason = self.check_should_retrain()
return {
"last_training": self.history.get("last_training"),
"last_record_count": self.history.get("last_record_count"),
"should_retrain": should_retrain,
"reason": reason,
"config": self.history.get("config", {}),
"recent_events": self.history.get("training_events", [])[-5:],
}
def main():
parser = argparse.ArgumentParser(description="Gapura AI Scheduled Training")
parser.add_argument(
"--force", action="store_true", help="Force training regardless of conditions"
)
parser.add_argument("--status", action="store_true", help="Show training status")
parser.add_argument("--config", type=str, help="Update config (JSON format)")
parser.add_argument(
"--schedule", type=int, help="Run as scheduled job every N hours"
)
args = parser.parse_args()
scheduler = TrainingScheduler()
if args.status:
status = scheduler.get_status()
print(json.dumps(status, indent=2, default=str))
return
if args.config:
try:
config_updates = json.loads(args.config)
scheduler.update_config(**config_updates)
print(f"Config updated: {config_updates}")
except json.JSONDecodeError as e:
print(f"Invalid JSON config: {e}")
sys.exit(1)
return
if args.schedule:
from apscheduler.schedulers.blocking import BlockingScheduler
sched = BlockingScheduler()
@sched.scheduled_job("interval", hours=args.schedule)
def scheduled_training():
logger.info("Running scheduled training check...")
result = scheduler.run_training()
logger.info(f"Training result: {result}")
logger.info(f"Starting scheduled training (every {args.schedule} hours)")
sched.start()
else:
result = scheduler.run_training(force=args.force)
print(json.dumps(result, indent=2, default=str))
if __name__ == "__main__":
main()