nse-bot-backend / storage.py
ash001's picture
Deploy from GitHub Actions to nse-bot-backend
ddee686 verified
import json
import os
import uuid
from pathlib import Path
from typing import Iterable
import pandas as pd
try:
from sqlalchemy import create_engine, inspect, text
from sqlalchemy.types import Text, Float, BigInteger
except Exception: # pragma: no cover
create_engine = None
inspect = None
text = None
Text = None
Float = None
BigInteger = None
BASE_DIR = Path(__file__).resolve().parent
DATA_DIR = BASE_DIR / "data"
DATA_DIR.mkdir(exist_ok=True)
OPEN_PATH = DATA_DIR / "open_trades.csv"
CLOSED_PATH = DATA_DIR / "closed_trades.csv"
SKIPPED_PATH = DATA_DIR / "skipped_trades.csv"
STATUS_PATH = DATA_DIR / "bot_status.json"
ARCHIVED_OPEN_PATH = DATA_DIR / "archived_open_trades.csv"
ARCHIVED_CLOSED_PATH = DATA_DIR / "archived_closed_trades.csv"
ARCHIVED_SKIPPED_PATH = DATA_DIR / "archived_skipped_trades.csv"
CURRENT_TRADE_TABLES = {
"open": "open_trades",
"closed": "closed_trades",
"skipped": "skipped_trades",
}
ARCHIVED_TRADE_TABLES = {
"open": "archived_open_trades",
"closed": "archived_closed_trades",
"skipped": "archived_skipped_trades",
}
DATETIME_COLUMNS = [
"BUY TIME",
"BUY SIGNAL TIME",
"SELL SIGNAL TIME",
"EXIT TIME",
"TARGET ACHIEVED TIME (1:1)",
"TARGET ACHIEVED TIME (1:2)",
]
TEXT_COLUMNS = {
"STOCK NAME",
"OPTION SYMBOL",
"SECTOR",
"BUY SIGNAL",
"SELL SIGNAL",
"CALL/PUT",
"TRADE SIDE",
"TRADE KEY",
"STATUS",
"REASON SKIPPED",
"EXECUTION MODE",
"ENTRY ORDER ID",
"ENTRY ORDER STATUS",
"ENTRY ORDER TYPE",
"EXIT ORDER ID",
"EXIT ORDER STATUS",
"EXIT ORDER TYPE",
"LIVE ORDER MESSAGE",
}
INTEGER_COLUMNS = {
"STRIKE",
"QUANTITY PER LOT",
}
FLOAT_COLUMNS = {
"BUY PRICE",
"STOP LOSS",
"TARGET ACHIEVED (PROFIT) (1:1)",
"PERCENTAGE CHANGE (1:1)",
"TARGET ACHIEVED (PROFIT) (1:2)",
"PERCENTAGE CHANGE (1:2)",
"STOP LOSS AMOUNT FOR ONE LOT",
"CAPITAL INVESTED (PER LOT)",
"TOTAL PROFIT (PER LOT) (1:1)",
"TOTAL PROFIT (PER LOT) (1:2)",
"MODEL SCORE",
"EXIT PRICE",
"REALIZED GROSS PNL (PER LOT)",
"ESTIMATED CHARGES",
"REALIZED NET PNL (PER LOT)",
}
def _normalize_db_url(db_url: str) -> str:
db_url = (db_url or "").strip()
if not db_url:
return ""
if db_url.startswith("postgres://"):
return "postgresql+psycopg://" + db_url[len("postgres://"):]
if db_url.startswith("postgresql://"):
return "postgresql+psycopg://" + db_url[len("postgresql://"):]
return db_url
def _normalize_trade_ts_series(series: pd.Series) -> pd.Series:
s = pd.Series(series)
out = []
for val in s:
ts = pd.to_datetime(val, errors="coerce")
if pd.isna(ts):
out.append(pd.NaT)
continue
ts = pd.Timestamp(ts)
if ts.tzinfo is None:
ts = ts.tz_localize("Asia/Kolkata")
else:
ts = ts.tz_convert("Asia/Kolkata")
out.append(ts)
return pd.Series(out, index=s.index)
class TradeStorage:
def __init__(self, base_dir: Path | None = None):
self.base_dir = Path(base_dir or BASE_DIR)
self.data_dir = self.base_dir / "data"
self.data_dir.mkdir(exist_ok=True)
self.open_path = self.data_dir / "open_trades.csv"
self.closed_path = self.data_dir / "closed_trades.csv"
self.skipped_path = self.data_dir / "skipped_trades.csv"
self.status_path = self.data_dir / "bot_status.json"
self.archived_open_path = self.data_dir / "archived_open_trades.csv"
self.archived_closed_path = self.data_dir / "archived_closed_trades.csv"
self.archived_skipped_path = self.data_dir / "archived_skipped_trades.csv"
self.import_open_path = self.data_dir / "import_open_trades.csv"
self.import_closed_path = self.data_dir / "import_closed_trades.csv"
self.import_skipped_path = self.data_dir / "import_skipped_trades.csv"
self.db_url = _normalize_db_url(os.getenv("DATABASE_URL", ""))
self.engine = None
self.backend = "csv"
if self.db_url and create_engine is not None:
self.engine = create_engine(
self.db_url,
pool_pre_ping=True,
future=True,
pool_size=int(os.getenv("DB_POOL_SIZE", "2")),
max_overflow=int(os.getenv("DB_MAX_OVERFLOW", "0")),
pool_recycle=int(os.getenv("DB_POOL_RECYCLE", "300")),
pool_timeout=int(os.getenv("DB_POOL_TIMEOUT", "10")),
connect_args={"connect_timeout": int(os.getenv("DB_CONNECT_TIMEOUT", "10"))},
)
self.backend = "database"
def using_database(self) -> bool:
return self.engine is not None
def _safe_to_sql_chunksize(self, columns: Iterable[str], hard_cap: int = 60000) -> int:
col_count = max(1, len(list(columns)))
return max(1, hard_cap // col_count)
def _table_exists(self, table_name: str, bind=None) -> bool:
if not self.using_database():
return False
target = bind if bind is not None else self.engine
inspector = inspect(target)
return inspector.has_table(table_name)
def _trade_paths(self, archived: bool = False) -> tuple[Path, Path, Path]:
if archived:
return self.archived_open_path, self.archived_closed_path, self.archived_skipped_path
return self.open_path, self.closed_path, self.skipped_path
def _trade_tables(self, archived: bool = False) -> tuple[str, str, str]:
tables = ARCHIVED_TRADE_TABLES if archived else CURRENT_TRADE_TABLES
return tables["open"], tables["closed"], tables["skipped"]
def _empty_df(self, columns: Iterable[str]) -> pd.DataFrame:
return pd.DataFrame(columns=list(columns))
def _load_csv(self, path: Path, columns: Iterable[str]) -> pd.DataFrame:
if path.exists():
try:
df = pd.read_csv(path)
return df.reindex(columns=list(columns))
except Exception:
return self._empty_df(columns)
return self._empty_df(columns)
def _serialize_df(self, df: pd.DataFrame, columns: Iterable[str]) -> pd.DataFrame:
out = df.copy().reindex(columns=list(columns))
for col in TEXT_COLUMNS:
if col in out.columns:
out[col] = out[col].astype("string")
out[col] = out[col].replace(
{"<NA>": None, "nan": None, "None": None, "": None}
)
for col in DATETIME_COLUMNS:
if col in out.columns:
out[col] = pd.to_datetime(out[col], errors="coerce").astype("string")
out[col] = out[col].replace({"<NA>": None, "NaT": None})
for col in INTEGER_COLUMNS:
if col in out.columns:
out[col] = pd.to_numeric(out[col], errors="coerce").astype("Int64")
for col in FLOAT_COLUMNS:
if col in out.columns:
out[col] = pd.to_numeric(out[col], errors="coerce")
return out.where(pd.notna(out), None)
def _dtype_map(self, columns: Iterable[str]) -> dict:
dtype = {}
if Text is None:
return dtype
for col in columns:
if col in TEXT_COLUMNS or col in DATETIME_COLUMNS:
dtype[col] = Text()
elif col in INTEGER_COLUMNS:
dtype[col] = BigInteger()
elif col in FLOAT_COLUMNS:
dtype[col] = Float()
return dtype
def _quote_ident(self, ident: str) -> str:
return '"' + str(ident).replace('"', '""') + '"'
def _sync_snapshot_by_trade_key(self, conn, table_name: str, payload: pd.DataFrame, columns: list[str]):
quoted_table = self._quote_ident(table_name)
quoted_key = self._quote_ident("TRADE KEY")
quoted_cols = [self._quote_ident(col) for col in columns]
col_list_sql = ", ".join(quoted_cols)
temp_table = f"_tmp_{table_name}_{uuid.uuid4().hex[:12]}"
payload.to_sql(
temp_table,
conn,
if_exists="fail",
index=False,
method="multi",
chunksize=self._safe_to_sql_chunksize(columns),
dtype=self._dtype_map(columns),
)
quoted_temp = self._quote_ident(temp_table)
non_key_assignments = [
f"{self._quote_ident(col)} = s.{self._quote_ident(col)}"
for col in columns
if col != "TRADE KEY"
]
delete_missing_keyed_sql = f"""
DELETE FROM {quoted_table} AS t
WHERE t.{quoted_key} IS NOT NULL
AND NOT EXISTS (
SELECT 1
FROM {quoted_temp} AS s
WHERE s.{quoted_key} IS NOT NULL
AND s.{quoted_key} = t.{quoted_key}
)
"""
update_sql = None
if non_key_assignments:
update_sql = f"""
UPDATE {quoted_table} AS t
SET {", ".join(non_key_assignments)}
FROM {quoted_temp} AS s
WHERE s.{quoted_key} IS NOT NULL
AND t.{quoted_key} = s.{quoted_key}
"""
insert_new_keyed_sql = f"""
INSERT INTO {quoted_table} ({col_list_sql})
SELECT {col_list_sql}
FROM {quoted_temp} AS s
WHERE s.{quoted_key} IS NOT NULL
AND NOT EXISTS (
SELECT 1
FROM {quoted_table} AS t
WHERE t.{quoted_key} = s.{quoted_key}
)
"""
replace_null_keyed_sql = f"""
DELETE FROM {quoted_table}
WHERE {quoted_key} IS NULL
"""
insert_null_keyed_sql = f"""
INSERT INTO {quoted_table} ({col_list_sql})
SELECT {col_list_sql}
FROM {quoted_temp}
WHERE {quoted_key} IS NULL
"""
try:
conn.execute(text(delete_missing_keyed_sql))
if update_sql:
conn.execute(text(update_sql))
conn.execute(text(insert_new_keyed_sql))
conn.execute(text(replace_null_keyed_sql))
conn.execute(text(insert_null_keyed_sql))
finally:
conn.execute(text(f'DROP TABLE IF EXISTS {quoted_temp}'))
def _write_table(self, table_name: str, df: pd.DataFrame, columns: Iterable[str], bind=None):
if not self.using_database():
return
cols = list(columns)
payload = self._serialize_df(df, cols)
dtype_map = self._dtype_map(cols)
def _write_with_conn(conn):
if not self._table_exists(table_name, bind=conn):
payload.head(0).to_sql(
table_name,
conn,
if_exists="fail",
index=False,
dtype=dtype_map,
)
if "TRADE KEY" in cols:
self._sync_snapshot_by_trade_key(conn, table_name, payload, cols)
return
conn.execute(text(f'DELETE FROM {self._quote_ident(table_name)}'))
if not payload.empty:
payload.to_sql(
table_name,
conn,
if_exists="append",
index=False,
method="multi",
chunksize=self._safe_to_sql_chunksize(cols),
dtype=dtype_map,
)
if bind is not None:
_write_with_conn(bind)
return
with self.engine.begin() as conn:
_write_with_conn(conn)
def _read_table(self, table_name: str, columns: Iterable[str], bind=None) -> pd.DataFrame:
cols = list(columns)
if not self.using_database():
return self._empty_df(cols)
def _read_with_conn(conn):
if not self._table_exists(table_name, bind=conn):
return self._empty_df(cols)
df = pd.read_sql(text(f'SELECT * FROM "{table_name}"'), conn)
return df.reindex(columns=cols)
try:
if bind is not None:
return _read_with_conn(bind)
with self.engine.begin() as conn:
return _read_with_conn(conn)
except Exception:
return self._empty_df(cols)
def _dedupe_trade_df(self, df: pd.DataFrame, columns: Iterable[str]) -> pd.DataFrame:
out = df.copy().reindex(columns=list(columns))
if out.empty:
return out
if "TRADE KEY" in out.columns:
key_series = out["TRADE KEY"].astype("string")
with_key = out.loc[key_series.notna()].copy()
without_key = out.loc[key_series.isna()].copy()
if not with_key.empty:
with_key = with_key.drop_duplicates(subset=["TRADE KEY"], keep="last")
if not without_key.empty:
without_key = without_key.drop_duplicates(keep="last")
out = pd.concat([with_key, without_key], ignore_index=True)
else:
out = out.drop_duplicates(keep="last")
return out.reindex(columns=list(columns))
def _combine_trade_frames(self, frames: list[pd.DataFrame], columns: Iterable[str]) -> pd.DataFrame:
valid_frames = [f.reindex(columns=list(columns)) for f in frames if f is not None and not f.empty]
if not valid_frames:
return self._empty_df(columns)
combined = pd.concat(valid_frames, ignore_index=True)
return self._dedupe_trade_df(combined, columns)
def _trade_datetime_fallback(self, df: pd.DataFrame) -> pd.Series:
if df is None or df.empty:
return pd.Series(dtype="object")
buy_times = (
_normalize_trade_ts_series(df["BUY TIME"])
if "BUY TIME" in df.columns
else pd.Series(pd.NaT, index=df.index)
)
buy_signal_times = (
_normalize_trade_ts_series(df["BUY SIGNAL TIME"])
if "BUY SIGNAL TIME" in df.columns
else pd.Series(pd.NaT, index=df.index)
)
sell_signal_times = (
_normalize_trade_ts_series(df["SELL SIGNAL TIME"])
if "SELL SIGNAL TIME" in df.columns
else pd.Series(pd.NaT, index=df.index)
)
combined = buy_times.combine_first(buy_signal_times).combine_first(sell_signal_times)
return _normalize_trade_ts_series(combined)
def _trade_key_timestamp_fallback(self, df: pd.DataFrame) -> pd.Series:
if df.empty or "TRADE KEY" not in df.columns:
return pd.Series(pd.NaT, index=df.index)
key_series = df["TRADE KEY"].astype("string")
tail = key_series.str.rsplit("|", n=1).str[-1]
parsed = pd.to_datetime(tail, errors="coerce")
parsed.index = df.index
return parsed
def _sanitize_archive_frame(self, df: pd.DataFrame, kind: str, columns: Iterable[str]) -> pd.DataFrame:
cols = list(columns)
if df is None or df.empty:
return self._empty_df(cols)
out = df.copy().reindex(columns=cols)
for col in DATETIME_COLUMNS:
if col in out.columns:
out[col] = pd.to_datetime(out[col], errors="coerce")
trade_key_ts = self._trade_key_timestamp_fallback(out)
buy_time = (
pd.to_datetime(out["BUY TIME"], errors="coerce")
if "BUY TIME" in out.columns
else pd.Series(pd.NaT, index=out.index)
)
buy_signal_time = (
pd.to_datetime(out["BUY SIGNAL TIME"], errors="coerce")
if "BUY SIGNAL TIME" in out.columns
else pd.Series(pd.NaT, index=out.index)
)
sell_signal_time = (
pd.to_datetime(out["SELL SIGNAL TIME"], errors="coerce")
if "SELL SIGNAL TIME" in out.columns
else pd.Series(pd.NaT, index=out.index)
)
if "BUY TIME" in out.columns:
out["BUY TIME"] = (
buy_time
.combine_first(buy_signal_time)
.combine_first(sell_signal_time)
.combine_first(trade_key_ts)
)
if kind == "closed" and "EXIT TIME" in out.columns:
exit_time = (
pd.to_datetime(out["EXIT TIME"], errors="coerce")
if "EXIT TIME" in out.columns
else pd.Series(pd.NaT, index=out.index)
)
target2_time = (
pd.to_datetime(out["TARGET ACHIEVED TIME (1:2)"], errors="coerce")
if "TARGET ACHIEVED TIME (1:2)" in out.columns
else pd.Series(pd.NaT, index=out.index)
)
target1_time = (
pd.to_datetime(out["TARGET ACHIEVED TIME (1:1)"], errors="coerce")
if "TARGET ACHIEVED TIME (1:1)" in out.columns
else pd.Series(pd.NaT, index=out.index)
)
out["EXIT TIME"] = exit_time.combine_first(target2_time).combine_first(target1_time)
return out.reindex(columns=cols)
def _extract_trade_date(self, df: pd.DataFrame, kind: str) -> pd.Series:
if df.empty:
return pd.Series(dtype="object")
base_trade_ts = self._trade_datetime_fallback(df)
if kind == "closed":
exit_dates = (
_normalize_trade_ts_series(df["EXIT TIME"])
if "EXIT TIME" in df.columns
else pd.Series(pd.NaT, index=df.index)
)
combined = _normalize_trade_ts_series(exit_dates.combine_first(base_trade_ts))
return combined.dt.date
normalized = _normalize_trade_ts_series(base_trade_ts)
return normalized.dt.date
def _split_keep_vs_archive(
self,
df: pd.DataFrame,
kind: str,
columns: Iterable[str],
keep_trade_date=None,
archive_all: bool = False,
) -> tuple[pd.DataFrame, pd.DataFrame]:
cols = list(columns)
if df.empty:
return self._empty_df(cols), self._empty_df(cols)
working = df.copy().reindex(columns=cols)
if archive_all:
return self._empty_df(cols), self._dedupe_trade_df(working, cols)
if keep_trade_date is None:
return working, self._empty_df(cols)
trade_dates = self._extract_trade_date(working, kind)
keep_mask = trade_dates == keep_trade_date
# Safety rule:
# never auto-archive current OPEN/SKIPPED rows just because their trade date is missing.
# Missing timestamps should stay in current tables until the bot rewrites them cleanly.
if kind in {"open", "skipped"}:
keep_mask = keep_mask | trade_dates.isna()
keep_df = working.loc[keep_mask].copy()
archive_df = working.loc[~keep_mask].copy()
return keep_df.reindex(columns=cols), self._dedupe_trade_df(archive_df, cols)
def load_trade_logs(self, columns: Iterable[str], archived: bool = False):
cols = list(columns)
open_table, closed_table, skipped_table = self._trade_tables(archived=archived)
open_path, closed_path, skipped_path = self._trade_paths(archived=archived)
if self.using_database():
try:
with self.engine.begin() as conn:
open_df = self._read_table(open_table, cols, bind=conn)
closed_df = self._read_table(closed_table, cols, bind=conn)
skipped_df = self._read_table(skipped_table, cols, bind=conn)
return open_df, closed_df, skipped_df
except Exception:
return self._empty_df(cols), self._empty_df(cols), self._empty_df(cols)
return (
self._load_csv(open_path, cols),
self._load_csv(closed_path, cols),
self._load_csv(skipped_path, cols),
)
def save_trade_logs(
self,
open_df: pd.DataFrame,
closed_df: pd.DataFrame,
skipped_df: pd.DataFrame,
columns: Iterable[str],
archived: bool = False,
):
cols = list(columns)
open_df = self._dedupe_trade_df(open_df, cols)
closed_df = self._dedupe_trade_df(closed_df, cols)
skipped_df = self._dedupe_trade_df(skipped_df, cols)
open_path, closed_path, skipped_path = self._trade_paths(archived=archived)
open_df.to_csv(open_path, index=False)
closed_df.to_csv(closed_path, index=False)
skipped_df.to_csv(skipped_path, index=False)
if self.using_database():
open_table, closed_table, skipped_table = self._trade_tables(archived=archived)
with self.engine.begin() as conn:
self._write_table(open_table, open_df, cols, bind=conn)
self._write_table(closed_table, closed_df, cols, bind=conn)
self._write_table(skipped_table, skipped_df, cols, bind=conn)
def clear_current_trade_logs(self, columns: Iterable[str]):
cols = list(columns)
empty = self._empty_df(cols)
self.save_trade_logs(empty, empty, empty, cols, archived=False)
def archive_trade_logs(
self,
open_df: pd.DataFrame,
closed_df: pd.DataFrame,
skipped_df: pd.DataFrame,
columns: Iterable[str],
keep_trade_date=None,
archive_all: bool = False,
):
cols = list(columns)
current_open = open_df.copy().reindex(columns=cols)
current_closed = closed_df.copy().reindex(columns=cols)
current_skipped = skipped_df.copy().reindex(columns=cols)
keep_open, archive_open = self._split_keep_vs_archive(
current_open, "open", cols, keep_trade_date=keep_trade_date, archive_all=archive_all
)
keep_closed, archive_closed = self._split_keep_vs_archive(
current_closed, "closed", cols, keep_trade_date=keep_trade_date, archive_all=archive_all
)
keep_skipped, archive_skipped = self._split_keep_vs_archive(
current_skipped, "skipped", cols, keep_trade_date=keep_trade_date, archive_all=archive_all
)
info = {
"archived_open": int(len(archive_open)),
"archived_closed": int(len(archive_closed)),
"archived_skipped": int(len(archive_skipped)),
"archived_total": int(len(archive_open) + len(archive_closed) + len(archive_skipped)),
"remaining_open": int(len(keep_open)),
"remaining_closed": int(len(keep_closed)),
"remaining_skipped": int(len(keep_skipped)),
"archive_all": bool(archive_all),
"keep_trade_date": str(keep_trade_date) if keep_trade_date is not None else None,
}
if info["archived_total"] == 0:
return keep_open, keep_closed, keep_skipped, info
archive_open = self._sanitize_archive_frame(archive_open, "open", cols)
archive_closed = self._sanitize_archive_frame(archive_closed, "closed", cols)
archive_skipped = self._sanitize_archive_frame(archive_skipped, "skipped", cols)
archived_open_df, archived_closed_df, archived_skipped_df = self.load_trade_logs(cols, archived=True)
archived_open_df = self._sanitize_archive_frame(
self._combine_trade_frames([archived_open_df, archive_open], cols),
"open",
cols,
)
archived_closed_df = self._sanitize_archive_frame(
self._combine_trade_frames([archived_closed_df, archive_closed], cols),
"closed",
cols,
)
archived_skipped_df = self._sanitize_archive_frame(
self._combine_trade_frames([archived_skipped_df, archive_skipped], cols),
"skipped",
cols,
)
self.save_trade_logs(
archived_open_df,
archived_closed_df,
archived_skipped_df,
cols,
archived=True,
)
self.save_trade_logs(keep_open, keep_closed, keep_skipped, cols, archived=False)
return keep_open, keep_closed, keep_skipped, info
def archive_current_trade_logs(
self,
columns: Iterable[str],
keep_trade_date=None,
archive_all: bool = False,
):
cols = list(columns)
open_df, closed_df, skipped_df = self.load_trade_logs(cols, archived=False)
return self.archive_trade_logs(
open_df,
closed_df,
skipped_df,
cols,
keep_trade_date=keep_trade_date,
archive_all=archive_all,
)
def _table_columns(self, table_name: str, bind=None) -> list[str]:
if not self.using_database():
return []
target = bind if bind is not None else self.engine
inspector = inspect(target)
if not inspector.has_table(table_name):
return []
return [str(col.get("name")) for col in inspector.get_columns(table_name)]
def load_status(self) -> dict:
if self.using_database():
try:
with self.engine.begin() as conn:
if not self._table_exists("bot_status_current", bind=conn):
raise RuntimeError("bot_status_current table missing")
df = pd.read_sql(text('SELECT * FROM "bot_status_current" LIMIT 1'), conn)
if not df.empty:
row = df.iloc[0].to_dict()
if isinstance(row.get("last_archive_counts"), str):
try:
row["last_archive_counts"] = json.loads(row["last_archive_counts"])
except Exception:
pass
return {k: (None if pd.isna(v) else v) for k, v in row.items()}
except Exception:
pass
if self.status_path.exists():
try:
return json.loads(self.status_path.read_text())
except Exception:
return {}
return {}
def save_status(self, status: dict):
# Keep rich JSON on disk
self.status_path.write_text(json.dumps(status, indent=2, default=str))
if self.using_database():
db_status = dict(status)
# Serialize nested / non-scalar fields for SQL storage
if isinstance(db_status.get("last_archive_counts"), (dict, list)):
db_status["last_archive_counts"] = json.dumps(db_status["last_archive_counts"], default=str)
try:
with self.engine.begin() as conn:
existing_cols = self._table_columns("bot_status_current", bind=conn)
if existing_cols:
existing_cols_set = set(existing_cols)
filtered = {k: v for k, v in db_status.items() if k in existing_cols_set}
else:
filtered = db_status
if not filtered:
return
df = pd.DataFrame([filtered])
self._write_table("bot_status_current", df, df.columns, bind=conn)
except Exception:
# Status persistence must never break the bot loop.
pass
def _resolve_import_path(self, raw_path, default_path: Path) -> Path:
if raw_path is None or str(raw_path).strip() == "":
return default_path
p = Path(str(raw_path).strip())
if not p.is_absolute():
p = self.data_dir / p
return p
def import_current_trade_logs_from_data_dir(
self,
columns: Iterable[str],
open_csv: str | Path | None = None,
closed_csv: str | Path | None = None,
skipped_csv: str | Path | None = None,
merge_with_existing: bool = False,
):
cols = list(columns)
open_path = self._resolve_import_path(open_csv, self.import_open_path)
closed_path = self._resolve_import_path(closed_csv, self.import_closed_path)
skipped_path = self._resolve_import_path(skipped_csv, self.import_skipped_path)
open_exists = open_path.exists()
closed_exists = closed_path.exists()
skipped_exists = skipped_path.exists()
if not any([open_exists, closed_exists, skipped_exists]):
raise FileNotFoundError(
"No import CSVs found in data/. Expected one or more of: "
f"{open_path.name}, {closed_path.name}, {skipped_path.name}"
)
imported_open = self._load_csv(open_path, cols) if open_exists else self._empty_df(cols)
imported_closed = self._load_csv(closed_path, cols) if closed_exists else self._empty_df(cols)
imported_skipped = self._load_csv(skipped_path, cols) if skipped_exists else self._empty_df(cols)
raw_open = len(imported_open)
raw_closed = len(imported_closed)
raw_skipped = len(imported_skipped)
if merge_with_existing:
existing_open, existing_closed, existing_skipped = self.load_trade_logs(cols, archived=False)
imported_open = self._combine_trade_frames([existing_open, imported_open], cols)
imported_closed = self._combine_trade_frames([existing_closed, imported_closed], cols)
imported_skipped = self._combine_trade_frames([existing_skipped, imported_skipped], cols)
else:
imported_open = self._dedupe_trade_df(imported_open, cols)
imported_closed = self._dedupe_trade_df(imported_closed, cols)
imported_skipped = self._dedupe_trade_df(imported_skipped, cols)
self.save_trade_logs(imported_open, imported_closed, imported_skipped, cols, archived=False)
info = {
"open_path": str(open_path),
"closed_path": str(closed_path),
"skipped_path": str(skipped_path),
"open_exists": open_exists,
"closed_exists": closed_exists,
"skipped_exists": skipped_exists,
"raw_open": int(raw_open),
"raw_closed": int(raw_closed),
"raw_skipped": int(raw_skipped),
"imported_open": int(len(imported_open)),
"imported_closed": int(len(imported_closed)),
"imported_skipped": int(len(imported_skipped)),
"merge_with_existing": bool(merge_with_existing),
}
return imported_open, imported_closed, imported_skipped, info
def merge_and_seed_trade_logs(
self,
columns: Iterable[str],
open_sources: list[Path],
closed_sources: list[Path],
skipped_sources: list[Path],
):
cols = list(columns)
def _merge(paths: list[Path]) -> pd.DataFrame:
frames = []
for path in paths:
if path.exists():
try:
frames.append(pd.read_csv(path).reindex(columns=cols))
except Exception:
continue
return self._combine_trade_frames(frames, cols)
open_df = _merge(open_sources)
closed_df = _merge(closed_sources)
skipped_df = _merge(skipped_sources)
self.save_trade_logs(open_df, closed_df, skipped_df, cols)
return open_df, closed_df, skipped_df
_STORAGE = None
def get_storage(base_dir: Path | None = None) -> TradeStorage:
global _STORAGE
if _STORAGE is None:
_STORAGE = TradeStorage(base_dir=base_dir)
return _STORAGE