Tarang_v2 / app.py
unknownfriend00007's picture
Update app.py
6cf9ab2 verified
import os
import time
import threading
import numpy as np
import pandas as pd
import torch
from flask import Flask, jsonify, render_template, request
from config import config
from data_loader import load_asset_series, make_features, build_windows
from trainer import TarangTrainer
from storage_manager import StorageManager
from live_ingest import update_stooq_daily, update_binance_klines
app = Flask(__name__, template_folder="templates")
MODEL_LOCK = threading.Lock()
BOOTSTRAP_DONE = threading.Event()
STATE = {
"boot_time": time.strftime("%Y-%m-%d %H:%M:%S"),
"training": False,
"trained": False,
"error": None,
"last_train_status": None,
"last_ingest_at": None,
"last_ingest_changed": [],
"last_ingest_errors": [],
"last_checkpoint_at": None,
"neon_enabled": bool(config.NEON_CONNECTION),
"ingest_every_seconds": config.INGEST_EVERY_SECONDS,
"checkpoint_every_seconds": config.CHECKPOINT_EVERY_SECONDS,
# Debug info
"checkpoint_loop_alive": True,
"ingest_loop_alive": True,
}
trainer = TarangTrainer()
storage = None
def _asset_csv_path(asset):
if asset.asset_type == "stock":
return os.path.join(config.DATA_DIR, "stocks", "stooq", asset.symbol, f"{config.INTERVAL}.csv")
return os.path.join(config.DATA_DIR, "crypto", "binance", asset.symbol, f"{config.INTERVAL}.csv")
def _predict_next_for_asset(asset):
feats = make_features(asset.df)
if len(feats) < config.WINDOW + 2:
return None
X, _, _ = build_windows(feats, window=config.WINDOW, horizon=config.HORIZON_DAYS)
x_last = X[-1]
mu = X.reshape(-1, X.shape[-1]).mean(axis=0, keepdims=True)
sd = X.reshape(-1, X.shape[-1]).std(axis=0, keepdims=True) + 1e-6
x_last = (x_last - mu) / sd
xb = torch.tensor(x_last, dtype=torch.float32).unsqueeze(0)
with torch.no_grad():
pred_lr = float(trainer.model(xb).squeeze().cpu().item())
last_close = float(feats["close"].iloc[-1])
last_date = pd.Timestamp(feats["date"].iloc[-1])
if asset.asset_type == "stock":
pred_date = last_date + pd.tseries.offsets.BDay(config.HORIZON_DAYS)
else:
pred_date = last_date + pd.Timedelta(days=config.HORIZON_DAYS)
pred_close = last_close * float(np.exp(pred_lr))
return {
"symbol": asset.symbol,
"asset_type": asset.asset_type,
"last_date": last_date.date().isoformat(),
"predicted_for": pd.Timestamp(pred_date).date().isoformat(),
"pred_log_return": pred_lr,
"pred_pct_return": pred_lr * 100.0,
"last_close": last_close,
"pred_close": float(pred_close),
}
def _train_epochs(epochs: int, reason: str):
if epochs <= 0:
return
STATE["training"] = True
STATE["error"] = None
try:
with MODEL_LOCK:
status = trainer.fit(epochs)
STATE["last_train_status"] = {
**status,
"reason": reason,
"epochs_ran": int(epochs),
}
STATE["trained"] = True
print(f"[train] ✓ Completed: {reason}", flush=True)
except Exception as e:
STATE["error"] = f"train_error: {e}"
print(f"[train] ERROR: {e}", flush=True)
finally:
STATE["training"] = False
def _ensure_storage():
global storage
if storage is not None:
return
if not config.NEON_CONNECTION:
return
storage = StorageManager()
def _load_checkpoint_from_neon():
if not config.NEON_CONNECTION:
return False
try:
_ensure_storage()
ckpt = storage.load_latest()
if not ckpt:
return False
# Use timeout lock to avoid deadlock during boot
acquired = MODEL_LOCK.acquire(timeout=30)
if not acquired:
print("[bootstrap] WARNING: Could not acquire lock for checkpoint load", flush=True)
return False
try:
trainer.model.load_state_dict(ckpt["model_state"])
trainer.opt.load_state_dict(ckpt["optim_state"])
trainer.model.eval()
finally:
MODEL_LOCK.release()
STATE["trained"] = True
STATE["last_checkpoint_at"] = ckpt.get("created_at")
print(f"[bootstrap] ✓ Loaded checkpoint from {ckpt.get('created_at')}", flush=True)
return True
except Exception as e:
STATE["error"] = f"neon_load_error: {e}"
print(f"[bootstrap] ERROR loading checkpoint: {e}", flush=True)
return False
def _checkpoint_once():
if not config.NEON_CONNECTION:
return
if not STATE["trained"]:
return
try:
_ensure_storage()
meta = {
"saved_at": time.strftime("%Y-%m-%d %H:%M:%S"),
"train_status": STATE.get("last_train_status"),
"last_ingest_at": STATE.get("last_ingest_at"),
"last_ingest_changed": STATE.get("last_ingest_changed", []),
"last_ingest_errors": STATE.get("last_ingest_errors", []),
}
# Use timeout lock to prevent infinite deadlock [web:360]
acquired = MODEL_LOCK.acquire(timeout=120)
if not acquired:
print(f"[checkpoint] WARNING: Could not acquire lock after 120s, skipping", flush=True)
STATE["error"] = "checkpoint_timeout: could not acquire MODEL_LOCK"
return
try:
storage.save_checkpoint(trainer.model.state_dict(), trainer.opt.state_dict(), meta)
finally:
MODEL_LOCK.release()
STATE["last_checkpoint_at"] = meta["saved_at"]
print(f"[checkpoint] ✓ Saved to Neon at {meta['saved_at']}", flush=True)
except Exception as e:
STATE["error"] = f"checkpoint_error: {e}"
print(f"[checkpoint] ERROR: {e}", flush=True)
def _checkpoint_loop():
"""
Runs every CHECKPOINT_EVERY_SECONDS (3600s = 1 hour).
Wrapped in try/except to survive errors. [web:360][web:361]
"""
while True:
try:
print(f"[checkpoint] Sleeping for {config.CHECKPOINT_EVERY_SECONDS}s...", flush=True)
time.sleep(config.CHECKPOINT_EVERY_SECONDS)
print(f"[checkpoint] Woke up at {time.strftime('%Y-%m-%d %H:%M:%S')}, starting checkpoint...", flush=True)
_checkpoint_once()
except Exception as e:
print(f"[checkpoint] LOOP ERROR: {e}", flush=True)
STATE["error"] = f"checkpoint_loop_error: {e}"
STATE["checkpoint_loop_alive"] = False
time.sleep(60) # Brief pause before retry
STATE["checkpoint_loop_alive"] = True
def _ingest_once():
changed, errors = [], []
assets = load_asset_series()
for a in assets:
csv_path = _asset_csv_path(a)
try:
if a.asset_type == "stock":
if update_stooq_daily(a.symbol, csv_path):
changed.append(a.symbol)
else:
if update_binance_klines(a.symbol, config.INTERVAL, csv_path):
changed.append(a.symbol)
except Exception as e:
errors.append({"symbol": a.symbol, "asset_type": a.asset_type, "error": str(e)[:200]})
STATE["last_ingest_at"] = time.strftime("%Y-%m-%d %H:%M:%S")
STATE["last_ingest_changed"] = changed
STATE["last_ingest_errors"] = errors
if changed:
print(f"[ingest] ✓ Updated {len(changed)} assets: {changed}", flush=True)
if errors:
print(f"[ingest] Errors for {len(errors)} assets", flush=True)
return changed
def _ingest_loop():
"""
Waits for bootstrap, then ingests every INGEST_EVERY_SECONDS (1800s = 30 min).
Wrapped in try/except to survive errors. [web:360][web:361]
"""
print("[ingest] Waiting for bootstrap to complete...", flush=True)
BOOTSTRAP_DONE.wait()
print("[ingest] Bootstrap done, starting ingestion loop", flush=True)
while True:
try:
# Don't ingest while training is active
if STATE["training"]:
print("[ingest] Training in progress, waiting 10s...", flush=True)
time.sleep(10)
continue
changed = _ingest_once()
# Update-train only if new rows arrived AND no training is running
if changed and not STATE["training"]:
print(f"[ingest] Triggering update-train for {len(changed)} assets", flush=True)
_train_epochs(config.UPDATE_TRAIN_EPOCHS, reason=f"ingest_update({len(changed)} assets)")
# Checkpoint immediately after update-train completes
print(f"[ingest] Update-train done, triggering immediate checkpoint", flush=True)
_checkpoint_once()
print(f"[ingest] Sleeping for {config.INGEST_EVERY_SECONDS}s...", flush=True)
time.sleep(config.INGEST_EVERY_SECONDS)
except Exception as e:
print(f"[ingest] LOOP ERROR: {e}", flush=True)
STATE["error"] = f"ingest_loop_error: {e}"
STATE["ingest_loop_alive"] = False
time.sleep(60) # Brief pause before retry
STATE["ingest_loop_alive"] = True
def _bootstrap_flow():
"""
Flow:
1) Try load checkpoint from Neon.
2) If not found: train on historical data.
3) Save checkpoint to Neon once.
4) Mark BOOTSTRAP_DONE so ingestion can start.
"""
print("[bootstrap] Starting bootstrap flow...", flush=True)
loaded = _load_checkpoint_from_neon()
if loaded:
print("[bootstrap] ✓ Loaded from Neon, bootstrap complete", flush=True)
BOOTSTRAP_DONE.set()
return
if config.RETRAIN_ON_START:
print("[bootstrap] No checkpoint found, training on historical data...", flush=True)
_train_epochs(config.INITIAL_TRAIN_EPOCHS, reason="boot_initial_from_historical")
# Save checkpoint once after initial training (if Neon enabled)
print("[bootstrap] Saving initial checkpoint...", flush=True)
_checkpoint_once()
print("[bootstrap] ✓ Bootstrap complete", flush=True)
BOOTSTRAP_DONE.set()
@app.route("/")
def index():
return render_template("dashboard.html")
@app.route("/api/status")
def api_status():
return jsonify({
"ok": True,
"boot_time": STATE["boot_time"],
"data_dir_exists": os.path.exists(config.DATA_DIR),
"training": STATE["training"],
"trained": STATE["trained"],
"error": STATE["error"],
"last_train_status": STATE["last_train_status"],
"last_ingest_at": STATE["last_ingest_at"],
"last_ingest_changed": STATE["last_ingest_changed"],
"last_ingest_errors": STATE["last_ingest_errors"],
"last_checkpoint_at": STATE["last_checkpoint_at"],
"neon_enabled": STATE["neon_enabled"],
"ingest_every_seconds": STATE["ingest_every_seconds"],
"checkpoint_every_seconds": STATE["checkpoint_every_seconds"],
"bootstrap_done": BOOTSTRAP_DONE.is_set(),
# Debug
"checkpoint_loop_alive": STATE["checkpoint_loop_alive"],
"ingest_loop_alive": STATE["ingest_loop_alive"],
})
@app.route("/api/predict")
def api_predict():
if STATE["training"]:
return jsonify({"ok": False, "error": "Training in progress. Try again soon."}), 503
if not STATE["trained"]:
return jsonify({"ok": False, "error": "Model not trained yet."}), 503
symbol = request.args.get("symbol", "").strip().lower()
assets = load_asset_series()
if symbol:
assets = [a for a in assets if a.symbol.lower() == symbol]
preds = []
with MODEL_LOCK:
for a in assets:
p = _predict_next_for_asset(a)
if p:
preds.append(p)
preds.sort(key=lambda x: abs(x["pred_pct_return"]), reverse=True)
return jsonify({"ok": True, "count": len(preds), "predictions": preds})
if __name__ == "__main__":
print("="*60, flush=True)
print(f"Tarang v2 starting at {STATE['boot_time']}", flush=True)
print("="*60, flush=True)
# 1) Bootstrap first (load Neon OR train historical then save Neon)
threading.Thread(target=_bootstrap_flow, daemon=True, name="bootstrap").start()
# 2) Start live ingestion only AFTER bootstrap is done
threading.Thread(target=_ingest_loop, daemon=True, name="ingest").start()
# 3) Start hourly Neon checkpoint loop (1 write per hour)
if config.NEON_CONNECTION:
threading.Thread(target=_checkpoint_loop, daemon=True, name="checkpoint").start()
app.run(host=config.HOST, port=config.PORT, debug=False)