Sahithi27 commited on
Commit
79db7e5
·
verified ·
1 Parent(s): 5ddc312

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -82
app.py CHANGED
@@ -1,90 +1,25 @@
1
- import gradio as gr
2
  import joblib
3
- import pandas as pd
4
- import numpy as np
5
 
6
- ARTIFACT_PATH = "dynamic_pricing_artifact.joblib"
 
7
 
8
- artifact = joblib.load(ARTIFACT_PATH)
 
 
9
 
10
- model = artifact["model"]
11
- FEATURES = artifact["features"]
12
- FIXED_FARE = artifact["fixed_fare"]
13
- RATE_PER_KM = artifact["rate_per_km"]
14
 
15
- def predict_dynamic_price(
16
- zone_id, demand, supply, driver_availability,
17
- weather_factor, event_factor, temperature, traffic_index,
18
- distance_km, duration_min,
19
- hour, day_of_week, is_weekend, month,
20
- is_holiday, is_festival
21
- ):
22
- # 1️⃣ Base price (distance based)
23
- base_price = FIXED_FARE + distance_km * RATE_PER_KM
24
-
25
- # 2️⃣ Build feature row
26
- row = {f: 0.0 for f in FEATURES}
27
-
28
- row.update({
29
- "zone_id": zone_id,
30
- "hour": hour,
31
- "day_of_week": day_of_week,
32
- "is_weekend": is_weekend,
33
- "month": month,
34
- "is_holiday": is_holiday,
35
- "is_festival": is_festival,
36
- "demand": demand,
37
- "supply": supply,
38
- "driver_availability": driver_availability,
39
- "weather_factor": weather_factor,
40
- "event_factor": event_factor,
41
- "temperature": temperature,
42
- "traffic_index": traffic_index,
43
- "distance_km": distance_km,
44
- "duration_min": duration_min,
45
- "base_fare": base_price
46
- })
47
-
48
- # 3️⃣ Demand–Supply ratio
49
- row["demand_supply_ratio"] = np.clip(demand / (supply + 1), 0, 20)
50
-
51
- df_row = pd.DataFrame([[row[f] for f in FEATURES]], columns=FEATURES)
52
-
53
- # 4️⃣ Predict surge
54
- surge = float(model.predict(df_row)[0])
55
- surge = np.clip(surge, 1.0, 2.5) # Rapido/Ola realistic cap
56
-
57
- final_price = base_price * surge
58
 
59
- return round(base_price, 2), round(surge, 3), round(final_price, 2)
 
 
60
 
61
- demo = gr.Interface(
62
- fn=predict_dynamic_price,
63
- inputs=[
64
- gr.Number(label="Zone ID", value=1),
65
- gr.Number(label="Demand", value=150),
66
- gr.Number(label="Supply", value=80),
67
- gr.Number(label="Driver Availability", value=60),
68
- gr.Number(label="Weather Factor", value=1.0),
69
- gr.Number(label="Event Factor", value=1.0),
70
- gr.Number(label="Temperature (°C)", value=30),
71
- gr.Number(label="Traffic Index", value=0.5),
72
- gr.Number(label="Distance (km)", value=10),
73
- gr.Number(label="Duration (min)", value=20),
74
- gr.Number(label="Hour", value=18),
75
- gr.Number(label="Day of Week", value=4),
76
- gr.Number(label="Is Weekend (0/1)", value=0),
77
- gr.Number(label="Month", value=11),
78
- gr.Number(label="Is Holiday (0/1)", value=0),
79
- gr.Number(label="Is Festival (0/1)", value=0),
80
- ],
81
- outputs=[
82
- gr.Number(label="Base Price"),
83
- gr.Number(label="Surge Factor"),
84
- gr.Number(label="Final Price"),
85
- ],
86
- title="Dynamic Pricing Model",
87
- description="Realistic dynamic pricing similar to Rapido/Ola"
88
- )
89
 
90
- demo.launch()
 
 
1
  import joblib
2
+ from skl2onnx import convert_sklearn
3
+ from skl2onnx.common.data_types import FloatTensorType
4
 
5
+ # Load trained artifact
6
+ artifact = joblib.load("dynamic_pricing_artifact.joblib")
7
 
8
+ model = artifact["model"] # trained ML model
9
+ features = artifact["features"] # feature list
10
+ n_features = len(features)
11
 
12
+ print("Number of features:", n_features)
 
 
 
13
 
14
+ # Convert to ONNX
15
+ onnx_model = convert_sklearn(
16
+ model,
17
+ initial_types=[("input", FloatTensorType([None, n_features]))]
18
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
+ # Save ONNX model
21
+ with open("dynamic_pricing_model.onnx", "wb") as f:
22
+ f.write(onnx_model.SerializeToString())
23
 
24
+ print("✅ ONNX model saved as dynamic_pricing_model.onnx")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25