""" Script to train and update all models for India, States, and Markets. Run this script to update all forecasting models without using the UI. """ import numpy as np import pandas as pd from sklearn.model_selection import train_test_split from sklearn.metrics import mean_squared_error, mean_absolute_error from xgboost import XGBRegressor from tqdm import tqdm from src.agri_predict import fetch_and_process_data from src.agri_predict.constants import state_market_dict from src.agri_predict.features import ( create_forecasting_features, create_forecasting_features_1m, create_forecasting_features_3m, ) from src.agri_predict.config import get_collections # Define forecast horizons FORECAST_HORIZONS = [14, 30, 90] # 14 days, 1 month, 3 months def train_model_batch(df, filter_key, days): """Train model without UI components for batch processing.""" cols = get_collections() # Select feature creation function based on horizon if days == 14: df_features = create_forecasting_features(df) split_date = '2024-01-01' collection_key = 'best_params_collection' elif days == 30: df_features = create_forecasting_features_1m(df) split_date = '2023-01-01' collection_key = 'best_params_collection_1m' else: # 90 days df_features = create_forecasting_features_3m(df) split_date = '2023-01-01' collection_key = 'best_params_collection_3m' # Split data train_df = df_features[df_features['Reported Date'] < split_date] test_df = df_features[df_features['Reported Date'] >= split_date] X_train = train_df.drop(columns=['Modal Price (Rs./Quintal)', 'Reported Date']) y_train = train_df['Modal Price (Rs./Quintal)'] X_test = test_df.drop(columns=['Modal Price (Rs./Quintal)', 'Reported Date']) y_test = test_df['Modal Price (Rs./Quintal)'] # Hyperparameter tuning with progress bar param_grid = { 'learning_rate': [0.01, 0.1, 0.2], 'max_depth': [3, 5, 7], 'n_estimators': [50, 100, 150], 'booster': ['gbtree', 'dart'] } model = XGBRegressor() best_score = float('-inf') best_params = None total_combinations = len(param_grid['learning_rate']) * len(param_grid['max_depth']) * \ len(param_grid['n_estimators']) * len(param_grid['booster']) with tqdm(total=total_combinations, desc=f" Tuning hyperparameters") as pbar: for learning_rate in param_grid['learning_rate']: for max_depth in param_grid['max_depth']: for n_estimators in param_grid['n_estimators']: for booster in param_grid['booster']: model.set_params( learning_rate=learning_rate, max_depth=max_depth, n_estimators=n_estimators, booster=booster ) model.fit(X_train, y_train) score = model.score(X_test, y_test) if score > best_score: best_score = score best_params = { 'learning_rate': learning_rate, 'max_depth': max_depth, 'n_estimators': n_estimators, 'booster': booster } pbar.update(1) # Train final model with best params best_model = XGBRegressor(**best_params) best_model.fit(X_train, y_train) y_pred = best_model.predict(X_test) # Calculate metrics rmse = np.sqrt(mean_squared_error(y_test, y_pred)) mae = mean_absolute_error(y_test, y_pred) # Save to MongoDB cols[collection_key].replace_one( {'filter_key': filter_key}, { **best_params, 'filter_key': filter_key, 'last_updated': pd.Timestamp.now().isoformat(), 'rmse': rmse, 'mae': mae, 'score': best_score }, upsert=True ) return best_params, rmse, mae def update_india_models(): """Update models for all of India.""" print("\n" + "="*60) print("UPDATING INDIA MODELS") print("="*60) query_filter = {} df = fetch_and_process_data(query_filter) if df is not None: for days in FORECAST_HORIZONS: horizon_name = "14 days" if days == 14 else "1 month" if days == 30 else "3 months" print(f"\n[India] Training {horizon_name} forecast model...") try: best_params, rmse, mae = train_model_batch(df, "India", days) print(f"✅ [India] {horizon_name} model updated successfully") print(f" RMSE: {rmse:.2f}, MAE: {mae:.2f}") except Exception as e: print(f"❌ [India] Error updating {horizon_name} model: {e}") else: print("❌ [India] No data available") def update_state_models(): """Update models for all states.""" print("\n" + "="*60) print("UPDATING STATE MODELS") print("="*60) states = ["Karnataka", "Madhya Pradesh", "Gujarat", "Uttar Pradesh", "Telangana"] for state in states: print(f"\n--- Processing State: {state} ---") query_filter = {"State Name": state} df = fetch_and_process_data(query_filter) if df is not None: filter_key = f"state_{state}" for days in FORECAST_HORIZONS: horizon_name = "14 days" if days == 14 else "1 month" if days == 30 else "3 months" print(f"[{state}] Training {horizon_name} forecast model...") try: best_params, rmse, mae = train_model_batch(df, filter_key, days) print(f"✅ [{state}] {horizon_name} model updated successfully") print(f" RMSE: {rmse:.2f}, MAE: {mae:.2f}") except Exception as e: print(f"❌ [{state}] Error updating {horizon_name} model: {e}") else: print(f"❌ [{state}] No data available") def update_market_models(): """Update models for specific markets.""" print("\n" + "="*60) print("UPDATING MARKET MODELS") print("="*60) markets = ["Rajkot", "Gondal", "Kalburgi", "Amreli"] for market in markets: print(f"\n--- Processing Market: {market} ---") query_filter = {"Market Name": market} df = fetch_and_process_data(query_filter) if df is not None: filter_key = f"market_{market}" for days in FORECAST_HORIZONS: horizon_name = "14 days" if days == 14 else "1 month" if days == 30 else "3 months" print(f"[{market}] Training {horizon_name} forecast model...") try: best_params, rmse, mae = train_model_batch(df, filter_key, days) print(f"✅ [{market}] {horizon_name} model updated successfully") print(f" RMSE: {rmse:.2f}, MAE: {mae:.2f}") except Exception as e: print(f"❌ [{market}] Error updating {horizon_name} model: {e}") else: print(f"❌ [{market}] No data available") def main(): """Main function to update all models.""" print("\n" + "🌾" * 30) print("AGRIPREDICT - BATCH MODEL UPDATE") print("🌾" * 30) print("\nThis script will train and update all forecasting models.") print("This may take several minutes to complete.\n") try: # Update India models update_india_models() # Update State models update_state_models() # Update Market models update_market_models() print("\n" + "="*60) print("✅ ALL MODELS UPDATED SUCCESSFULLY") print("="*60) except KeyboardInterrupt: print("\n\n⚠️ Process interrupted by user") except Exception as e: print(f"\n\n❌ Fatal error: {e}") if __name__ == "__main__": main()