Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,81 +1,201 @@
|
|
|
|
|
|
|
|
| 1 |
|
| 2 |
-
|
| 3 |
-
os.environ["KERAS_BACKEND"] = "tensorflow"
|
| 4 |
-
|
| 5 |
import numpy as np
|
| 6 |
import pandas as pd
|
| 7 |
-
import gradio as gr
|
| 8 |
import joblib
|
| 9 |
-
import
|
|
|
|
| 10 |
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
|
|
|
|
|
|
| 14 |
|
| 15 |
-
#
|
| 16 |
-
model = keras.models.load_model(MODEL_PATH, compile=False)
|
| 17 |
-
|
| 18 |
with open(META_PATH, "r", encoding="utf-8") as f:
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
try:
|
| 26 |
-
|
| 27 |
-
if arr is not None and len(arr) > 0:
|
| 28 |
-
FEATURES = list(arr)
|
| 29 |
-
break
|
| 30 |
except Exception:
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
btn
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
def _wrap(b, *vals):
|
| 71 |
-
data = {FEATURES[i]: vals[i] for i in range(len(FEATURES))}
|
| 72 |
-
return predict_fn(baseline_value=b, **data)
|
| 73 |
-
|
| 74 |
-
btn.click(_wrap, inputs=[baseline_input, *feat_inputs], outputs=out)
|
| 75 |
try:
|
| 76 |
-
demo.queue(max_size=32)
|
| 77 |
except TypeError:
|
| 78 |
demo.queue()
|
| 79 |
|
|
|
|
| 80 |
if __name__ == "__main__":
|
| 81 |
demo.launch(server_name="0.0.0.0", server_port=7860)
|
|
|
|
|
|
| 1 |
+
import os, json, calendar
|
| 2 |
+
from datetime import date, timedelta
|
| 3 |
|
| 4 |
+
# Keras 3 поверх TF
|
| 5 |
+
os.environ["KERAS_BACKEND"] = "tensorflow"
|
|
|
|
| 6 |
import numpy as np
|
| 7 |
import pandas as pd
|
|
|
|
| 8 |
import joblib
|
| 9 |
+
import gradio as gr
|
| 10 |
+
import keras # standalone Keras 3
|
| 11 |
|
| 12 |
+
# ====== ПУТИ К АРТЕФАКТАМ ======
|
| 13 |
+
MODEL_DIR = "neiro1" # в Space лежит папка neiro1 рядом с app.py
|
| 14 |
+
MODEL_PATH = os.path.join(MODEL_DIR, "model_v6.keras")
|
| 15 |
+
PP_PATH = os.path.join(MODEL_DIR, "preprocess_v6.joblib")
|
| 16 |
+
META_PATH = os.path.join(MODEL_DIR, "meta_v6.json")
|
| 17 |
|
| 18 |
+
# ====== ЗАГРУЗКА ======
|
| 19 |
+
model = keras.models.load_model(MODEL_PATH, compile=False)
|
| 20 |
+
ct = joblib.load(PP_PATH)
|
| 21 |
with open(META_PATH, "r", encoding="utf-8") as f:
|
| 22 |
+
META = json.load(f)
|
| 23 |
+
|
| 24 |
+
# ====== ПАРАМЕТРЫ И МЕТАДАННЫЕ ======
|
| 25 |
+
ROUTE = META.get("route", "SFO_MIA")
|
| 26 |
+
ORIGIN, DEST = ROUTE.split("_")
|
| 27 |
+
AIRLINES = META.get("airlines", [])
|
| 28 |
+
M_BLEND = {int(k): float(v) for k, v in META.get("blend_month", {}).items()}
|
| 29 |
+
A_BLEND = {str(k): float(v) for k, v in META.get("blend_airline", {}).items()}
|
| 30 |
+
AM_MEANS = {k: float(v) for k, v in META.get("am_means", {}).items()}
|
| 31 |
+
AM_COUNTS= {k: int(v) for k, v in META.get("am_counts", {}).items()}
|
| 32 |
+
W_MONTH, W_AIR, W_AM = META.get("w_month", 0.0), META.get("w_air", 0.0), META.get("w_am", 0.0)
|
| 33 |
+
FLOOR_FRAC = float(META.get("floor_frac", 0.25))
|
| 34 |
+
EXOG = META.get("exog_params", {})
|
| 35 |
+
K_AIRLINE = float(META.get("calibration_k_airline", 1.0))
|
| 36 |
+
K_ROUTE = float(META.get("calibration_k_route", 1.0))
|
| 37 |
+
WHITELIST = META.get("airline_whitelist", AIRLINES)
|
| 38 |
+
SHARE = META.get("airline_share_overall", {})
|
| 39 |
+
CAP_TARGET = float(META.get("capacity_target_summer", 1000.0))
|
| 40 |
+
|
| 41 |
+
# ====== ХЕЛПЕРЫ ======
|
| 42 |
+
def nth_weekday(year, month, weekday, n):
|
| 43 |
+
d = date(year, month, 1)
|
| 44 |
+
add = (weekday - d.weekday()) % 7
|
| 45 |
+
return d + timedelta(days=add) + timedelta(weeks=n-1)
|
| 46 |
+
|
| 47 |
+
def last_weekday(year, month, weekday):
|
| 48 |
+
d = date(year, month, calendar.monthrange(year, month)[1])
|
| 49 |
+
sub = (d.weekday() - weekday) % 7
|
| 50 |
+
return d - timedelta(days=sub)
|
| 51 |
+
|
| 52 |
+
def daterange(d0, d1):
|
| 53 |
+
cur, out = d0, []
|
| 54 |
+
while cur <= d1:
|
| 55 |
+
out.append(cur)
|
| 56 |
+
cur += timedelta(days=1)
|
| 57 |
+
return out
|
| 58 |
+
|
| 59 |
+
def holiday_windows_for_year(y):
|
| 60 |
+
mem = set(daterange(last_weekday(y,5,0)-timedelta(days=3), last_weekday(y,5,0)))
|
| 61 |
+
july = set(daterange(date(y,6,30), date(y,7,7)))
|
| 62 |
+
labor= set(daterange(nth_weekday(y,9,0,1)-timedelta(days=3), nth_weekday(y,9,0,1)))
|
| 63 |
+
tg = set(daterange(nth_weekday(y,11,3,4)-timedelta(days=2), nth_weekday(y,11,3,4)+timedelta(days=3)))
|
| 64 |
+
xny = set(daterange(date(y,12,20), date(y,12,31))) | set(daterange(date(y+1,1,1), date(y+1,1,5)))
|
| 65 |
+
return {"memorial":mem, "july4":july, "labor":labor, "thanksgiving":tg, "xmas_newyear":xny}
|
| 66 |
+
|
| 67 |
+
def exog_multiplier(ts: pd.Timestamp) -> float:
|
| 68 |
+
y = int(ts.year); m = int(ts.month); dow = int(ts.weekday())
|
| 69 |
+
P = {k: float(EXOG.get(k, v)) for k, v in dict(
|
| 70 |
+
SUMMER_MULT=1.60, FRI_MULT=1.10, SAT_MULT=1.15, SUN_MULT=1.10,
|
| 71 |
+
SPRING_BREAK_MULT=1.12, MEMORIAL_MULT=1.35, JULY4_MULT=1.30,
|
| 72 |
+
LABOR_MULT=1.25, THANKSGIVING_MULT=1.60, XMAS_NEWYEAR_MULT=1.70
|
| 73 |
+
).items()}
|
| 74 |
+
mult = 1.0
|
| 75 |
+
if m in (6,7,8): mult *= P["SUMMER_MULT"]
|
| 76 |
+
if pd.Timestamp(year=y, month=3, day=10) <= ts <= pd.Timestamp(year=y, month=4, day=10):
|
| 77 |
+
mult *= P["SPRING_BREAK_MULT"]
|
| 78 |
+
if dow == 4: mult *= P["FRI_MULT"]
|
| 79 |
+
if dow == 5: mult *= P["SAT_MULT"]
|
| 80 |
+
if dow == 6: mult *= P["SUN_MULT"]
|
| 81 |
+
hw = holiday_windows_for_year(y); hw_prev = holiday_windows_for_year(y-1)
|
| 82 |
+
if ts.date() in hw["memorial"]: mult *= P["MEMORIAL_MULT"]
|
| 83 |
+
if ts.date() in hw["july4"]: mult *= P["JULY4_MULT"]
|
| 84 |
+
if ts.date() in hw["labor"]: mult *= P["LABOR_MULT"]
|
| 85 |
+
if ts.date() in hw["thanksgiving"]: mult *= P["THANKSGIVING_MULT"]
|
| 86 |
+
if ts.date() in hw["xmas_newyear"] or ts.date() in hw_prev["xmas_newyear"]:
|
| 87 |
+
mult *= P["XMAS_NEWYEAR_MULT"]
|
| 88 |
+
return float(mult)
|
| 89 |
+
|
| 90 |
+
def _blend(child_mean, child_n, prior_mean, prior_w):
|
| 91 |
+
child_n = float(child_n)
|
| 92 |
+
return float((child_n*child_mean + prior_w*prior_mean) / (child_n + prior_w))
|
| 93 |
|
| 94 |
+
def baseline_pair(a, m):
|
| 95 |
+
gmean = np.mean(list(A_BLEND.values())) if A_BLEND else 0.0
|
| 96 |
+
a0 = A_BLEND.get(str(a), gmean)
|
| 97 |
+
m0 = M_BLEND.get(int(m), gmean)
|
| 98 |
+
prior = 0.5*a0 + 0.5*m0
|
| 99 |
+
key = f"{a}_{int(m)}"
|
| 100 |
+
if key in AM_MEANS:
|
| 101 |
+
mean = AM_MEANS[key]; cnt = AM_COUNTS.get(key, 0)
|
| 102 |
+
return _blend(mean, cnt, prior, W_AM)
|
| 103 |
+
return float(prior)
|
| 104 |
+
|
| 105 |
+
def make_inputs(airline: str, date_str: str):
|
| 106 |
+
d = pd.to_datetime(date_str)
|
| 107 |
+
dow, month = int(d.weekday()), int(d.month)
|
| 108 |
+
weekofyear = int(d.isocalendar().week)
|
| 109 |
+
is_weekend = int(dow in [5,6])
|
| 110 |
+
sin_dow, cos_dow = np.sin(2*np.pi*dow/7), np.cos(2*np.pi*dow/7)
|
| 111 |
+
sin_mon, cos_mon = np.sin(2*np.pi*(month-1)/12), np.cos(2*np.pi*(month-1)/12)
|
| 112 |
+
exog = exog_multiplier(d)
|
| 113 |
+
X_df = pd.DataFrame([{
|
| 114 |
+
"airline": airline,
|
| 115 |
+
"sin_dow": sin_dow, "cos_dow": cos_dow,
|
| 116 |
+
"sin_month": sin_mon, "cos_month": cos_mon,
|
| 117 |
+
"weekofyear": weekofyear, "is_weekend": is_weekend,
|
| 118 |
+
"exog_mult": exog,
|
| 119 |
+
}])
|
| 120 |
+
base = baseline_pair(airline, month) * exog
|
| 121 |
+
return ct.transform(X_df), np.array([[base]], dtype=float), dict(
|
| 122 |
+
dow=dow, month=month, weekofyear=weekofyear, is_weekend=bool(is_weekend),
|
| 123 |
+
exog=exog, baseline=float(base)
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
def predict_one(airline, date_str):
|
| 127 |
+
X, B, info = make_inputs(airline, date_str)
|
| 128 |
+
y_raw = float(model.predict([X, B], verbose=0)[0,0])
|
| 129 |
+
y_cal = y_raw * K_AIRLINE
|
| 130 |
+
# safety floor: минимум 25% от baseline (или META['floor_frac'])
|
| 131 |
+
floor_val = float(B[0,0]) * max(FLOOR_FRAC, 0.25)
|
| 132 |
+
y_cal = max(y_cal, floor_val)
|
| 133 |
+
return y_cal, info, y_raw
|
| 134 |
+
|
| 135 |
+
# ====== UI ======
|
| 136 |
+
with gr.Blocks(title="Прогноз пассажиров (v6.2)") as demo:
|
| 137 |
+
gr.Markdown("### Прогноз числа пассажиров (учтены сезонность/праздники)")
|
| 138 |
+
with gr.Row():
|
| 139 |
+
gr.Textbox(label="Маршрут", value=ROUTE.replace('_', '→'), interactive=False)
|
| 140 |
+
dd_airline = gr.Dropdown(
|
| 141 |
+
choices=(WHITELIST or AIRLINES or ["American airlines","United"]),
|
| 142 |
+
value=(WHITELIST[0] if WHITELIST else (AIRLINES[0] if AIRLINES else "American airlines")),
|
| 143 |
+
label="Авиакомпания",
|
| 144 |
+
allow_custom_value=False
|
| 145 |
+
)
|
| 146 |
+
tb_date = gr.Textbox(label="Дата (YYYY-MM-DD)", value=str(pd.Timestamp.today().date()))
|
| 147 |
+
sum_all = gr.Checkbox(label="Суммарно по всем авиакомпаниям", value=True)
|
| 148 |
+
|
| 149 |
+
def predict_ui(airline, date_str, sum_all):
|
| 150 |
+
# простая валидация даты
|
| 151 |
try:
|
| 152 |
+
_ = pd.to_datetime(date_str)
|
|
|
|
|
|
|
|
|
|
| 153 |
except Exception:
|
| 154 |
+
return 0, "Некорректная дата, формат YYYY-MM-DD."
|
| 155 |
+
|
| 156 |
+
if sum_all:
|
| 157 |
+
total_raw, lines = 0.0, []
|
| 158 |
+
for a in (WHITELIST or AIRLINES or [airline]):
|
| 159 |
+
X, B, info = make_inputs(a, date_str)
|
| 160 |
+
y_raw = float(model.predict([X, B], verbose=0)[0,0])
|
| 161 |
+
total_raw += y_raw
|
| 162 |
+
lines.append(f"- {a}: base×M={info['baseline']:.1f} (M={info['exog']:.2f}) → raw {y_raw:.1f}")
|
| 163 |
+
total = total_raw * K_ROUTE
|
| 164 |
+
msg = (
|
| 165 |
+
f"Маршрут: **{ORIGIN}→{DEST}**, дата: **{date_str}** \n"
|
| 166 |
+
f"Суммарный прогноз: **{round(total)}** "
|
| 167 |
+
f"(raw_sum={total_raw:.1f}, k_route={K_ROUTE:.2f}; целевой летний≈{int(CAP_TARGET)})\n"
|
| 168 |
+
+ "\n".join(lines)
|
| 169 |
+
)
|
| 170 |
+
return round(total), msg
|
| 171 |
+
|
| 172 |
+
y, info, y_raw = predict_one(airline, date_str)
|
| 173 |
+
warn = ""
|
| 174 |
+
if SHARE and SHARE.get(airline, 0.0) < 0.05:
|
| 175 |
+
warn = (
|
| 176 |
+
"\n\n_Примечание: доля этой авиакомпании на маршруте <5%, "
|
| 177 |
+
"прогноз может быть занижен; попробуйте «Суммарно по всем»._"
|
| 178 |
+
)
|
| 179 |
+
msg = (
|
| 180 |
+
f"Маршрут: **{ORIGIN}→{DEST}**, авиакомпания: **{airline}**, дата: **{date_str}** \n"
|
| 181 |
+
f"M(date)={info['exog']:.2f}; baseline×M={info['baseline']:.1f} \n"
|
| 182 |
+
f"Прогноз: **{round(y)}** (raw={y_raw:.1f}, k_airline={K_AIRLINE:.2f}){warn}"
|
| 183 |
+
)
|
| 184 |
+
return round(y), msg
|
| 185 |
+
|
| 186 |
+
btn = gr.Button("Рассчитать")
|
| 187 |
+
out_pred = gr.Number(label="Прогноз пассажиров", interactive=False)
|
| 188 |
+
out_expl = gr.Markdown()
|
| 189 |
+
|
| 190 |
+
btn.click(predict_ui, [dd_airline, tb_date, sum_all], [out_pred, out_expl])
|
| 191 |
+
|
| 192 |
+
# очередь без спорных параметров (кросс-версийно)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
try:
|
| 194 |
+
demo.queue(max_size=32)
|
| 195 |
except TypeError:
|
| 196 |
demo.queue()
|
| 197 |
|
| 198 |
+
# HF Spaces сам запускает; локально — так:
|
| 199 |
if __name__ == "__main__":
|
| 200 |
demo.launch(server_name="0.0.0.0", server_port=7860)
|
| 201 |
+
|