Petr-UI-GA commited on
Commit
f35d1ab
·
verified ·
1 Parent(s): 7d0adf4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +185 -65
app.py CHANGED
@@ -1,81 +1,201 @@
 
 
1
 
2
- import os, json
3
- os.environ["KERAS_BACKEND"] = "tensorflow" # сказать Keras-3, что бэкенд — TF
4
-
5
  import numpy as np
6
  import pandas as pd
7
- import gradio as gr
8
  import joblib
9
- import keras # standalone Keras 3
 
10
 
11
- MODEL_PATH = "neiro1/model_v6.keras"
12
- PP_PATH = "neiro1/preprocess_v6.joblib"
13
- META_PATH = "neiro1/meta_v6.json"
 
 
14
 
15
- # === загрузка артефактов ===
16
- model = keras.models.load_model(MODEL_PATH, compile=False) # compile=False чтобы не требовать старые оптимайзеры
17
- preprocess = joblib.load(PP_PATH)
18
  with open(META_PATH, "r", encoding="utf-8") as f:
19
- meta = json.load(f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- # Список фич для x_in
22
- FEATURES = None
23
- for attr in ["feature_names_in_", "features_in_", "input_features_"]:
24
- if hasattr(preprocess, attr):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  try:
26
- arr = getattr(preprocess, attr)
27
- if arr is not None and len(arr) > 0:
28
- FEATURES = list(arr)
29
- break
30
  except Exception:
31
- pass
32
- if FEATURES is None:
33
- FEATURES = meta.get("features", []) or meta.get("expected_features", [])
34
- if not FEATURES:
35
- # fallback, чтобы интерфейс хоть как-то поднялся
36
- FEATURES = [f"feat{i+1}" for i in range(14)]
37
-
38
- # Имя/описание базовой фичи для второго входа
39
- BASELINE_NAME = meta.get("baseline_name", "baseline_in")
40
- BASELINE_HINT = meta.get("baseline_hint", "Базовое значение (скаляp)")
41
-
42
- def build_features(inputs_dict):
43
- X = pd.DataFrame([{k: inputs_dict[k] for k in FEATURES}])
44
- Xp = preprocess.transform(X)
45
- return Xp
46
-
47
- def predict_fn(baseline_value, **kwargs):
48
- # kwargs содержит остальные признаки
49
- Xp = build_features(kwargs) # (1, n_features) -> препроцессор
50
- b = np.array([[float(baseline_value)]], dtype=np.float32) # (1,1)
51
- # Модель ожидает 2 входа: [x_in_preprocessed, baseline_in]
52
- y = model.predict([Xp, b], verbose=0)
53
- if y.ndim == 2 and y.shape[1] == 1:
54
- y = y.ravel()
55
- return f"Прогноз: {float(y[0]):,.4f}"
56
-
57
- # === UI ===
58
- with gr.Blocks(title="Нейронка v6") as demo:
59
- gr.Markdown("# Нейронка v6\nЗаполни параметры и базовое значение — получи прогноз.")
60
-
61
- # сначала baseline для второго входа
62
- baseline_input = gr.Number(label=BASELINE_NAME, info=BASELINE_HINT)
63
-
64
- # затем признаки для препроцессора/первого входа
65
- feat_inputs = [gr.Number(label=f) for f in FEATURES]
66
-
67
- btn = gr.Button("Предсказать")
68
- out = gr.Textbox(label="Результат", lines=2)
69
-
70
- def _wrap(b, *vals):
71
- data = {FEATURES[i]: vals[i] for i in range(len(FEATURES))}
72
- return predict_fn(baseline_value=b, **data)
73
-
74
- btn.click(_wrap, inputs=[baseline_input, *feat_inputs], outputs=out)
75
  try:
76
- demo.queue(max_size=32) # gradio>=4
77
  except TypeError:
78
  demo.queue()
79
 
 
80
  if __name__ == "__main__":
81
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
1
+ import os, json, calendar
2
+ from datetime import date, timedelta
3
 
4
+ # Keras 3 поверх TF
5
+ os.environ["KERAS_BACKEND"] = "tensorflow"
 
6
  import numpy as np
7
  import pandas as pd
 
8
  import joblib
9
+ import gradio as gr
10
+ import keras # standalone Keras 3
11
 
12
+ # ====== ПУТИ К АРТЕФАКТАМ ======
13
+ MODEL_DIR = "neiro1" # в Space лежит папка neiro1 рядом с app.py
14
+ MODEL_PATH = os.path.join(MODEL_DIR, "model_v6.keras")
15
+ PP_PATH = os.path.join(MODEL_DIR, "preprocess_v6.joblib")
16
+ META_PATH = os.path.join(MODEL_DIR, "meta_v6.json")
17
 
18
+ # ====== ЗАГРУЗКА ======
19
+ model = keras.models.load_model(MODEL_PATH, compile=False)
20
+ ct = joblib.load(PP_PATH)
21
  with open(META_PATH, "r", encoding="utf-8") as f:
22
+ META = json.load(f)
23
+
24
+ # ====== ПАРАМЕТРЫ И МЕТАДАННЫЕ ======
25
+ ROUTE = META.get("route", "SFO_MIA")
26
+ ORIGIN, DEST = ROUTE.split("_")
27
+ AIRLINES = META.get("airlines", [])
28
+ M_BLEND = {int(k): float(v) for k, v in META.get("blend_month", {}).items()}
29
+ A_BLEND = {str(k): float(v) for k, v in META.get("blend_airline", {}).items()}
30
+ AM_MEANS = {k: float(v) for k, v in META.get("am_means", {}).items()}
31
+ AM_COUNTS= {k: int(v) for k, v in META.get("am_counts", {}).items()}
32
+ W_MONTH, W_AIR, W_AM = META.get("w_month", 0.0), META.get("w_air", 0.0), META.get("w_am", 0.0)
33
+ FLOOR_FRAC = float(META.get("floor_frac", 0.25))
34
+ EXOG = META.get("exog_params", {})
35
+ K_AIRLINE = float(META.get("calibration_k_airline", 1.0))
36
+ K_ROUTE = float(META.get("calibration_k_route", 1.0))
37
+ WHITELIST = META.get("airline_whitelist", AIRLINES)
38
+ SHARE = META.get("airline_share_overall", {})
39
+ CAP_TARGET = float(META.get("capacity_target_summer", 1000.0))
40
+
41
+ # ====== ХЕЛПЕРЫ ======
42
+ def nth_weekday(year, month, weekday, n):
43
+ d = date(year, month, 1)
44
+ add = (weekday - d.weekday()) % 7
45
+ return d + timedelta(days=add) + timedelta(weeks=n-1)
46
+
47
+ def last_weekday(year, month, weekday):
48
+ d = date(year, month, calendar.monthrange(year, month)[1])
49
+ sub = (d.weekday() - weekday) % 7
50
+ return d - timedelta(days=sub)
51
+
52
+ def daterange(d0, d1):
53
+ cur, out = d0, []
54
+ while cur <= d1:
55
+ out.append(cur)
56
+ cur += timedelta(days=1)
57
+ return out
58
+
59
+ def holiday_windows_for_year(y):
60
+ mem = set(daterange(last_weekday(y,5,0)-timedelta(days=3), last_weekday(y,5,0)))
61
+ july = set(daterange(date(y,6,30), date(y,7,7)))
62
+ labor= set(daterange(nth_weekday(y,9,0,1)-timedelta(days=3), nth_weekday(y,9,0,1)))
63
+ tg = set(daterange(nth_weekday(y,11,3,4)-timedelta(days=2), nth_weekday(y,11,3,4)+timedelta(days=3)))
64
+ xny = set(daterange(date(y,12,20), date(y,12,31))) | set(daterange(date(y+1,1,1), date(y+1,1,5)))
65
+ return {"memorial":mem, "july4":july, "labor":labor, "thanksgiving":tg, "xmas_newyear":xny}
66
+
67
+ def exog_multiplier(ts: pd.Timestamp) -> float:
68
+ y = int(ts.year); m = int(ts.month); dow = int(ts.weekday())
69
+ P = {k: float(EXOG.get(k, v)) for k, v in dict(
70
+ SUMMER_MULT=1.60, FRI_MULT=1.10, SAT_MULT=1.15, SUN_MULT=1.10,
71
+ SPRING_BREAK_MULT=1.12, MEMORIAL_MULT=1.35, JULY4_MULT=1.30,
72
+ LABOR_MULT=1.25, THANKSGIVING_MULT=1.60, XMAS_NEWYEAR_MULT=1.70
73
+ ).items()}
74
+ mult = 1.0
75
+ if m in (6,7,8): mult *= P["SUMMER_MULT"]
76
+ if pd.Timestamp(year=y, month=3, day=10) <= ts <= pd.Timestamp(year=y, month=4, day=10):
77
+ mult *= P["SPRING_BREAK_MULT"]
78
+ if dow == 4: mult *= P["FRI_MULT"]
79
+ if dow == 5: mult *= P["SAT_MULT"]
80
+ if dow == 6: mult *= P["SUN_MULT"]
81
+ hw = holiday_windows_for_year(y); hw_prev = holiday_windows_for_year(y-1)
82
+ if ts.date() in hw["memorial"]: mult *= P["MEMORIAL_MULT"]
83
+ if ts.date() in hw["july4"]: mult *= P["JULY4_MULT"]
84
+ if ts.date() in hw["labor"]: mult *= P["LABOR_MULT"]
85
+ if ts.date() in hw["thanksgiving"]: mult *= P["THANKSGIVING_MULT"]
86
+ if ts.date() in hw["xmas_newyear"] or ts.date() in hw_prev["xmas_newyear"]:
87
+ mult *= P["XMAS_NEWYEAR_MULT"]
88
+ return float(mult)
89
+
90
+ def _blend(child_mean, child_n, prior_mean, prior_w):
91
+ child_n = float(child_n)
92
+ return float((child_n*child_mean + prior_w*prior_mean) / (child_n + prior_w))
93
 
94
+ def baseline_pair(a, m):
95
+ gmean = np.mean(list(A_BLEND.values())) if A_BLEND else 0.0
96
+ a0 = A_BLEND.get(str(a), gmean)
97
+ m0 = M_BLEND.get(int(m), gmean)
98
+ prior = 0.5*a0 + 0.5*m0
99
+ key = f"{a}_{int(m)}"
100
+ if key in AM_MEANS:
101
+ mean = AM_MEANS[key]; cnt = AM_COUNTS.get(key, 0)
102
+ return _blend(mean, cnt, prior, W_AM)
103
+ return float(prior)
104
+
105
+ def make_inputs(airline: str, date_str: str):
106
+ d = pd.to_datetime(date_str)
107
+ dow, month = int(d.weekday()), int(d.month)
108
+ weekofyear = int(d.isocalendar().week)
109
+ is_weekend = int(dow in [5,6])
110
+ sin_dow, cos_dow = np.sin(2*np.pi*dow/7), np.cos(2*np.pi*dow/7)
111
+ sin_mon, cos_mon = np.sin(2*np.pi*(month-1)/12), np.cos(2*np.pi*(month-1)/12)
112
+ exog = exog_multiplier(d)
113
+ X_df = pd.DataFrame([{
114
+ "airline": airline,
115
+ "sin_dow": sin_dow, "cos_dow": cos_dow,
116
+ "sin_month": sin_mon, "cos_month": cos_mon,
117
+ "weekofyear": weekofyear, "is_weekend": is_weekend,
118
+ "exog_mult": exog,
119
+ }])
120
+ base = baseline_pair(airline, month) * exog
121
+ return ct.transform(X_df), np.array([[base]], dtype=float), dict(
122
+ dow=dow, month=month, weekofyear=weekofyear, is_weekend=bool(is_weekend),
123
+ exog=exog, baseline=float(base)
124
+ )
125
+
126
+ def predict_one(airline, date_str):
127
+ X, B, info = make_inputs(airline, date_str)
128
+ y_raw = float(model.predict([X, B], verbose=0)[0,0])
129
+ y_cal = y_raw * K_AIRLINE
130
+ # safety floor: минимум 25% от baseline (или META['floor_frac'])
131
+ floor_val = float(B[0,0]) * max(FLOOR_FRAC, 0.25)
132
+ y_cal = max(y_cal, floor_val)
133
+ return y_cal, info, y_raw
134
+
135
+ # ====== UI ======
136
+ with gr.Blocks(title="Прогноз пассажиров (v6.2)") as demo:
137
+ gr.Markdown("### Прогноз числа пассажиров (учтены сезонность/праздники)")
138
+ with gr.Row():
139
+ gr.Textbox(label="Маршрут", value=ROUTE.replace('_', '→'), interactive=False)
140
+ dd_airline = gr.Dropdown(
141
+ choices=(WHITELIST or AIRLINES or ["American airlines","United"]),
142
+ value=(WHITELIST[0] if WHITELIST else (AIRLINES[0] if AIRLINES else "American airlines")),
143
+ label="Авиакомпания",
144
+ allow_custom_value=False
145
+ )
146
+ tb_date = gr.Textbox(label="Дата (YYYY-MM-DD)", value=str(pd.Timestamp.today().date()))
147
+ sum_all = gr.Checkbox(label="Суммарно по всем авиакомпаниям", value=True)
148
+
149
+ def predict_ui(airline, date_str, sum_all):
150
+ # простая валидация даты
151
  try:
152
+ _ = pd.to_datetime(date_str)
 
 
 
153
  except Exception:
154
+ return 0, "Некорректная дата, формат YYYY-MM-DD."
155
+
156
+ if sum_all:
157
+ total_raw, lines = 0.0, []
158
+ for a in (WHITELIST or AIRLINES or [airline]):
159
+ X, B, info = make_inputs(a, date_str)
160
+ y_raw = float(model.predict([X, B], verbose=0)[0,0])
161
+ total_raw += y_raw
162
+ lines.append(f"- {a}: base×M={info['baseline']:.1f} (M={info['exog']:.2f}) → raw {y_raw:.1f}")
163
+ total = total_raw * K_ROUTE
164
+ msg = (
165
+ f"Маршрут: **{ORIGIN}→{DEST}**, дата: **{date_str}** \n"
166
+ f"Суммарный прогноз: **{round(total)}** "
167
+ f"(raw_sum={total_raw:.1f}, k_route={K_ROUTE:.2f}; целевой летний≈{int(CAP_TARGET)})\n"
168
+ + "\n".join(lines)
169
+ )
170
+ return round(total), msg
171
+
172
+ y, info, y_raw = predict_one(airline, date_str)
173
+ warn = ""
174
+ if SHARE and SHARE.get(airline, 0.0) < 0.05:
175
+ warn = (
176
+ "\n\n_Примечание: доля этой авиакомпании на маршруте <5%, "
177
+ "прогноз может быть занижен; попробуйте «Суммарно по всем»._"
178
+ )
179
+ msg = (
180
+ f"Маршрут: **{ORIGIN}→{DEST}**, авиакомпания: **{airline}**, дата: **{date_str}** \n"
181
+ f"M(date)={info['exog']:.2f}; baseline×M={info['baseline']:.1f} \n"
182
+ f"Прогноз: **{round(y)}** (raw={y_raw:.1f}, k_airline={K_AIRLINE:.2f}){warn}"
183
+ )
184
+ return round(y), msg
185
+
186
+ btn = gr.Button("Рассчитать")
187
+ out_pred = gr.Number(label="Прогноз пассажиров", interactive=False)
188
+ out_expl = gr.Markdown()
189
+
190
+ btn.click(predict_ui, [dd_airline, tb_date, sum_all], [out_pred, out_expl])
191
+
192
+ # очередь без спорных параметров (кросс-версийно)
 
 
 
 
 
193
  try:
194
+ demo.queue(max_size=32)
195
  except TypeError:
196
  demo.queue()
197
 
198
+ # HF Spaces сам запускает; локально — так:
199
  if __name__ == "__main__":
200
  demo.launch(server_name="0.0.0.0", server_port=7860)
201
+