Spaces:
Running
Running
File size: 8,239 Bytes
3029a46 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 |
"""
Script to train and update all models for India, States, and Markets.
Run this script to update all forecasting models without using the UI.
"""
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, mean_absolute_error
from xgboost import XGBRegressor
from tqdm import tqdm
from src.agri_predict import fetch_and_process_data
from src.agri_predict.constants import state_market_dict
from src.agri_predict.features import (
create_forecasting_features,
create_forecasting_features_1m,
create_forecasting_features_3m,
)
from src.agri_predict.config import get_collections
# Define forecast horizons
FORECAST_HORIZONS = [14, 30, 90] # 14 days, 1 month, 3 months
def train_model_batch(df, filter_key, days):
"""Train model without UI components for batch processing."""
cols = get_collections()
# Select feature creation function based on horizon
if days == 14:
df_features = create_forecasting_features(df)
split_date = '2024-01-01'
collection_key = 'best_params_collection'
elif days == 30:
df_features = create_forecasting_features_1m(df)
split_date = '2023-01-01'
collection_key = 'best_params_collection_1m'
else: # 90 days
df_features = create_forecasting_features_3m(df)
split_date = '2023-01-01'
collection_key = 'best_params_collection_3m'
# Split data
train_df = df_features[df_features['Reported Date'] < split_date]
test_df = df_features[df_features['Reported Date'] >= split_date]
X_train = train_df.drop(columns=['Modal Price (Rs./Quintal)', 'Reported Date'])
y_train = train_df['Modal Price (Rs./Quintal)']
X_test = test_df.drop(columns=['Modal Price (Rs./Quintal)', 'Reported Date'])
y_test = test_df['Modal Price (Rs./Quintal)']
# Hyperparameter tuning with progress bar
param_grid = {
'learning_rate': [0.01, 0.1, 0.2],
'max_depth': [3, 5, 7],
'n_estimators': [50, 100, 150],
'booster': ['gbtree', 'dart']
}
model = XGBRegressor()
best_score = float('-inf')
best_params = None
total_combinations = len(param_grid['learning_rate']) * len(param_grid['max_depth']) * \
len(param_grid['n_estimators']) * len(param_grid['booster'])
with tqdm(total=total_combinations, desc=f" Tuning hyperparameters") as pbar:
for learning_rate in param_grid['learning_rate']:
for max_depth in param_grid['max_depth']:
for n_estimators in param_grid['n_estimators']:
for booster in param_grid['booster']:
model.set_params(
learning_rate=learning_rate,
max_depth=max_depth,
n_estimators=n_estimators,
booster=booster
)
model.fit(X_train, y_train)
score = model.score(X_test, y_test)
if score > best_score:
best_score = score
best_params = {
'learning_rate': learning_rate,
'max_depth': max_depth,
'n_estimators': n_estimators,
'booster': booster
}
pbar.update(1)
# Train final model with best params
best_model = XGBRegressor(**best_params)
best_model.fit(X_train, y_train)
y_pred = best_model.predict(X_test)
# Calculate metrics
rmse = np.sqrt(mean_squared_error(y_test, y_pred))
mae = mean_absolute_error(y_test, y_pred)
# Save to MongoDB
cols[collection_key].replace_one(
{'filter_key': filter_key},
{
**best_params,
'filter_key': filter_key,
'last_updated': pd.Timestamp.now().isoformat(),
'rmse': rmse,
'mae': mae,
'score': best_score
},
upsert=True
)
return best_params, rmse, mae
def update_india_models():
"""Update models for all of India."""
print("\n" + "="*60)
print("UPDATING INDIA MODELS")
print("="*60)
query_filter = {}
df = fetch_and_process_data(query_filter)
if df is not None:
for days in FORECAST_HORIZONS:
horizon_name = "14 days" if days == 14 else "1 month" if days == 30 else "3 months"
print(f"\n[India] Training {horizon_name} forecast model...")
try:
best_params, rmse, mae = train_model_batch(df, "India", days)
print(f"โ
[India] {horizon_name} model updated successfully")
print(f" RMSE: {rmse:.2f}, MAE: {mae:.2f}")
except Exception as e:
print(f"โ [India] Error updating {horizon_name} model: {e}")
else:
print("โ [India] No data available")
def update_state_models():
"""Update models for all states."""
print("\n" + "="*60)
print("UPDATING STATE MODELS")
print("="*60)
states = ["Karnataka", "Madhya Pradesh", "Gujarat", "Uttar Pradesh", "Telangana"]
for state in states:
print(f"\n--- Processing State: {state} ---")
query_filter = {"State Name": state}
df = fetch_and_process_data(query_filter)
if df is not None:
filter_key = f"state_{state}"
for days in FORECAST_HORIZONS:
horizon_name = "14 days" if days == 14 else "1 month" if days == 30 else "3 months"
print(f"[{state}] Training {horizon_name} forecast model...")
try:
best_params, rmse, mae = train_model_batch(df, filter_key, days)
print(f"โ
[{state}] {horizon_name} model updated successfully")
print(f" RMSE: {rmse:.2f}, MAE: {mae:.2f}")
except Exception as e:
print(f"โ [{state}] Error updating {horizon_name} model: {e}")
else:
print(f"โ [{state}] No data available")
def update_market_models():
"""Update models for specific markets."""
print("\n" + "="*60)
print("UPDATING MARKET MODELS")
print("="*60)
markets = ["Rajkot", "Gondal", "Kalburgi", "Amreli"]
for market in markets:
print(f"\n--- Processing Market: {market} ---")
query_filter = {"Market Name": market}
df = fetch_and_process_data(query_filter)
if df is not None:
filter_key = f"market_{market}"
for days in FORECAST_HORIZONS:
horizon_name = "14 days" if days == 14 else "1 month" if days == 30 else "3 months"
print(f"[{market}] Training {horizon_name} forecast model...")
try:
best_params, rmse, mae = train_model_batch(df, filter_key, days)
print(f"โ
[{market}] {horizon_name} model updated successfully")
print(f" RMSE: {rmse:.2f}, MAE: {mae:.2f}")
except Exception as e:
print(f"โ [{market}] Error updating {horizon_name} model: {e}")
else:
print(f"โ [{market}] No data available")
def main():
"""Main function to update all models."""
print("\n" + "๐พ" * 30)
print("AGRIPREDICT - BATCH MODEL UPDATE")
print("๐พ" * 30)
print("\nThis script will train and update all forecasting models.")
print("This may take several minutes to complete.\n")
try:
# Update India models
update_india_models()
# Update State models
update_state_models()
# Update Market models
update_market_models()
print("\n" + "="*60)
print("โ
ALL MODELS UPDATED SUCCESSFULLY")
print("="*60)
except KeyboardInterrupt:
print("\n\nโ ๏ธ Process interrupted by user")
except Exception as e:
print(f"\n\nโ Fatal error: {e}")
if __name__ == "__main__":
main()
|