import gradio as gr
import pandas as pd
import numpy as np
import joblib
import json
import matplotlib.pyplot as plt
import seaborn as sns
import datetime
import folium
import requests
import os
import firebase_admin
from firebase_admin import credentials, db
from folium.plugins import MarkerCluster
from tensorflow.keras.models import load_model
FIREBASE_WEB_API_KEY = os.environ.get("FIREBASE_WEB_API_KEY")
FIREBASE_KEY_JSON = os.environ.get("FIREBASE_KEY_JSON")
if not FIREBASE_KEY_JSON and os.path.exists("firebase_key.json"):
with open("firebase_key.json") as f:
FIREBASE_KEY_JSON = f.read()
if FIREBASE_KEY_JSON and not firebase_admin._apps:
try:
cred_dict = json.loads(FIREBASE_KEY_JSON)
cred = credentials.Certificate(cred_dict)
firebase_admin.initialize_app(cred, {
'databaseURL': f"https://{cred_dict['project_id']}-default-rtdb.firebaseio.com/"
})
print("✅ Firebase Admin Connected.")
except Exception as e:
print(f"❌ Firebase Init Error: {e}")
# 모델 경로
MODEL_PATH = 'seq2seq_model3-2.h5'
SCALER_X_PATH = 'scaler_X3-2.save'
SCALER_Y_PATH = 'scaler_y3-2.save'
STATION_DICT_PATH = 'station_dict4-2.json'
DATA_PATH = 'tashu.csv'
TIMESTEPS = 10
FEATURE_COLS = ['temp', 'wind_speed']
DEFAULT_JSON = json.dumps([
{"_time": "2025-11-27 15:00:00", "station_id": "ST0956", "temp": 16.0, "wind_speed": 3.0, "parking_count": 3},
{"_time": "2025-11-27 15:10:00", "station_id": "ST0956", "temp": 16.0, "wind_speed": 3.0, "parking_count": 3},
{"_time": "2025-11-27 15:20:00", "station_id": "ST0956", "temp": 16.0, "wind_speed": 3.0, "parking_count": 2},
{"_time": "2025-11-27 15:30:00", "station_id": "ST0956", "temp": 16.0, "wind_speed": 3.0, "parking_count": 2},
{"_time": "2025-11-27 15:40:00", "station_id": "ST0956", "temp": 16.0, "wind_speed": 3.0, "parking_count": 1},
{"_time": "2025-11-27 15:50:00", "station_id": "ST0956", "temp": 16.0, "wind_speed": 3.0, "parking_count": 0}
], indent=2)
print("Loading System...")
try:
model = load_model(MODEL_PATH, compile=False)
scaler_X = joblib.load(SCALER_X_PATH)
scaler_y = joblib.load(SCALER_Y_PATH)
with open(STATION_DICT_PATH, 'r') as f:
station_dict = json.load(f)
print("✅ Model Artifacts Loaded.")
except:
model, scaler_X, scaler_y, station_dict = None, None, None, {}
try:
df = pd.read_csv(DATA_PATH)
df['datetime'] = pd.to_datetime(df['_time'])
if df['datetime'].dt.tz is None:
df['datetime'] = df['datetime'].dt.tz_localize('UTC')
df['datetime_kr'] = df['datetime'].dt.tz_convert('Asia/Seoul')
df['hour_kr'] = df['datetime_kr'].dt.hour
df['dayofweek_kr'] = df['datetime_kr'].dt.dayofweek
df['is_weekend'] = df['dayofweek_kr'].apply(lambda x: 'Weekend' if x >= 5 else 'Weekday')
print("✅ CSV Data Loaded.")
except:
df = pd.DataFrame()
# ------------------------------------------------
# 2. 내부 로직
# ------------------------------------------------
def register_user(email, password):
url = f"https://identitytoolkit.googleapis.com/v1/accounts:signUp?key={FIREBASE_WEB_API_KEY}"
payload = {"email": email, "password": password, "returnSecureToken": True}
try:
res = requests.post(url, json=payload)
return "✅ 회원가입 성공! 로그인해주세요." if res.status_code == 200 else f"❌ 오류: {res.json().get('error', {}).get('message')}"
except Exception as e: return f"❌ 통신 오류: {e}"
def login_user(email, password):
url = f"https://identitytoolkit.googleapis.com/v1/accounts:signInWithPassword?key={FIREBASE_WEB_API_KEY}"
payload = {"email": email, "password": password, "returnSecureToken": True}
try:
res = requests.post(url, json=payload)
if res.status_code == 200:
return res.json()['localId'], f"✅ 로그인 성공! (ID: {email})"
else:
return None, f"❌ 로그인 실패: {res.json().get('error', {}).get('message')}"
except Exception as e: return None, f"❌ 통신 오류: {e}"
def add_favorite(user_id, station_id, alias="회사 근처"):
if not user_id: return "로그인이 필요합니다."
try:
station_id = str(station_id).strip()
ref = db.reference(f'users/{user_id}/favorites')
favs = ref.get() or {}
if station_id in favs: return "이미 등록된 정류장입니다."
favs[station_id] = alias
ref.set(favs)
return f"✅ 추가 완료: {station_id}"
except Exception as e: return f"DB 오류: {str(e)}"
def delete_favorite(user_id, station_key):
if not user_id: return "로그인이 필요합니다."
try:
# "ST0956" 처럼 ID만 들어온다고 가정
st_id = station_key.split(' ')[0]
ref = db.reference(f'users/{user_id}/favorites/{st_id}')
ref.delete()
return f"🗑️ 삭제 완료: {st_id}"
except Exception as e: return f"삭제 오류: {str(e)}"
def get_favorites(user_id):
if not user_id: return {}
try:
return db.reference(f'users/{user_id}/favorites').get() or {}
except: return {}
def run_prediction_logic(input_df):
if model is None: return np.zeros(6)
input_df['datetime'] = pd.to_datetime(input_df['_time'])
input_df['hour'] = input_df['datetime'].dt.hour
input_df['minute'] = input_df['datetime'].dt.minute
input_df['dayofweek'] = input_df['datetime'].dt.dayofweek
input_df['station_idx'] = input_df['station_id'].map(station_dict).fillna(0)
X_num = input_df[FEATURE_COLS + ['hour','minute','dayofweek']].values.astype(np.float32)
X_scaled = scaler_X.transform(X_num)
station_seq = input_df['station_idx'].values.astype(np.int32)
if len(X_scaled) < TIMESTEPS:
pad_len = TIMESTEPS - len(X_scaled)
X_scaled = np.vstack([np.zeros((pad_len, X_scaled.shape[1]), dtype=np.float32), X_scaled])
station_seq = np.hstack([np.zeros(pad_len, dtype=np.int32), station_seq])
X_scaled = X_scaled.reshape(1, TIMESTEPS, X_scaled.shape[1])
station_seq = station_seq.reshape(1, TIMESTEPS)
y_pred_scaled = model.predict([X_scaled, station_seq], verbose=0)
return scaler_y.inverse_transform(y_pred_scaled.reshape(-1,1)).reshape(-1)
# ------------------------------------------------
# 3. 탭별 시각화 함수
# ------------------------------------------------
def filter_dataframe(station, s_date, e_date, h_start, h_end):
temp_df = df.copy()
try:
s = pd.to_datetime(s_date).tz_localize('Asia/Seoul')
e = pd.to_datetime(e_date).tz_localize('Asia/Seoul') + pd.Timedelta(days=1)
temp_df = temp_df[(temp_df['datetime_kr'] >= s) & (temp_df['datetime_kr'] < e)]
except: pass
temp_df = temp_df[(temp_df['hour_kr'] >= h_start) & (temp_df['hour_kr'] <= h_end)]
if station and station.strip():
temp_df = temp_df[temp_df['station_id'] == station.strip()]
return temp_df
def plot_time_analysis(station, s_date, e_date, h_start, h_end):
target_df = filter_dataframe(station, s_date, e_date, h_start, h_end)
if target_df.empty: return plt.figure()
fig = plt.figure(figsize=(10, 5))
if 'is_weekend' in target_df.columns:
sns.lineplot(data=target_df, x='hour_kr', y='parking_count', hue='is_weekend', errorbar=None)
else:
sns.lineplot(data=target_df, x='hour_kr', y='parking_count', errorbar=None)
plt.title('Hourly Trend')
plt.grid(True, alpha=0.3)
return fig
def plot_temp_analysis(s_date, e_date, h_start, h_end):
target_df = filter_dataframe("", s_date, e_date, h_start, h_end)
if target_df.empty: return plt.figure()
fig = plt.figure(figsize=(10, 5))
target_df['temp_round'] = (target_df['temp'] / 2).round() * 2
temp_agg = target_df.groupby('temp_round')['parking_count'].mean().reset_index()
sns.scatterplot(data=temp_agg, x='temp_round', y='parking_count', s=100, color='red')
sns.lineplot(data=temp_agg, x='temp_round', y='parking_count', color='red', alpha=0.3)
plt.title('Parking Count vs Temperature')
plt.grid(True, alpha=0.3)
return fig
def plot_heatmap(s_date, e_date, h_start, h_end, heatmap_type):
target_df = filter_dataframe("", s_date, e_date, h_start, h_end)
if target_df.empty: return plt.figure()
if heatmap_type == "Weekday": target_df = target_df[target_df['is_weekend'] == 'Weekday']
elif heatmap_type == "Weekend": target_df = target_df[target_df['is_weekend'] == 'Weekend']
fig = plt.figure(figsize=(8, 6))
sns.heatmap(target_df[['parking_count', 'temp', 'wind_speed', 'hour_kr']].corr(), annot=True, cmap='coolwarm', fmt='.2f')
plt.title(f'Correlation ({heatmap_type})')
return fig
def simulate_scenario(station_id, temp, wind, current_count):
now = datetime.datetime.now()
input_list = []
st_id = str(station_id).strip()
for i in range(TIMESTEPS):
past_time = now - datetime.timedelta(minutes=(TIMESTEPS - 1 - i) * 10)
input_list.append({'_time': past_time.strftime('%Y-%m-%d %H:%M:%S'), 'station_id': st_id, 'temp': float(temp), 'wind_speed': float(wind), 'parking_count': float(current_count)})
try:
preds = run_prediction_logic(pd.DataFrame(input_list))
res, x, y = [], [], []
for h, v in enumerate(preds):
val = max(0, float(v))
res.append({"Time": f"+{(h+1)*10}m", "Pred": round(val, 2)})
x.append(f"+{(h+1)*10}m"); y.append(val)
fig = plt.figure(figsize=(10, 4)); plt.plot(x, y, marker='o', color='green'); plt.grid(True, alpha=0.3); plt.ylim(bottom=0)
return pd.DataFrame(res), fig
except: return pd.DataFrame(), None
def draw_map_and_add_fav(temp, wind):
try: meta_df = pd.read_csv("stations_meta.csv")
except: return "파일 없음"
m = folium.Map(location=[36.3504, 127.3845], zoom_start=13)
mc = MarkerCluster().add_to(m)
for _, r in meta_df.sample(n=min(100, len(meta_df)), random_state=42).iterrows():
color = 'red' if (15 <= temp <= 25 and wind < 5) else 'green'
folium.Marker([r['lat'], r['lon']], popup=f"{r['name']}
{r['station_id']}", icon=folium.Icon(color=color, icon='bicycle', prefix='fa')).add_to(mc)
return m._repr_html_()
def predict_json_manual(j):
try: return pd.DataFrame([{"Step": i, "Pred": float(v)} for i, v in enumerate(run_prediction_logic(pd.DataFrame(json.loads(j))))])
except Exception as e: return pd.DataFrame({"Error": [str(e)]})
# [Tab 5] 마이페이지 (UI 개선 적용됨)
def render_mypage_with_delete_list(user_id, temp, wind):
if not user_id:
return "
{st_id}
{st_txt}