import gradio as gr import pandas as pd import numpy as np import joblib import json import matplotlib.pyplot as plt import seaborn as sns import datetime import folium import requests import os import firebase_admin from firebase_admin import credentials, db from folium.plugins import MarkerCluster from tensorflow.keras.models import load_model FIREBASE_WEB_API_KEY = os.environ.get("FIREBASE_WEB_API_KEY") FIREBASE_KEY_JSON = os.environ.get("FIREBASE_KEY_JSON") if not FIREBASE_KEY_JSON and os.path.exists("firebase_key.json"): with open("firebase_key.json") as f: FIREBASE_KEY_JSON = f.read() if FIREBASE_KEY_JSON and not firebase_admin._apps: try: cred_dict = json.loads(FIREBASE_KEY_JSON) cred = credentials.Certificate(cred_dict) firebase_admin.initialize_app(cred, { 'databaseURL': f"https://{cred_dict['project_id']}-default-rtdb.firebaseio.com/" }) print("✅ Firebase Admin Connected.") except Exception as e: print(f"❌ Firebase Init Error: {e}") # 모델 경로 MODEL_PATH = 'seq2seq_model3-2.h5' SCALER_X_PATH = 'scaler_X3-2.save' SCALER_Y_PATH = 'scaler_y3-2.save' STATION_DICT_PATH = 'station_dict4-2.json' DATA_PATH = 'tashu.csv' TIMESTEPS = 10 FEATURE_COLS = ['temp', 'wind_speed'] DEFAULT_JSON = json.dumps([ {"_time": "2025-11-27 15:00:00", "station_id": "ST0956", "temp": 16.0, "wind_speed": 3.0, "parking_count": 3}, {"_time": "2025-11-27 15:10:00", "station_id": "ST0956", "temp": 16.0, "wind_speed": 3.0, "parking_count": 3}, {"_time": "2025-11-27 15:20:00", "station_id": "ST0956", "temp": 16.0, "wind_speed": 3.0, "parking_count": 2}, {"_time": "2025-11-27 15:30:00", "station_id": "ST0956", "temp": 16.0, "wind_speed": 3.0, "parking_count": 2}, {"_time": "2025-11-27 15:40:00", "station_id": "ST0956", "temp": 16.0, "wind_speed": 3.0, "parking_count": 1}, {"_time": "2025-11-27 15:50:00", "station_id": "ST0956", "temp": 16.0, "wind_speed": 3.0, "parking_count": 0} ], indent=2) print("Loading System...") try: model = load_model(MODEL_PATH, compile=False) scaler_X = joblib.load(SCALER_X_PATH) scaler_y = joblib.load(SCALER_Y_PATH) with open(STATION_DICT_PATH, 'r') as f: station_dict = json.load(f) print("✅ Model Artifacts Loaded.") except: model, scaler_X, scaler_y, station_dict = None, None, None, {} try: df = pd.read_csv(DATA_PATH) df['datetime'] = pd.to_datetime(df['_time']) if df['datetime'].dt.tz is None: df['datetime'] = df['datetime'].dt.tz_localize('UTC') df['datetime_kr'] = df['datetime'].dt.tz_convert('Asia/Seoul') df['hour_kr'] = df['datetime_kr'].dt.hour df['dayofweek_kr'] = df['datetime_kr'].dt.dayofweek df['is_weekend'] = df['dayofweek_kr'].apply(lambda x: 'Weekend' if x >= 5 else 'Weekday') print("✅ CSV Data Loaded.") except: df = pd.DataFrame() # ------------------------------------------------ # 2. 내부 로직 # ------------------------------------------------ def register_user(email, password): url = f"https://identitytoolkit.googleapis.com/v1/accounts:signUp?key={FIREBASE_WEB_API_KEY}" payload = {"email": email, "password": password, "returnSecureToken": True} try: res = requests.post(url, json=payload) return "✅ 회원가입 성공! 로그인해주세요." if res.status_code == 200 else f"❌ 오류: {res.json().get('error', {}).get('message')}" except Exception as e: return f"❌ 통신 오류: {e}" def login_user(email, password): url = f"https://identitytoolkit.googleapis.com/v1/accounts:signInWithPassword?key={FIREBASE_WEB_API_KEY}" payload = {"email": email, "password": password, "returnSecureToken": True} try: res = requests.post(url, json=payload) if res.status_code == 200: return res.json()['localId'], f"✅ 로그인 성공! (ID: {email})" else: return None, f"❌ 로그인 실패: {res.json().get('error', {}).get('message')}" except Exception as e: return None, f"❌ 통신 오류: {e}" def add_favorite(user_id, station_id, alias="회사 근처"): if not user_id: return "로그인이 필요합니다." try: station_id = str(station_id).strip() ref = db.reference(f'users/{user_id}/favorites') favs = ref.get() or {} if station_id in favs: return "이미 등록된 정류장입니다." favs[station_id] = alias ref.set(favs) return f"✅ 추가 완료: {station_id}" except Exception as e: return f"DB 오류: {str(e)}" def delete_favorite(user_id, station_key): if not user_id: return "로그인이 필요합니다." try: # "ST0956" 처럼 ID만 들어온다고 가정 st_id = station_key.split(' ')[0] ref = db.reference(f'users/{user_id}/favorites/{st_id}') ref.delete() return f"🗑️ 삭제 완료: {st_id}" except Exception as e: return f"삭제 오류: {str(e)}" def get_favorites(user_id): if not user_id: return {} try: return db.reference(f'users/{user_id}/favorites').get() or {} except: return {} def run_prediction_logic(input_df): if model is None: return np.zeros(6) input_df['datetime'] = pd.to_datetime(input_df['_time']) input_df['hour'] = input_df['datetime'].dt.hour input_df['minute'] = input_df['datetime'].dt.minute input_df['dayofweek'] = input_df['datetime'].dt.dayofweek input_df['station_idx'] = input_df['station_id'].map(station_dict).fillna(0) X_num = input_df[FEATURE_COLS + ['hour','minute','dayofweek']].values.astype(np.float32) X_scaled = scaler_X.transform(X_num) station_seq = input_df['station_idx'].values.astype(np.int32) if len(X_scaled) < TIMESTEPS: pad_len = TIMESTEPS - len(X_scaled) X_scaled = np.vstack([np.zeros((pad_len, X_scaled.shape[1]), dtype=np.float32), X_scaled]) station_seq = np.hstack([np.zeros(pad_len, dtype=np.int32), station_seq]) X_scaled = X_scaled.reshape(1, TIMESTEPS, X_scaled.shape[1]) station_seq = station_seq.reshape(1, TIMESTEPS) y_pred_scaled = model.predict([X_scaled, station_seq], verbose=0) return scaler_y.inverse_transform(y_pred_scaled.reshape(-1,1)).reshape(-1) # ------------------------------------------------ # 3. 탭별 시각화 함수 # ------------------------------------------------ def filter_dataframe(station, s_date, e_date, h_start, h_end): temp_df = df.copy() try: s = pd.to_datetime(s_date).tz_localize('Asia/Seoul') e = pd.to_datetime(e_date).tz_localize('Asia/Seoul') + pd.Timedelta(days=1) temp_df = temp_df[(temp_df['datetime_kr'] >= s) & (temp_df['datetime_kr'] < e)] except: pass temp_df = temp_df[(temp_df['hour_kr'] >= h_start) & (temp_df['hour_kr'] <= h_end)] if station and station.strip(): temp_df = temp_df[temp_df['station_id'] == station.strip()] return temp_df def plot_time_analysis(station, s_date, e_date, h_start, h_end): target_df = filter_dataframe(station, s_date, e_date, h_start, h_end) if target_df.empty: return plt.figure() fig = plt.figure(figsize=(10, 5)) if 'is_weekend' in target_df.columns: sns.lineplot(data=target_df, x='hour_kr', y='parking_count', hue='is_weekend', errorbar=None) else: sns.lineplot(data=target_df, x='hour_kr', y='parking_count', errorbar=None) plt.title('Hourly Trend') plt.grid(True, alpha=0.3) return fig def plot_temp_analysis(s_date, e_date, h_start, h_end): target_df = filter_dataframe("", s_date, e_date, h_start, h_end) if target_df.empty: return plt.figure() fig = plt.figure(figsize=(10, 5)) target_df['temp_round'] = (target_df['temp'] / 2).round() * 2 temp_agg = target_df.groupby('temp_round')['parking_count'].mean().reset_index() sns.scatterplot(data=temp_agg, x='temp_round', y='parking_count', s=100, color='red') sns.lineplot(data=temp_agg, x='temp_round', y='parking_count', color='red', alpha=0.3) plt.title('Parking Count vs Temperature') plt.grid(True, alpha=0.3) return fig def plot_heatmap(s_date, e_date, h_start, h_end, heatmap_type): target_df = filter_dataframe("", s_date, e_date, h_start, h_end) if target_df.empty: return plt.figure() if heatmap_type == "Weekday": target_df = target_df[target_df['is_weekend'] == 'Weekday'] elif heatmap_type == "Weekend": target_df = target_df[target_df['is_weekend'] == 'Weekend'] fig = plt.figure(figsize=(8, 6)) sns.heatmap(target_df[['parking_count', 'temp', 'wind_speed', 'hour_kr']].corr(), annot=True, cmap='coolwarm', fmt='.2f') plt.title(f'Correlation ({heatmap_type})') return fig def simulate_scenario(station_id, temp, wind, current_count): now = datetime.datetime.now() input_list = [] st_id = str(station_id).strip() for i in range(TIMESTEPS): past_time = now - datetime.timedelta(minutes=(TIMESTEPS - 1 - i) * 10) input_list.append({'_time': past_time.strftime('%Y-%m-%d %H:%M:%S'), 'station_id': st_id, 'temp': float(temp), 'wind_speed': float(wind), 'parking_count': float(current_count)}) try: preds = run_prediction_logic(pd.DataFrame(input_list)) res, x, y = [], [], [] for h, v in enumerate(preds): val = max(0, float(v)) res.append({"Time": f"+{(h+1)*10}m", "Pred": round(val, 2)}) x.append(f"+{(h+1)*10}m"); y.append(val) fig = plt.figure(figsize=(10, 4)); plt.plot(x, y, marker='o', color='green'); plt.grid(True, alpha=0.3); plt.ylim(bottom=0) return pd.DataFrame(res), fig except: return pd.DataFrame(), None def draw_map_and_add_fav(temp, wind): try: meta_df = pd.read_csv("stations_meta.csv") except: return "파일 없음" m = folium.Map(location=[36.3504, 127.3845], zoom_start=13) mc = MarkerCluster().add_to(m) for _, r in meta_df.sample(n=min(100, len(meta_df)), random_state=42).iterrows(): color = 'red' if (15 <= temp <= 25 and wind < 5) else 'green' folium.Marker([r['lat'], r['lon']], popup=f"{r['name']}
{r['station_id']}", icon=folium.Icon(color=color, icon='bicycle', prefix='fa')).add_to(mc) return m._repr_html_() def predict_json_manual(j): try: return pd.DataFrame([{"Step": i, "Pred": float(v)} for i, v in enumerate(run_prediction_logic(pd.DataFrame(json.loads(j))))]) except Exception as e: return pd.DataFrame({"Error": [str(e)]}) # [Tab 5] 마이페이지 (UI 개선 적용됨) def render_mypage_with_delete_list(user_id, temp, wind): if not user_id: return "

로그인이 필요합니다.

", None, gr.update(choices=[]) favs = get_favorites(user_id) fav_list = list(favs.keys()) if favs else [] if not favs: return "

등록된 즐겨찾기가 없습니다.

", None, gr.update(choices=[]) now = datetime.datetime.now() plot_data = {} cards_html = "
" for st_id, alias in favs.items(): input_list = [] for i in range(TIMESTEPS): past = now - datetime.timedelta(minutes=(TIMESTEPS - 1 - i) * 10) input_list.append({'_time': past.strftime('%Y-%m-%d %H:%M:%S'), 'station_id': st_id, 'temp': float(temp), 'wind_speed': float(wind), 'parking_count': 3.0}) preds = run_prediction_logic(pd.DataFrame(input_list)) final_pred = max(0, float(preds[-1])) # [요청 2] 그래프 라벨을 별명 대신 정류장ID로 표시 plot_data[f"{st_id}"] = preds bg, st_txt = ("#FFCDD2", "🔴 부족") if final_pred < 1.5 else ("#BBDEFB", "🔵 과잉") if final_pred > 5 else ("#C8E6C9", "🟢 여유") # [요청 1] 카드 글씨 색상 검정(#000)으로 변경 (가독성 향상) cards_html += f"""

{alias}

{st_id}

{final_pred:.1f}대

{st_txt}

""" cards_html += "
" fig = plt.figure(figsize=(10, 5)) for l, y in plot_data.items(): plt.plot([f"+{(i+1)*10}m" for i in range(len(y))], y, marker='o', label=l) plt.grid(True, alpha=0.3); plt.legend() plt.title("My Stations Forecast (ID based)") return cards_html, fig, gr.update(choices=fav_list, value=None) # ------------------------------------------------ # 4. UI 구성 # ------------------------------------------------ theme = gr.themes.Soft(primary_hue="green").set(body_background_fill="*neutral_50") with gr.Blocks(theme=theme, title="Tashu AI Service") as demo: user_id_state = gr.State(value=None) # [Tab 0] 로그인 탭 with gr.Tab("🔑 로그인", id="login_tab") as login_tab: gr.Markdown("# 🚲 대전 타슈 AI 관제 서비스에 오신 것을 환영합니다.") with gr.Row(): with gr.Column(): l_email = gr.Textbox(label="이메일"); l_pass = gr.Textbox(label="비밀번호", type="password") btn_login = gr.Button("로그인", variant="primary") with gr.Column(): r_email = gr.Textbox(label="이메일"); r_pass = gr.Textbox(label="비밀번호", type="password") btn_reg = gr.Button("회원가입") login_msg = gr.Textbox(label="알림", interactive=False) # [Tab 0.5] 홈 (인트로) with gr.Tab("🏠 홈 (소개)", id="home_tab"): gr.Markdown( """ # 👋 안녕하세요! 타슈 AI 관제 서비스입니다. 본 서비스는 **Deep Learning (LSTM)** 모델을 활용하여 대전 공공자전거 '타슈'의 수요를 예측하고 관리합니다. ### 📌 주요 기능 안내 * **📊 데이터 인사이트:** 과거 1개월치 데이터를 시간, 온도, 요일별로 심층 분석합니다. * **🔮 예측 시뮬레이터:** "오늘 날씨가 추우면 자전거가 남을까?" 가상의 시나리오를 돌려보세요. * **🗺️ 예측 지도:** 대전 전체 지도를 보며 1시간 뒤 자전거가 부족할 지역을 시각화합니다. * **👤 마이페이지:** 자주 가는 정류장을 '즐겨찾기'하고, 나만의 맞춤형 예측 리포트를 받아보세요. --- *좌측 상단의 탭을 클릭하여 기능을 이용해보세요!* """ ) # [Tab 1] 데이터 분석 with gr.Tab("📊 데이터 인사이트"): gr.Markdown("### 📈 영역별 상세 분석") with gr.Group(): gr.Markdown("#### 1️⃣ 시간대별 이용 패턴") with gr.Row(): t1_st = gr.Textbox(label="정류장 ID (옵션)") t1_date_s = gr.Textbox(label="시작 날짜", value="2025-10-20") t1_date_e = gr.Textbox(label="종료 날짜", value="2025-11-25") with gr.Row(): t1_h_s = gr.Slider(0, 23, 6, step=1, label="시작 시간") t1_h_e = gr.Slider(0, 24, 22, step=1, label="종료 시간") btn_plot1 = gr.Button("시간 분석 실행", variant="primary") plot1 = gr.Plot() btn_plot1.click(plot_time_analysis, inputs=[t1_st, t1_date_s, t1_date_e, t1_h_s, t1_h_e], outputs=plot1) gr.Markdown("---") with gr.Group(): gr.Markdown("#### 2️⃣ 기온에 따른 이용률") with gr.Row(): t2_date_s = gr.Textbox(label="시작 날짜", value="2025-10-20") t2_date_e = gr.Textbox(label="종료 날짜", value="2025-11-25") t2_h_s = gr.Slider(0, 23, 6, step=1, label="시작 시간") t2_h_e = gr.Slider(0, 24, 22, step=1, label="종료 시간") btn_plot2 = gr.Button("온도 분석 실행", variant="primary") plot2 = gr.Plot() btn_plot2.click(plot_temp_analysis, inputs=[t2_date_s, t2_date_e, t2_h_s, t2_h_e], outputs=plot2) gr.Markdown("---") with gr.Group(): gr.Markdown("#### 3️⃣ 변수 간 상관관계") with gr.Row(): t3_date_s = gr.Textbox(label="시작 날짜", value="2025-10-20") t3_date_e = gr.Textbox(label="종료 날짜", value="2025-11-25") t3_type = gr.Radio(["All", "Weekday", "Weekend"], value="All", label="분석 대상") with gr.Row(): t3_h_s = gr.Slider(0, 23, 6, step=1, label="시작 시간") t3_h_e = gr.Slider(0, 24, 22, step=1, label="종료 시간") btn_plot3 = gr.Button("상관관계 분석 실행", variant="primary") plot3 = gr.Plot() btn_plot3.click(plot_heatmap, inputs=[t3_date_s, t3_date_e, t3_h_s, t3_h_e, t3_type], outputs=plot3) # [Tab 2] 시뮬레이터 with gr.Tab("🔮 예측 시뮬레이터"): with gr.Row(): with gr.Column(scale=1): inp_st = gr.Textbox(label="정류장 ID", value="ST0003") inp_temp = gr.Slider(-10, 40, 15, label="기온") inp_wind = gr.Slider(0, 30, 2, label="풍속") inp_cnt = gr.Slider(0, 30, 3, label="현재 대수") btn_sim = gr.Button("예측 실행", variant="primary") with gr.Column(scale=2): out_plot = gr.Plot(); out_df = gr.Dataframe() btn_sim.click(simulate_scenario, inputs=[inp_st, inp_temp, inp_wind, inp_cnt], outputs=[out_df, out_plot]) # [Tab 3] 지도 with gr.Tab("🗺️ 지도 & 즐겨찾기"): with gr.Row(): m_temp = gr.Slider(-10, 35, 15, label="기온"); m_wind = gr.Slider(0, 20, 2, label="풍속") with gr.Row(): with gr.Column(scale=3): out_map = gr.HTML(label="Map") with gr.Column(scale=1): fav_id = gr.Textbox(label="정류장 ID"); fav_alias = gr.Textbox(label="별명") btn_add_fav = gr.Button("즐겨찾기 저장", variant="stop") fav_msg = gr.Textbox(label="결과") btn_load_map = gr.Button("지도 불러오기") btn_load_map.click(draw_map_and_add_fav, inputs=[m_temp, m_wind], outputs=out_map) btn_add_fav.click(add_favorite, inputs=[user_id_state, fav_id, fav_alias], outputs=fav_msg) # [Tab 4] JSON with gr.Tab("📝 JSON 입력"): inp_json = gr.Code(value=DEFAULT_JSON, language="json") btn_json = gr.Button("예측"); out_json = gr.Dataframe() btn_json.click(predict_json_manual, inputs=inp_json, outputs=out_json) # [Tab 5] 마이페이지 with gr.Tab("👤 마이페이지"): gr.Markdown("### 🌟 나의 맞춤형 예측 리포트") btn_refresh = gr.Button("새로고침 / 불러오기", variant="primary") out_cards = gr.HTML() out_my_plot = gr.Plot() gr.Markdown("---") gr.Markdown("#### 🗑️ 즐겨찾기 관리") with gr.Row(): del_dropdown = gr.Dropdown(label="삭제할 정류장 선택", choices=[], interactive=True) btn_delete = gr.Button("삭제하기", variant="stop") del_msg = gr.Textbox(label="삭제 결과", interactive=False) btn_refresh.click( render_mypage_with_delete_list, inputs=[user_id_state, m_temp, m_wind], outputs=[out_cards, out_my_plot, del_dropdown] ) def handle_delete(uid, key): return delete_favorite(uid, key) btn_delete.click(handle_delete, inputs=[user_id_state, del_dropdown], outputs=del_msg) # 로그인 로직 def handle_login_and_hide(e, p): uid, msg = login_user(e, p) return (uid, msg, gr.update(visible=False)) if uid else (None, msg, gr.update(visible=True)) btn_login.click(handle_login_and_hide, inputs=[l_email, l_pass], outputs=[user_id_state, login_msg, login_tab]) btn_reg.click(register_user, inputs=[r_email, r_pass], outputs=login_msg) if __name__ == "__main__": demo.launch()