Sahithi27 commited on
Commit
77c2e7e
·
verified ·
1 Parent(s): 79db7e5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -17
app.py CHANGED
@@ -1,25 +1,108 @@
 
1
  import joblib
2
- from skl2onnx import convert_sklearn
3
- from skl2onnx.common.data_types import FloatTensorType
4
 
5
- # Load trained artifact
6
- artifact = joblib.load("dynamic_pricing_artifact.joblib")
7
 
8
- model = artifact["model"] # trained ML model
9
- features = artifact["features"] # feature list
10
- n_features = len(features)
 
 
11
 
12
- print("Number of features:", n_features)
 
 
 
 
 
 
 
 
 
13
 
14
- # Convert to ONNX
15
- onnx_model = convert_sklearn(
16
- model,
17
- initial_types=[("input", FloatTensorType([None, n_features]))]
18
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- # Save ONNX model
21
- with open("dynamic_pricing_model.onnx", "wb") as f:
22
- f.write(onnx_model.SerializeToString())
23
 
24
- print("✅ ONNX model saved as dynamic_pricing_model.onnx")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
 
 
1
+ import gradio as gr
2
  import joblib
3
+ import pandas as pd
4
+ import numpy as np
5
 
6
+ ARTIFACT_PATH = "dynamic_pricing_artifact.joblib"
 
7
 
8
+ artifact = joblib.load(ARTIFACT_PATH)
9
+ model = artifact["model"]
10
+ FEATURES = artifact["features"]
11
+ FIXED_FARE = artifact["fixed_fare"]
12
+ RATE_PER_KM = artifact["rate_per_km"]
13
 
14
+ # -------------------------------
15
+ # Prediction Logic (FINAL)
16
+ # -------------------------------
17
+ def predict_dynamic_price(
18
+ zone_id, demand, supply, driver_availability,n
19
+ weather_factor, event_factor, temperature, traffic_index,
20
+ distance_km, duration_min,
21
+ hour, day_of_week, is_weekend, month,
22
+ is_holiday, is_festival
23
+ ):
24
 
25
+ # Base Price (EXACT – Distance Based)
26
+ base_price = FIXED_FARE + (distance_km * RATE_PER_KM)
27
+
28
+ # Build model input row
29
+ row = {f: 0.0 for f in FEATURES}
30
+
31
+ inputs = {
32
+ "zone_id": zone_id,
33
+ "hour": hour,
34
+ "day_of_week": day_of_week,
35
+ "is_weekend": is_weekend,
36
+ "month": month,
37
+ "is_holiday": is_holiday,
38
+ "is_festival": is_festival,
39
+ "demand": demand,
40
+ "supply": supply,
41
+ "driver_availability": driver_availability,
42
+ "weather_factor": weather_factor,
43
+ "event_factor": event_factor,
44
+ "temperature": temperature,
45
+ "traffic_index": traffic_index,
46
+ "distance_km": distance_km,
47
+ "duration_min": duration_min,
48
+ "base_fare": base_price
49
+ }
50
+
51
+ for k, v in inputs.items():
52
+ if k in row:
53
+ row[k] = float(v)
54
+
55
+ # demand-supply ratio
56
+ row["demand_supply_ratio"] = np.clip(demand / (supply + 1), 0, 20)
57
+
58
+ df_row = pd.DataFrame([[row[f] for f in FEATURES]], columns=FEATURES)
59
 
60
+ # Predict Surge
61
+ surge = float(model.predict(df_row)[0])
62
+ surge = np.clip(surge, 1.0, 2.5) # realistic Rapido/Ola cap
63
 
64
+ final_price = base_price * surge
65
+
66
+ return round(base_price, 2), round(surge, 3), round(final_price, 2)
67
+
68
+ # -------------------------------
69
+ # Gradio UI
70
+ # -------------------------------
71
+ inputs = [
72
+ gr.Number(label="Zone ID", value=1),
73
+ gr.Number(label="Demand", value=150),
74
+ gr.Number(label="Supply", value=80),
75
+ gr.Number(label="Driver Availability", value=60),
76
+ gr.Number(label="Weather Factor (1.0–1.35)", value=1.0),
77
+ gr.Number(label="Event Factor (1.0–1.5)", value=1.0),
78
+ gr.Number(label="Temperature (°C)", value=30),
79
+ gr.Number(label="Traffic Index (0–1)", value=0.5),
80
+ gr.Number(label="Distance (km)", value=10),
81
+ gr.Number(label="Duration (min)", value=20),
82
+ gr.Number(label="Hour (0–23)", value=18),
83
+ gr.Number(label="Day of Week (0=Mon)", value=4),
84
+ gr.Number(label="Is Weekend (0/1)", value=0),
85
+ gr.Number(label="Month (1–12)", value=11),
86
+ gr.Number(label="Is Holiday (0/1)", value=0),
87
+ gr.Number(label="Is Festival (0/1)", value=0),
88
+ ]
89
+
90
+ outputs = [
91
+ gr.Number(label="Base Price (Distance Based)"),
92
+ gr.Number(label="Predicted Surge Factor"),
93
+ gr.Number(label="Final Dynamic Price"),
94
+ ]
95
+
96
+ demo = gr.Interface(
97
+ fn=predict_dynamic_price,
98
+ inputs=inputs,
99
+ outputs=outputs,
100
+ title="Dynamic Pricing Model",
101
+ description=(
102
+ "Base Fare = Fixed Fare + Distance × KM Rate. "
103
+ "Surge adapts to demand, supply, traffic, weather, events, "
104
+ "weekends, holidays, festivals and zone peak hours."
105
+ )
106
+ )
107
 
108
+ demo.launch()