matanzig commited on
Commit
49e6cce
·
verified ·
1 Parent(s): c98bf6f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -13
app.py CHANGED
@@ -11,7 +11,7 @@ REPO_ID = "matanzig/flight-price-prediction"
11
  reg_model = joblib.load(hf_hub_download(repo_id=REPO_ID, filename="flight_price_rf_model.pkl"))
12
  cls_model = joblib.load(hf_hub_download(repo_id=REPO_ID, filename="flight_price_classifier_rf.pkl"))
13
 
14
- # הרשימה המדויקת של העמודות, כולל הכפילויות שהמודל מצפה לקבל!
15
  COLUMNS = [
16
  'startingAirport', 'destinationAirport', 'isBasicEconomy', 'isRefundable',
17
  'isNonStop', 'seatsRemaining', 'totalTravelDistance', 'month',
@@ -23,7 +23,7 @@ COLUMNS = [
23
  'cluster_group_3', 'cluster_group_1', 'cluster_group_2', 'cluster_group_3'
24
  ]
25
 
26
- # מילון חברות התעופה
27
  AIRLINE_MAPPING = {
28
  "Alaska Airlines": 0, "American Airlines": 1, "Boutique Air": 2, "Cape Air": 3,
29
  "Contour Airlines": 4, "Delta": 5, "Frontier Airlines": 6, "Hawaiian Airlines": 7,
@@ -33,7 +33,7 @@ AIRLINE_MAPPING = {
33
 
34
  # --- 2. Prediction Engine ---
35
  def predict_flight_price(flight_date, distance, duration, days_until, seats, airline_name, seen_price, is_nonstop, is_basic_economy):
36
- # --- ולידציה קשוחה לתאריך ---
37
  try:
38
  dt = pd.to_datetime(flight_date, format="%Y-%m-%d")
39
  month_val = dt.month
@@ -42,39 +42,39 @@ def predict_flight_price(flight_date, distance, duration, days_until, seats, air
42
  except ValueError:
43
  raise gr.Error("❌ Invalid Date Format! Please use exactly YYYY-MM-DD (e.g., 2026-07-15).")
44
 
45
- # נרמול החודש (לפי ממוצע וסטיית תקן משוערים של נתוני האימון)
46
  scaled_month = (month_val - 7.0) / 2.0
47
 
48
- # נרמול הימים לטיסה
49
  scaled_days = (days_until - 30) / 15.0
50
  scaled_days_sq = scaled_days ** 2
51
 
52
- # המרת חברת התעופה למספר שהמודל מבין
53
  airline_id = AIRLINE_MAPPING.get(airline_name, 5)
 
54
 
55
- # בניית 27 העמודות במדויק
56
  row_data = [
57
- 4, 5, # שדות תעופה קבועים להדגמה
58
  int(is_basic_economy), 0, int(is_nonstop), int(seats), float(distance),
59
  float(scaled_month), float(scaled_days), float(duration), int(airline_id), 1, -0.32,
60
  day_name == 'Monday', day_name == 'Saturday', day_name == 'Sunday',
61
  day_name == 'Thursday', day_name == 'Tuesday', day_name == 'Wednesday',
62
  is_weekend, float(scaled_days_sq),
63
- False, False, False, False, False, False # קלאסטרים קבועים להדגמה
64
  ]
65
 
66
- # יצירת ה-DataFrame
67
  input_df = pd.DataFrame([row_data], columns=COLUMNS)
68
 
69
- # ביצוע החיזוי
70
  reg_prediction = reg_model.predict(input_df)[0]
71
  cls_prediction = cls_model.predict(input_df)[0]
72
 
73
- # תרגום הסיווג העסקי
74
  tier_mapping = {0: "Budget Expected 🟢", 1: "Standard Expected 🟡", 2: "Premium Expected 🔴"}
75
  expected_tier = tier_mapping.get(cls_prediction, "Unknown")
76
 
77
- # מנתח עסקאות לעומת המחיר שהמשתמש מצא
78
  if seen_price > 0:
79
  diff = seen_price - reg_prediction
80
  if diff < -25:
 
11
  reg_model = joblib.load(hf_hub_download(repo_id=REPO_ID, filename="flight_price_rf_model.pkl"))
12
  cls_model = joblib.load(hf_hub_download(repo_id=REPO_ID, filename="flight_price_classifier_rf.pkl"))
13
 
14
+
15
  COLUMNS = [
16
  'startingAirport', 'destinationAirport', 'isBasicEconomy', 'isRefundable',
17
  'isNonStop', 'seatsRemaining', 'totalTravelDistance', 'month',
 
23
  'cluster_group_3', 'cluster_group_1', 'cluster_group_2', 'cluster_group_3'
24
  ]
25
 
26
+
27
  AIRLINE_MAPPING = {
28
  "Alaska Airlines": 0, "American Airlines": 1, "Boutique Air": 2, "Cape Air": 3,
29
  "Contour Airlines": 4, "Delta": 5, "Frontier Airlines": 6, "Hawaiian Airlines": 7,
 
33
 
34
  # --- 2. Prediction Engine ---
35
  def predict_flight_price(flight_date, distance, duration, days_until, seats, airline_name, seen_price, is_nonstop, is_basic_economy):
36
+
37
  try:
38
  dt = pd.to_datetime(flight_date, format="%Y-%m-%d")
39
  month_val = dt.month
 
42
  except ValueError:
43
  raise gr.Error("❌ Invalid Date Format! Please use exactly YYYY-MM-DD (e.g., 2026-07-15).")
44
 
45
+
46
  scaled_month = (month_val - 7.0) / 2.0
47
 
48
+
49
  scaled_days = (days_until - 30) / 15.0
50
  scaled_days_sq = scaled_days ** 2
51
 
52
+
53
  airline_id = AIRLINE_MAPPING.get(airline_name, 5)
54
+
55
 
 
56
  row_data = [
57
+ 4, 5,
58
  int(is_basic_economy), 0, int(is_nonstop), int(seats), float(distance),
59
  float(scaled_month), float(scaled_days), float(duration), int(airline_id), 1, -0.32,
60
  day_name == 'Monday', day_name == 'Saturday', day_name == 'Sunday',
61
  day_name == 'Thursday', day_name == 'Tuesday', day_name == 'Wednesday',
62
  is_weekend, float(scaled_days_sq),
63
+ False, False, False, False, False, False
64
  ]
65
 
66
+
67
  input_df = pd.DataFrame([row_data], columns=COLUMNS)
68
 
69
+
70
  reg_prediction = reg_model.predict(input_df)[0]
71
  cls_prediction = cls_model.predict(input_df)[0]
72
 
73
+
74
  tier_mapping = {0: "Budget Expected 🟢", 1: "Standard Expected 🟡", 2: "Premium Expected 🔴"}
75
  expected_tier = tier_mapping.get(cls_prediction, "Unknown")
76
 
77
+
78
  if seen_price > 0:
79
  diff = seen_price - reg_prediction
80
  if diff < -25: