lumen / app.py
mdAmin313's picture
Update app.py
c7f08a7 verified
# smc_multimodal_atr_report.py
"""
SMC Multimodal + Gemini Auto-label + Batch CSV Prediction with ATR-based SL/TP + HTML report.
Requirements (example):
pip install torch torchvision pandas scikit-learn pillow gradio tqdm joblib matplotlib google-generativeai
Set GENAI_API_KEY environment variable for Gemini if using that tab.
"""
import os
import io
import zipfile
import tempfile
import shutil
import json
import base64
from typing import List, Tuple, Optional
from datetime import datetime
import numpy as np
import pandas as pd
from PIL import Image
import joblib
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import gradio as gr
# --- Gemini client (google-genai) -----------------
try:
from google import genai
except Exception:
genai = None
# -------------------------
# Config
# -------------------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMAGE_SIZE = 128
SEQ_LEN = 32
SEQ_BATCH = 64
EPOCHS = 8
MODEL_DIR = "models_smc_multimodal"
AUTO_DIR = "auto_labeled"
os.makedirs(MODEL_DIR, exist_ok=True)
os.makedirs(AUTO_DIR, exist_ok=True)
LABELS = ["Sell", "Hold", "Buy"]
# -------------------------
# Helpers and CSV robustness
# -------------------------
COMMON_COLS = {
"open": ["open", "Open", "O", "o"],
"high": ["high", "High", "H", "h"],
"low": ["low", "Low", "L", "l"],
"close":["close", "Close", "C", "c"],
"volume":["volume", "Volume", "V", "v"]
}
def find_col(df, choices):
for c in choices:
if c in df.columns:
return c
return None
def handle_ohlcv_csv(csv_file):
df = pd.read_csv(csv_file.name)
col_map = {}
for key, choices in COMMON_COLS.items():
found = find_col(df, choices)
if not found:
raise ValueError(f"OHLCV CSV missing column for '{key}'. Tried: {choices}")
col_map[found] = key.capitalize()
df = df.rename(columns=col_map)
if "Date" in df.columns:
try:
df["Date"] = pd.to_datetime(df["Date"], infer_datetime_format=True)
except Exception:
pass
df = df.sort_values("Date").reset_index(drop=True)
else:
df = df.reset_index(drop=True)
for c in ["Open","High","Low","Close","Volume"]:
df[c] = pd.to_numeric(df[c], errors="coerce")
df[["Open","High","Low","Close","Volume"]] = df[["Open","High","Low","Close","Volume"]].fillna(method="ffill").fillna(method="bfill")
if df[["Open","High","Low","Close","Volume"]].isnull().any().any():
raise ValueError("OHLCV CSV contains too many missing values after fill.")
return df
def label_to_index_generic(l):
if isinstance(l, str):
s = l.strip().lower()
for i, lab in enumerate(LABELS):
if lab.lower() == s:
return i
try:
idx = int(l)
if idx in [0,1,2]:
return idx
except:
pass
raise ValueError(f"Unknown label: {l}")
# -------------------------
# Models (same as before)
# -------------------------
class MultimodalDataset(Dataset):
def __init__(self, items: List[Tuple[str, np.ndarray, int]]):
self.items = items
def __len__(self): return len(self.items)
def __getitem__(self, idx):
img_path, seq_arr, label = self.items[idx]
img = Image.open(img_path).convert("RGB").resize((IMAGE_SIZE, IMAGE_SIZE))
img_arr = (np.array(img).astype(np.float32) / 255.0)
img_arr = np.transpose(img_arr, (2,0,1))
img_tensor = torch.tensor(img_arr, dtype=torch.float32)
seq_tensor = torch.tensor(seq_arr.astype(np.float32))
label_tensor = torch.tensor(int(label), dtype=torch.long)
return img_tensor, seq_tensor, label_tensor
class SimpleCNNEmbedding(nn.Module):
def __init__(self, emb_size=128):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(3, 16, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(16, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
)
conv_out = (IMAGE_SIZE // 8) * (IMAGE_SIZE // 8) * 64
self.head = nn.Sequential(nn.Flatten(), nn.Linear(conv_out, emb_size), nn.ReLU())
def forward(self, x): return self.head(self.conv(x))
class SimpleLSTMEmbedding(nn.Module):
def __init__(self, n_features, hidden=64, emb_size=128):
super().__init__()
self.lstm = nn.LSTM(input_size=n_features, hidden_size=hidden, num_layers=1, batch_first=True, bidirectional=True)
self.proj = nn.Sequential(nn.Linear(hidden*2, emb_size), nn.ReLU())
def forward(self, x):
out, _ = self.lstm(x)
last = out[:, -1, :]
return self.proj(last)
class MultimodalNet(nn.Module):
def __init__(self, n_features, emb_size=128, n_classes=3):
super().__init__()
self.img_model = SimpleCNNEmbedding(emb_size=emb_size)
self.seq_model = SimpleLSTMEmbedding(n_features=n_features, hidden=64, emb_size=emb_size)
self.head = nn.Sequential(nn.Linear(emb_size*2, 256), nn.ReLU(), nn.Dropout(0.2), nn.Linear(256, n_classes))
def forward(self, img, seq):
img_emb = self.img_model(img)
seq_emb = self.seq_model(seq)
cat = torch.cat([img_emb, seq_emb], dim=1)
return self.head(cat)
# -------------------------
# Helper: ATR
# -------------------------
def compute_atr(df: pd.DataFrame, n=14):
high = df["High"].values
low = df["Low"].values
close = df["Close"].values
tr = np.maximum(high - low, np.maximum(np.abs(high - np.concatenate(([close[0]], close[:-1]))), np.abs(low - np.concatenate(([close[0]], close[:-1])))))
# first TR uses high-low
atr = pd.Series(tr).rolling(window=n, min_periods=1).mean().values
return atr # numpy array length n_rows
# -------------------------
# Load & predict (single)
# -------------------------
def load_multimodal_model():
path = os.path.join(MODEL_DIR, "multimodal_model.pt")
if not os.path.exists(path):
return None
ckpt = torch.load(path, map_location=DEVICE)
n_features = ckpt.get("n_features")
model = MultimodalNet(n_features=n_features, emb_size=128, n_classes=len(LABELS)).to(DEVICE)
model.load_state_dict(ckpt["model_state"])
model.eval()
return model
def predict_multimodal(image_bytes, csv_snippet_file, scaler_path=os.path.join(MODEL_DIR,"csv_scaler.pkl")):
model = load_multimodal_model()
if model is None:
return "No trained multimodal model found.", None
if not os.path.exists(scaler_path):
return "CSV scaler not found. Train first.", None
img = Image.open(io.BytesIO(image_bytes)).convert("RGB").resize((IMAGE_SIZE, IMAGE_SIZE))
img_arr = (np.array(img).astype(np.float32) / 255.0); img_arr = np.transpose(img_arr, (2,0,1))
img_tensor = torch.tensor(img_arr[None], dtype=torch.float32).to(DEVICE)
df = pd.read_csv(csv_snippet_file.name)
if df.shape[0] < SEQ_LEN:
return f"CSV snippet too short. Need at least {SEQ_LEN} rows.", None
required = {"Open","High","Low","Close","Volume"}
if not required.issubset(set(df.columns)):
return f"CSV snippet must have columns: {required}", None
scaler = joblib.load(scaler_path)
feats = df[["Open","High","Low","Close","Volume"]].values.astype(np.float32)
feats_scaled = scaler.transform(feats)
seq_arr = feats_scaled[-SEQ_LEN:]
seq_tensor = torch.tensor(seq_arr[None], dtype=torch.float32).to(DEVICE)
with torch.no_grad():
logits = model(img_tensor, seq_tensor)
probs = torch.softmax(logits, dim=-1).cpu().numpy()[0]
idx = int(probs.argmax())
return LABELS[idx], {LABELS[i]: float(probs[i]) for i in range(len(LABELS))}
# -------------------------
# Batch predict with ATR-based SL/TP + HTML report
# -------------------------
def batch_predict_extract_trades_atr(full_csv_file, stop_mult=1.0, tp_mult=2.0, atr_period=14, max_hold=32, scaler_path=os.path.join(MODEL_DIR,"csv_scaler.pkl")):
"""
Sliding-window sequence-only predictions and ATR-based stop/take management.
stop_mult: ATR multiplier for stop loss (e.g., 1.0)
tp_mult: ATR multiplier for take profit (e.g., 2.0)
atr_period: ATR period length (default 14)
"""
# require scaler & model
if not os.path.exists(scaler_path):
return None, None, None, "CSV scaler not found. Train first."
model = load_multimodal_model()
if model is None:
return None, None, None, "No trained multimodal model found."
scaler = joblib.load(scaler_path)
# load & clean df
df = handle_ohlcv_csv(full_csv_file)
n = df.shape[0]
if n < SEQ_LEN:
return None, None, None, f"CSV too short: need at least {SEQ_LEN} rows."
feats = df[["Open","High","Low","Close","Volume"]].values.astype(np.float32)
feats_scaled = scaler.transform(feats)
# sliding windows (same as earlier)
windows = []
indices = []
for end_idx in range(SEQ_LEN-1, n):
start = end_idx - (SEQ_LEN-1)
windows.append(feats_scaled[start:end_idx+1])
indices.append(end_idx)
X = np.stack(windows) # (B, SEQ_LEN, features)
# predict using sequence branch (zero image emb)
device = DEVICE
preds = []
probs_all = []
with torch.no_grad():
B = X.shape[0]
batch_size = SEQ_BATCH
for i in range(0, B, batch_size):
xb = torch.tensor(X[i:i+batch_size], dtype=torch.float32).to(device)
seq_emb = model.seq_model(xb) # (b, emb_size)
zero_img_emb = torch.zeros_like(seq_emb)
head_in = torch.cat([zero_img_emb, seq_emb], dim=1)
logits = model.head(head_in)
probs = torch.softmax(logits, dim=-1).cpu().numpy()
batch_preds = probs.argmax(axis=1)
preds.extend(batch_preds.tolist())
probs_all.extend(probs.tolist())
label_series = [None] * (SEQ_LEN-1) + [LABELS[int(p)] for p in preds]
prob_series = [None] * (SEQ_LEN-1) + [probs_all[i] for i in range(len(preds))]
# compute ATR
atr = compute_atr(df, n=atr_period) # numpy array
trades = []
state = "flat"
entry_idx = None
entry_price = None
side = None
hold_count = 0
t = SEQ_LEN-1
while t < n:
lab = label_series[t]
# enter
if state == "flat":
if lab == "Buy" or lab == "Sell":
# entry at next candle open if possible
entry_idx = t+1 if (t+1 < n) else t
entry_open = float(df["Open"].iloc[entry_idx]) if (t+1 < n) else float(df["Close"].iloc[t])
entry_price = entry_open
side = "Long" if lab == "Buy" else "Short"
# determine ATR at entry index (use most recent available)
atr_val = float(atr[entry_idx]) if entry_idx < len(atr) else float(atr[-1])
if np.isnan(atr_val) or atr_val == 0:
# fallback: use close-open
atr_val = float(abs(df["Close"].iloc[entry_idx] - df["Open"].iloc[entry_idx]))
if atr_val == 0:
atr_val = 1e-6
if side == "Long":
stop_price = entry_price - stop_mult * atr_val
take_price = entry_price + tp_mult * atr_val
else:
stop_price = entry_price + stop_mult * atr_val
take_price = entry_price - tp_mult * atr_val
hold_count = 0
# start scanning from entry candle (inclusive) for stop/take/opposite signal
i = entry_idx
exited = False
while i < n:
# examine candle i high/low
high_i = float(df["High"].iloc[i])
low_i = float(df["Low"].iloc[i])
# check stop / take first (intra-candle)
if side == "Long":
if low_i <= stop_price:
exit_price = stop_price
exit_idx = i
reason = "stop"
exited = True
elif high_i >= take_price:
exit_price = take_price
exit_idx = i
reason = "take"
exited = True
else: # Short
if high_i >= stop_price:
exit_price = stop_price
exit_idx = i
reason = "stop"
exited = True
elif low_i <= take_price:
exit_price = take_price
exit_idx = i
reason = "take"
exited = True
if exited:
# record trade
if side == "Long":
pnl = exit_price / entry_price - 1.0
else:
pnl = entry_price / exit_price - 1.0
trades.append({
"entry_idx": entry_idx,
"entry_time": df["Date"].iloc[entry_idx] if "Date" in df.columns else int(entry_idx),
"entry_price": entry_price,
"exit_idx": exit_idx,
"exit_time": df["Date"].iloc[exit_idx] if "Date" in df.columns else int(exit_idx),
"exit_price": exit_price,
"side": side,
"pnl": pnl,
"exit_reason": reason
})
state = "flat"
entry_idx = None; entry_price = None; side = None
# move t to exit_idx + 1
t = exit_idx + 1
break
# check opposite signal at this candle's aligned label (label_series aligned to window end)
lab_i = label_series[i] if i < len(label_series) else None
if (side == "Long" and lab_i == "Sell") or (side == "Short" and lab_i == "Buy"):
# exit at next open if exists
exit_idx = i+1 if (i+1 < n) else i
exit_price = float(df["Open"].iloc[exit_idx]) if (i+1 < n) else float(df["Close"].iloc[i])
reason = "opp_signal"
if side == "Long":
pnl = exit_price / entry_price - 1.0
else:
pnl = entry_price / exit_price - 1.0
trades.append({
"entry_idx": entry_idx,
"entry_time": df["Date"].iloc[entry_idx] if "Date" in df.columns else int(entry_idx),
"entry_price": entry_price,
"exit_idx": exit_idx,
"exit_time": df["Date"].iloc[exit_idx] if "Date" in df.columns else int(exit_idx),
"exit_price": exit_price,
"side": side,
"pnl": pnl,
"exit_reason": reason
})
state = "flat"
entry_idx = None; entry_price = None; side = None
t = exit_idx + 1
exited = True
break
# hold count
hold_count += 1
if hold_count >= max_hold:
exit_idx = i+1 if (i+1 < n) else i
exit_price = float(df["Open"].iloc[exit_idx]) if (i+1 < n) else float(df["Close"].iloc[i])
reason = "max_hold"
if side == "Long":
pnl = exit_price / entry_price - 1.0
else:
pnl = entry_price / exit_price - 1.0
trades.append({
"entry_idx": entry_idx,
"entry_time": df["Date"].iloc[entry_idx] if "Date" in df.columns else int(entry_idx),
"entry_price": entry_price,
"exit_idx": exit_idx,
"exit_time": df["Date"].iloc[exit_idx] if "Date" in df.columns else int(exit_idx),
"exit_price": exit_price,
"side": side,
"pnl": pnl,
"exit_reason": reason
})
state = "flat"
entry_idx = None; entry_price = None; side = None
t = exit_idx + 1
exited = True
break
# continue to next candle
i += 1
if not exited:
# reached end of series without exit -> close at last close
exit_idx = n-1
exit_price = float(df["Close"].iloc[exit_idx])
reason = "eod"
if side == "Long":
pnl = exit_price / entry_price - 1.0
else:
pnl = entry_price / exit_price - 1.0
trades.append({
"entry_idx": entry_idx,
"entry_time": df["Date"].iloc[entry_idx] if "Date" in df.columns else int(entry_idx),
"entry_price": entry_price,
"exit_idx": exit_idx,
"exit_time": df["Date"].iloc[exit_idx] if "Date" in df.columns else int(exit_idx),
"exit_price": exit_price,
"side": side,
"pnl": pnl,
"exit_reason": reason
})
return _save_and_report(df, trades) # end
continue # continue outer while (t already advanced to exit_idx+1)
# increment t if no entry or after handled
t += 1
# done scanning
return _save_and_report(df, trades)
def _save_and_report(df, trades):
# save trades CSV
trades_df = pd.DataFrame(trades)
ts = int(datetime.utcnow().timestamp())
trades_csv_path = os.path.join(MODEL_DIR, f"trades_{ts}.csv")
trades_df.to_csv(trades_csv_path, index=False)
# make plot
plt.figure(figsize=(12,5))
plt.plot(df.index, df["Close"].values, label="Close")
entries_x = []
entries_y = []
exits_x = []
exits_y = []
colors = []
for tr in trades:
eidx = int(tr["entry_idx"]); exidx = int(tr["exit_idx"])
entries_x.append(eidx); entries_y.append(tr["entry_price"])
exits_x.append(exidx); exits_y.append(tr["exit_price"])
colors.append("g" if tr["side"]=="Long" else "r")
# plot entries/exits
if len(entries_x):
for (x,y,c) in zip(entries_x, entries_y, colors):
plt.scatter([x],[y], marker="o", color=c, zorder=5)
if len(exits_x):
for (x,y,c) in zip(exits_x, exits_y, colors):
plt.scatter([x],[y], marker="x", color=c, zorder=5)
plt.title("Close price with entries (o) and exits (x)")
plt.xlabel("index")
plt.ylabel("Close")
plt.legend()
plot_path = os.path.join(MODEL_DIR, f"trades_plot_{ts}.png")
plt.tight_layout()
plt.savefig(plot_path)
plt.close()
# compute metrics
if trades_df.empty:
summary_text = "No trades generated."
report_path = _generate_html_report(df, trades_df, plot_path, summary_text, {})
return trades_csv_path, plot_path, report_path, summary_text
pnl_list = trades_df["pnl"].astype(float).values
total_pnl = float(np.nansum(pnl_list))
avg_pnl = float(np.nanmean(pnl_list))
wins = float((pnl_list > 0).sum())
win_rate = float(wins / len(pnl_list)) if len(pnl_list) > 0 else 0.0
# equity curve (assume 1 unit risk per trade, returns additively)
# we'll build cumulative returns as (1 + pnl) multiplicatively for realistic equity growth
equity = np.cumprod(1 + pnl_list) # start at 1
peak = np.maximum.accumulate(equity)
drawdowns = (peak - equity) / peak
max_dd = float(np.max(drawdowns)) if len(drawdowns)>0 else 0.0
summary = {
"n_trades": int(len(pnl_list)),
"total_pnl": total_pnl,
"avg_pnl": avg_pnl,
"win_rate": win_rate,
"max_drawdown": max_dd
}
summary_text = f"Trades: {summary['n_trades']}, Total PnL (factorized): {summary['total_pnl']:.4f}, Win rate: {summary['win_rate']:.3f}, Max Drawdown: {summary['max_drawdown']:.3f}"
report_path = _generate_html_report(df, trades_df, plot_path, summary_text, summary)
return trades_csv_path, plot_path, report_path, summary_text
def _generate_html_report(df, trades_df, plot_path, summary_text, summary):
# embed plot as base64
with open(plot_path, "rb") as f:
img_b64 = base64.b64encode(f.read()).decode("ascii")
img_tag = f'<img src="data:image/png;base64,{img_b64}" style="max-width:100%;height:auto;" />'
# trades table html
if trades_df.empty:
trades_html = "<p>No trades</p>"
else:
trades_html = trades_df.to_html(classes="table table-striped", index=False, float_format="%.6f")
# summary HTML block
summary_html = "<ul>"
for k,v in summary.items():
summary_html += f"<li><strong>{k}</strong>: {v}</li>"
summary_html += "</ul>"
html = f"""
<html>
<head>
<meta charset="utf-8"/>
<title>SMC Multimodal - ATR Trades Report</title>
<style>
body {{ font-family: Arial, sans-serif; margin: 20px; }}
.header {{ margin-bottom: 20px; }}
.section {{ margin-bottom: 30px; }}
table {{ border-collapse: collapse; width: 100%; }}
table, th, td {{ border: 1px solid #ddd; }}
th, td {{ padding: 8px; text-align: left; }}
th {{ background-color: #f2f2f2; }}
</style>
</head>
<body>
<div class="header">
<h1>SMC Multimodal - ATR Trades Report</h1>
<p>Generated: {datetime.utcnow().isoformat()} UTC</p>
<p>{summary_text}</p>
</div>
<div class="section">
<h2>Price chart with trades</h2>
{img_tag}
</div>
<div class="section">
<h2>Trades</h2>
{trades_html}
</div>
<div class="section">
<h2>Summary metrics</h2>
{summary_html}
</div>
</body>
</html>
"""
ts = int(datetime.utcnow().timestamp())
report_path = os.path.join(MODEL_DIR, f"report_{ts}.html")
with open(report_path, "w", encoding="utf-8") as f:
f.write(html)
return report_path
# -------------------------
# The rest (Gemini, auto-label, training) left mostly unchanged from earlier app
# For brevity, reuse previous functions for ai_teacher and training wrappers.
# (You can paste the previous implementation here for a full runnable app.)
# -------------------------
# For this response I will implement minimal wrappers to keep the app runnable:
GENAI_API_KEY = os.environ.get("GENAI_API_KEY") or os.environ.get("GOOGLE_API_KEY")
PROMPT_JSON_SCHEMA = f"""
You are a chart-analysis assistant. INPUT: a candlestick chart image (standard candlestick, no overlays).
You MUST OUTPUT a single JSON object only (no explanatory text). Schema:
{{
"label": "Buy" | "Sell" | "Hold",
"confidence": 0.0-1.0,
"seq_start": <integer index into the provided OHLCV CSV> OR a date string "YYYY-MM-DD" (one of the two),
"note": "brief explanation (optional)"
}}
Constraints:
- Only output valid JSON following exactly the keys above.
- If uncertain, return confidence < 0.6.
- seq_start should be either an integer or an ISO date string (not both).
- Do NOT include any other keys.
"""
def ai_teacher_with_gemini(image_bytes: bytes, prompt_extra: str = "") -> dict:
if genai is None:
raise RuntimeError("google-genai package not installed. pip install google-generativeai")
api_key = GENAI_API_KEY
if not api_key:
raise RuntimeError("Set GENAI_API_KEY environment variable with your Gemini API key.")
client = genai.Client(api_key=api_key)
prompt = PROMPT_JSON_SCHEMA + "\n" + prompt_extra
pil_img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
contents = [prompt, pil_img]
response = client.models.generate_content(model="gemini-2.5-flash", contents=contents, temperature=0.0, max_output_tokens=800, response_format="TEXT")
text = None
try:
cand = response.candidates[0]
if hasattr(cand, "content") and hasattr(cand.content, "parts"):
parts = cand.content.parts
collected = []
for p in parts:
if getattr(p, "text", None):
collected.append(p.text)
text = "\n".join(collected).strip()
else:
text = str(response)
except Exception:
text = str(response)
try:
parsed = json.loads(text)
except Exception:
import re
m = re.search(r"(\{.*\})", text, flags=re.S)
if not m:
raise ValueError(f"Could not parse JSON from model response: {text[:500]}")
parsed = json.loads(m.group(1))
return parsed
# Minimal stubs for auto-label/train/predict wrappers so UI launches
AUTO_IMAGES_DIR = os.path.join(AUTO_DIR, "images")
os.makedirs(AUTO_IMAGES_DIR, exist_ok=True)
AUTO_MAPPING_CSV = os.path.join(AUTO_DIR, "mapping_auto.csv")
if not os.path.exists(AUTO_MAPPING_CSV):
pd.DataFrame(columns=["filename","label","seq_start"]).to_csv(AUTO_MAPPING_CSV, index=False)
def ui_ai_label_and_suggest(img_filepath, ohlcv_csv=None):
return "AI tab not active in this trimmed example (install GenAI)", "{}", None, img_filepath
def ui_accept_auto_label(img_filepath, parsed_json_text, seq_start_override, save_prefix="auto"):
return "Not implemented in trimmed example", None
def ui_train_from_auto(ohlcv_csv, epochs=EPOCHS):
return "Train-from-auto not implemented in trimmed example", None
def ui_train_multimodal_wrapper(img_zip_file, mapping_file, ohlcv_file_in, epochs_in):
return "Train wrapper not implemented in trimmed example", None
def ui_predict_multimodal(img_filepath, csv_snippet_file):
try:
with open(img_filepath, "rb") as f:
image_bytes = f.read()
except Exception as e:
return f"Could not read image file: {e}", None, None
try:
label, probs = predict_multimodal(image_bytes, csv_snippet_file)
return label, str(probs), img_filepath
except Exception as e:
return f"Prediction error: {e}", None, None
def ui_batch_predict_and_export(full_csv_file, stop_mult, tp_mult, atr_period, max_hold):
try:
trades_csv, plot_path, report_path, msg = batch_predict_extract_trades_atr(full_csv_file, stop_mult=float(stop_mult), tp_mult=float(tp_mult), atr_period=int(atr_period), max_hold=int(max_hold))
if trades_csv is None:
return msg, None, None, None
return msg, trades_csv, plot_path, report_path
except Exception as e:
return f"Batch predict error: {e}", None, None, None
# -------------------------
# Build Gradio UI
# -------------------------
with gr.Blocks() as demo:
gr.Markdown("# SMC Multimodal Trainer + ATR-based Trade Extraction & HTML Report")
with gr.Tab("Query / Batch Predict (ATR)"):
gr.Markdown("Upload a full OHLCV CSV and choose ATR/SL/TP settings to generate trades and an HTML report.")
full_csv = gr.File(label="Full OHLCV CSV (for batch prediction)")
with gr.Row():
stop_mult_input = gr.Number(value=1.0, label="Stop ATR multiplier (stop = entry ± ATR * stop_mult)")
tp_mult_input = gr.Number(value=2.0, label="TakeProfit ATR multiplier (take = entry ± ATR * tp_mult)")
atr_period_input = gr.Number(value=14, label="ATR period")
max_hold_input = gr.Number(value=32, label="Max holding candles")
batch_btn = gr.Button("Batch Predict -> ATR trades + report")
batch_out = gr.Textbox(label="Batch Output")
batch_trades_file = gr.File(label="Trades CSV (download)")
batch_plot = gr.Image(label="Trades plot (entries/exits)")
batch_report = gr.File(label="HTML report (download)")
batch_btn.click(fn=ui_batch_predict_and_export, inputs=[full_csv, stop_mult_input, tp_mult_input, atr_period_input, max_hold_input], outputs=[batch_out, batch_trades_file, batch_plot, batch_report])
with gr.Tab("Notes"):
gr.Markdown("""
**ATR-based trade simulation notes**
- This simulator uses ATR for stop loss and take profit levels.
- Entry is at next candle Open after the signal.
- The simulator checks each candle to see if SL/TP hit (intra-candle).
- If the opposite model signal appears, it exits at the next open.
- HTML report includes embedded plot and trade table, plus summary metrics (win rate, max drawdown).
- Validate the results and tune multipliers (stop_mult / tp_mult) before trusting strategies.
""")
if __name__ == "__main__":
demo.launch(debug=True, share=False)