chainlit-supplychain-app / src /components /model_risk_predictor.py
samithcs's picture
Update src/components/model_risk_predictor.py
f3de9c7 verified
import pandas as pd
import numpy as np
import requests
from io import BytesIO
import joblib
from pathlib import Path
try:
MODEL_URL = "https://huggingface.co/samithcs/risk_predictor/resolve/main/risk_predictor/hist_gradient_boosting_risk_predictor.joblib"
RESPONSE = requests.get(MODEL_URL)
risk_model = joblib.load(BytesIO(RESPONSE.content))
FEATURE_COLS = list(risk_model.feature_names_in_) if hasattr(risk_model, "feature_names_in_") else None
except Exception:
risk_model = None
FEATURE_COLS = None
REGION_BASE_RISKS = {
"Shanghai": 0.55, "Singapore": 0.30, "Mumbai": 0.45, "Dubai": 0.35,
"UAE": 0.35, "USA": 0.30, "Germany": 0.25, "China": 0.55,
"India": 0.45, "Hong Kong": 0.50, "Rotterdam": 0.28, "Los Angeles": 0.40
}
EVENT_RISK_MULTIPLIERS = {
"strike": 0.30, "port strike": 0.35, "typhoon": 0.35, "hurricane": 0.35,
"earthquake": 0.40, "flood": 0.25, "port closure": 0.45, "supplier outage": 0.25,
"customs delay": 0.15, "congestion": 0.20, "pandemic": 0.30, "war": 0.50,
"sanctions": 0.40
}
def build_feature_row(feature_cols, query_dict, reference_row):
row = reference_row.copy()
shipping_mode = query_dict.get('shipping_mode', 'Standard Class')
for col in feature_cols:
if 'Shipping_Mode' in col and shipping_mode in col:
row[col] = 1
region = query_dict.get('region', '')
for col in feature_cols:
if ('Order_Country' in col or 'Order_Region' in col) and region in col:
row[col] = 1
for col in feature_cols:
if 'Order_Status_COMPLETE' in col:
row[col] = 1
return row
def calculate_rule_based_risk(region, days, incidents):
base_risk = REGION_BASE_RISKS.get(region, 0.40)
event_risk = 0.0
if incidents:
for incident in incidents:
incident_lower = str(incident).lower()
for keyword, multiplier in EVENT_RISK_MULTIPLIERS.items():
if keyword in incident_lower:
event_risk += multiplier
time_factor = max(0.1, 1.0 - (days / 30.0))
return min(1.0, base_risk * 0.5 + event_risk * 0.4 + time_factor * 0.1)
def predict_risk(region: str, days: int = 5, origin=None, destination=None,
event_type=None, incidents=None, shipping_mode=None):
try:
if shipping_mode is None:
shipping_mode = "Standard Class"
rule_risk = calculate_rule_based_risk(region, days, incidents or [])
ml_risk = 0.40
if risk_model is not None and FEATURE_COLS is not None:
try:
reference_row = pd.Series({col: 0 for col in FEATURE_COLS})
query_dict = {
"region": region,
"days": days,
"origin": origin,
"destination": destination,
"shipping_mode": shipping_mode,
}
test_features = pd.DataFrame(
[build_feature_row(FEATURE_COLS, query_dict, reference_row)]
)
ml_risk = float(risk_model.predict_proba(test_features)[0, 1])
except Exception as e:
ml_risk = 0.40
if incidents and len(incidents) > 0:
final_risk = (ml_risk * 0.40) + (rule_risk * 0.60)
else:
final_risk = (ml_risk * 0.70) + (rule_risk * 0.30)
final_risk = float(np.clip(final_risk, 0.0, 1.0))
return round(final_risk, 2)
except Exception as e:
return 0.50