IOI-RUN / db_router.py
Roudrigus's picture
Update db_router.py
9aabbbe verified
# -*- coding: utf-8 -*-
"""
db_router.py — Roteia Engine/SessionLocal para:
- 'prod' → Load.db
- 'test' → Load_teste.db
- 'treinamento' → Load_treinamento.db
Mantém a escolha do usuário em st.session_state (quando disponível) e em variável
de ambiente (DB_CHOICE) para persistir entre reruns/contexts.
APIs expostas (compatíveis com o app):
• get_available_choices() -> list[str]
• set_current_db_choice(choice: str) -> None
• set_db_choice(choice: str) -> None
• current_db_choice() -> str
• bank_label(choice: str) -> str
• get_engine() -> sqlalchemy.Engine
• get_session_factory() -> sqlalchemy.orm.sessionmaker
• SessionLocal() -> sqlalchemy.orm.Session
Inclui garantia de criação do diretório pai do SQLite com fallback para ~/.ioirun.
Compatível com execução fora do Streamlit (fallback em estado global).
"""
from __future__ import annotations
import os
from typing import Dict, Optional, Any
# Streamlit é preferível; mas se não houver (execução fora do app), caímos em fallback
try:
import streamlit as st
_HAS_ST = True
except Exception:
_HAS_ST = False
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
# ============================
# Configuração e caminhos base
# ============================
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
# Nomes de arquivos de banco conforme sua especificação
PROD_DB_NAME = "Load.db"
TEST_DB_NAME = "Load_teste.db"
TREINAMENTO_DB_NAME = "Load_treinamento.db"
# (Opcional) uso de .env/Secrets para apontar outras URIs (Postgres/MySQL/SQLite abs.)
DB1_PROD_URL = os.getenv("DB1_PROD_URL", f"sqlite:///{os.path.join(BASE_DIR, PROD_DB_NAME)}")
DB2_TEST_URL = os.getenv("DB2_TEST_URL", f"sqlite:///{os.path.join(BASE_DIR, TEST_DB_NAME)}")
DB3_TREINAMENTO_URL = os.getenv("DB3_TREINAMENTO_URL", f"sqlite:///{os.path.join(BASE_DIR, TREINAMENTO_DB_NAME)}")
DB_URLS: Dict[str, str] = {
"prod": DB1_PROD_URL,
"test": DB2_TEST_URL,
"treinamento": DB3_TREINAMENTO_URL,
}
# Aliases aceitos (ex.: "train" → "treinamento")
CHOICE_ALIASES: Dict[str, str] = {
"train": "treinamento",
}
DB_LABELS: Dict[str, str] = {
"prod": "🟢 Produção",
"test": "🔴 Teste",
"treinamento": "🟡 Treinamento",
}
# ============================
# Garantir diretório do SQLite
# ============================
def _ensure_parent_dir_sqlite(url: str) -> str:
"""
Garante que a pasta do arquivo SQLite exista. Se não conseguir,
direciona para ~/.ioirun/<arquivo>.db (gravável no Spaces).
"""
if not url or not url.startswith("sqlite"):
return url
# Extrai caminho local do SQLite (sqlite:////abs/path ou sqlite:///rel/path)
prefix = "sqlite:///"
path = url[len(prefix):] if url.startswith(prefix) else url
# Normaliza e cria parent
file_path = os.path.abspath(path)
parent = os.path.dirname(file_path)
try:
os.makedirs(parent, exist_ok=True)
return url
except Exception:
# Fallback para HOME gravável
home_dir = os.path.join(os.path.expanduser("~"), ".ioirun")
os.makedirs(home_dir, exist_ok=True)
alt = os.path.join(home_dir, os.path.basename(file_path))
return f"sqlite:///{alt}"
# ============================
# Helpers de UI
# ============================
def get_available_choices() -> list[str]:
"""Lista as chaves de bancos disponíveis (para UI)."""
return list(DB_URLS.keys())
def list_banks() -> list[str]:
"""Compat anterior — alias de get_available_choices()."""
return get_available_choices()
def bank_label(choice: str) -> str:
"""Rótulo amigável para a UI."""
return DB_LABELS.get(choice, choice)
# ============================
# Chaves de sessão (ou fallback global)
# ============================
SESSION_DB_CHOICE_KEY = "__db_choice__" # "prod" | "test" | "treinamento"
SESSION_DB_ENGINE_KEY = "__db_engine__" # cache de engine
SESSION_DB_FACTORY_KEY = "__db_session_factory__" # cache de sessionmaker
# Fallback global quando não houver Streamlit
_GLOBAL_STATE: Dict[str, Any] = {
SESSION_DB_CHOICE_KEY: os.getenv("DB_CHOICE", "prod"),
SESSION_DB_ENGINE_KEY: None,
SESSION_DB_FACTORY_KEY: None,
}
def _session_get(key: str, default=None):
if _HAS_ST:
return st.session_state.get(key, default)
return _GLOBAL_STATE.get(key, default)
def _session_set(key: str, value):
if _HAS_ST:
st.session_state[key] = value
else:
_GLOBAL_STATE[key] = value
def _session_pop(key: str):
if _HAS_ST:
st.session_state.pop(key, None)
else:
_GLOBAL_STATE.pop(key, None)
# ============================
# Normalização da escolha
# ============================
def _normalize_choice(raw: Optional[str]) -> str:
val = (raw or "").strip().lower()
if val in CHOICE_ALIASES:
val = CHOICE_ALIASES[val]
if val not in DB_URLS:
val = "prod"
return val
# ============================
# Escolha do banco
# ============================
def set_db_choice(choice: str):
"""
Define o banco ativo para a sessão atual.
choice ∈ {"prod", "test", "treinamento"} (aceita alias "train" → "treinamento").
Invalida caches de engine/session e atualiza ENV (DB_CHOICE).
"""
choice = _normalize_choice(choice)
# Se houver engine cacheado, faça dispose antes de trocar
cached_engine = _session_get(SESSION_DB_ENGINE_KEY)
if cached_engine is not None:
try:
cached_engine.dispose()
except Exception:
pass
_session_set(SESSION_DB_CHOICE_KEY, choice)
os.environ["DB_CHOICE"] = choice # permite que outros módulos leiam
# Ao trocar, invalida caches
_session_pop(SESSION_DB_ENGINE_KEY)
_session_pop(SESSION_DB_FACTORY_KEY)
def set_current_db_choice(choice: str) -> None:
"""Alias compatível com app: set_current_db_choice(choice)."""
set_db_choice(choice)
def current_db_choice() -> str:
"""
Retorna 'prod' | 'test' | 'treinamento' (default: 'prod').
Prioriza session_state; se ausente, lê DB_CHOICE do ambiente.
"""
# 1) Session
val = _session_get(SESSION_DB_CHOICE_KEY)
# 2) ENV (permite seleção por URL/externo)
if val is None:
env_val = os.getenv("DB_CHOICE")
if env_val:
val = _normalize_choice(env_val)
_session_set(SESSION_DB_CHOICE_KEY, val)
# 3) Default
if val is None:
val = "prod"
_session_set(SESSION_DB_CHOICE_KEY, val)
# Sanitize
val = _normalize_choice(val)
if val != _session_get(SESSION_DB_CHOICE_KEY):
_session_set(SESSION_DB_CHOICE_KEY, val)
return val
# ============================
# Engine / Session por ambiente
# ============================
def _url_for_choice(choice: str) -> str:
return DB_URLS[choice]
def _engine_args_for_url(url: str) -> dict:
args = {
"echo": False,
"pool_pre_ping": True,
}
if url.startswith("sqlite:///"):
# evita erro em threads do Streamlit
args["connect_args"] = {"check_same_thread": False}
return args
def get_engine():
"""
Entrega o engine do banco ATIVO (por sessão). Cria e cacheia se necessário.
"""
choice = current_db_choice()
cached = _session_get(SESSION_DB_ENGINE_KEY)
if cached is not None and getattr(cached, "__db_choice__", None) == choice:
return cached
url = _url_for_choice(choice)
url = _ensure_parent_dir_sqlite(url) # ⬅️ garante diretório pai se for SQLite
eng = create_engine(url, **_engine_args_for_url(url))
setattr(eng, "__db_choice__", choice)
_session_set(SESSION_DB_ENGINE_KEY, eng)
return eng
def get_session_factory():
"""
Entrega um sessionmaker vinculado ao engine do banco ATIVO (em cache).
"""
choice = current_db_choice()
fac = _session_get(SESSION_DB_FACTORY_KEY)
if fac is not None and getattr(fac, "__db_choice__", None) == choice:
return fac
fac = sessionmaker(bind=get_engine(), autocommit=False, autoflush=False)
setattr(fac, "__db_choice__", choice)
_session_set(SESSION_DB_FACTORY_KEY, fac)
return fac
def SessionLocal():
"""
Cria uma sessão no banco ATIVO.
Uso:
db = SessionLocal()
try:
...
finally:
db.close()
"""
return get_session_factory()()