# smc_multimodal_atr_report.py
"""
SMC Multimodal + Gemini Auto-label + Batch CSV Prediction with ATR-based SL/TP + HTML report.
Requirements (example):
pip install torch torchvision pandas scikit-learn pillow gradio tqdm joblib matplotlib google-generativeai
Set GENAI_API_KEY environment variable for Gemini if using that tab.
"""
import os
import io
import zipfile
import tempfile
import shutil
import json
import base64
from typing import List, Tuple, Optional
from datetime import datetime
import numpy as np
import pandas as pd
from PIL import Image
import joblib
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import gradio as gr
# --- Gemini client (google-genai) -----------------
try:
from google import genai
except Exception:
genai = None
# -------------------------
# Config
# -------------------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMAGE_SIZE = 128
SEQ_LEN = 32
SEQ_BATCH = 64
EPOCHS = 8
MODEL_DIR = "models_smc_multimodal"
AUTO_DIR = "auto_labeled"
os.makedirs(MODEL_DIR, exist_ok=True)
os.makedirs(AUTO_DIR, exist_ok=True)
LABELS = ["Sell", "Hold", "Buy"]
# -------------------------
# Helpers and CSV robustness
# -------------------------
COMMON_COLS = {
"open": ["open", "Open", "O", "o"],
"high": ["high", "High", "H", "h"],
"low": ["low", "Low", "L", "l"],
"close":["close", "Close", "C", "c"],
"volume":["volume", "Volume", "V", "v"]
}
def find_col(df, choices):
for c in choices:
if c in df.columns:
return c
return None
def handle_ohlcv_csv(csv_file):
df = pd.read_csv(csv_file.name)
col_map = {}
for key, choices in COMMON_COLS.items():
found = find_col(df, choices)
if not found:
raise ValueError(f"OHLCV CSV missing column for '{key}'. Tried: {choices}")
col_map[found] = key.capitalize()
df = df.rename(columns=col_map)
if "Date" in df.columns:
try:
df["Date"] = pd.to_datetime(df["Date"], infer_datetime_format=True)
except Exception:
pass
df = df.sort_values("Date").reset_index(drop=True)
else:
df = df.reset_index(drop=True)
for c in ["Open","High","Low","Close","Volume"]:
df[c] = pd.to_numeric(df[c], errors="coerce")
df[["Open","High","Low","Close","Volume"]] = df[["Open","High","Low","Close","Volume"]].fillna(method="ffill").fillna(method="bfill")
if df[["Open","High","Low","Close","Volume"]].isnull().any().any():
raise ValueError("OHLCV CSV contains too many missing values after fill.")
return df
def label_to_index_generic(l):
if isinstance(l, str):
s = l.strip().lower()
for i, lab in enumerate(LABELS):
if lab.lower() == s:
return i
try:
idx = int(l)
if idx in [0,1,2]:
return idx
except:
pass
raise ValueError(f"Unknown label: {l}")
# -------------------------
# Models (same as before)
# -------------------------
class MultimodalDataset(Dataset):
def __init__(self, items: List[Tuple[str, np.ndarray, int]]):
self.items = items
def __len__(self): return len(self.items)
def __getitem__(self, idx):
img_path, seq_arr, label = self.items[idx]
img = Image.open(img_path).convert("RGB").resize((IMAGE_SIZE, IMAGE_SIZE))
img_arr = (np.array(img).astype(np.float32) / 255.0)
img_arr = np.transpose(img_arr, (2,0,1))
img_tensor = torch.tensor(img_arr, dtype=torch.float32)
seq_tensor = torch.tensor(seq_arr.astype(np.float32))
label_tensor = torch.tensor(int(label), dtype=torch.long)
return img_tensor, seq_tensor, label_tensor
class SimpleCNNEmbedding(nn.Module):
def __init__(self, emb_size=128):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(3, 16, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(16, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
)
conv_out = (IMAGE_SIZE // 8) * (IMAGE_SIZE // 8) * 64
self.head = nn.Sequential(nn.Flatten(), nn.Linear(conv_out, emb_size), nn.ReLU())
def forward(self, x): return self.head(self.conv(x))
class SimpleLSTMEmbedding(nn.Module):
def __init__(self, n_features, hidden=64, emb_size=128):
super().__init__()
self.lstm = nn.LSTM(input_size=n_features, hidden_size=hidden, num_layers=1, batch_first=True, bidirectional=True)
self.proj = nn.Sequential(nn.Linear(hidden*2, emb_size), nn.ReLU())
def forward(self, x):
out, _ = self.lstm(x)
last = out[:, -1, :]
return self.proj(last)
class MultimodalNet(nn.Module):
def __init__(self, n_features, emb_size=128, n_classes=3):
super().__init__()
self.img_model = SimpleCNNEmbedding(emb_size=emb_size)
self.seq_model = SimpleLSTMEmbedding(n_features=n_features, hidden=64, emb_size=emb_size)
self.head = nn.Sequential(nn.Linear(emb_size*2, 256), nn.ReLU(), nn.Dropout(0.2), nn.Linear(256, n_classes))
def forward(self, img, seq):
img_emb = self.img_model(img)
seq_emb = self.seq_model(seq)
cat = torch.cat([img_emb, seq_emb], dim=1)
return self.head(cat)
# -------------------------
# Helper: ATR
# -------------------------
def compute_atr(df: pd.DataFrame, n=14):
high = df["High"].values
low = df["Low"].values
close = df["Close"].values
tr = np.maximum(high - low, np.maximum(np.abs(high - np.concatenate(([close[0]], close[:-1]))), np.abs(low - np.concatenate(([close[0]], close[:-1])))))
# first TR uses high-low
atr = pd.Series(tr).rolling(window=n, min_periods=1).mean().values
return atr # numpy array length n_rows
# -------------------------
# Load & predict (single)
# -------------------------
def load_multimodal_model():
path = os.path.join(MODEL_DIR, "multimodal_model.pt")
if not os.path.exists(path):
return None
ckpt = torch.load(path, map_location=DEVICE)
n_features = ckpt.get("n_features")
model = MultimodalNet(n_features=n_features, emb_size=128, n_classes=len(LABELS)).to(DEVICE)
model.load_state_dict(ckpt["model_state"])
model.eval()
return model
def predict_multimodal(image_bytes, csv_snippet_file, scaler_path=os.path.join(MODEL_DIR,"csv_scaler.pkl")):
model = load_multimodal_model()
if model is None:
return "No trained multimodal model found.", None
if not os.path.exists(scaler_path):
return "CSV scaler not found. Train first.", None
img = Image.open(io.BytesIO(image_bytes)).convert("RGB").resize((IMAGE_SIZE, IMAGE_SIZE))
img_arr = (np.array(img).astype(np.float32) / 255.0); img_arr = np.transpose(img_arr, (2,0,1))
img_tensor = torch.tensor(img_arr[None], dtype=torch.float32).to(DEVICE)
df = pd.read_csv(csv_snippet_file.name)
if df.shape[0] < SEQ_LEN:
return f"CSV snippet too short. Need at least {SEQ_LEN} rows.", None
required = {"Open","High","Low","Close","Volume"}
if not required.issubset(set(df.columns)):
return f"CSV snippet must have columns: {required}", None
scaler = joblib.load(scaler_path)
feats = df[["Open","High","Low","Close","Volume"]].values.astype(np.float32)
feats_scaled = scaler.transform(feats)
seq_arr = feats_scaled[-SEQ_LEN:]
seq_tensor = torch.tensor(seq_arr[None], dtype=torch.float32).to(DEVICE)
with torch.no_grad():
logits = model(img_tensor, seq_tensor)
probs = torch.softmax(logits, dim=-1).cpu().numpy()[0]
idx = int(probs.argmax())
return LABELS[idx], {LABELS[i]: float(probs[i]) for i in range(len(LABELS))}
# -------------------------
# Batch predict with ATR-based SL/TP + HTML report
# -------------------------
def batch_predict_extract_trades_atr(full_csv_file, stop_mult=1.0, tp_mult=2.0, atr_period=14, max_hold=32, scaler_path=os.path.join(MODEL_DIR,"csv_scaler.pkl")):
"""
Sliding-window sequence-only predictions and ATR-based stop/take management.
stop_mult: ATR multiplier for stop loss (e.g., 1.0)
tp_mult: ATR multiplier for take profit (e.g., 2.0)
atr_period: ATR period length (default 14)
"""
# require scaler & model
if not os.path.exists(scaler_path):
return None, None, None, "CSV scaler not found. Train first."
model = load_multimodal_model()
if model is None:
return None, None, None, "No trained multimodal model found."
scaler = joblib.load(scaler_path)
# load & clean df
df = handle_ohlcv_csv(full_csv_file)
n = df.shape[0]
if n < SEQ_LEN:
return None, None, None, f"CSV too short: need at least {SEQ_LEN} rows."
feats = df[["Open","High","Low","Close","Volume"]].values.astype(np.float32)
feats_scaled = scaler.transform(feats)
# sliding windows (same as earlier)
windows = []
indices = []
for end_idx in range(SEQ_LEN-1, n):
start = end_idx - (SEQ_LEN-1)
windows.append(feats_scaled[start:end_idx+1])
indices.append(end_idx)
X = np.stack(windows) # (B, SEQ_LEN, features)
# predict using sequence branch (zero image emb)
device = DEVICE
preds = []
probs_all = []
with torch.no_grad():
B = X.shape[0]
batch_size = SEQ_BATCH
for i in range(0, B, batch_size):
xb = torch.tensor(X[i:i+batch_size], dtype=torch.float32).to(device)
seq_emb = model.seq_model(xb) # (b, emb_size)
zero_img_emb = torch.zeros_like(seq_emb)
head_in = torch.cat([zero_img_emb, seq_emb], dim=1)
logits = model.head(head_in)
probs = torch.softmax(logits, dim=-1).cpu().numpy()
batch_preds = probs.argmax(axis=1)
preds.extend(batch_preds.tolist())
probs_all.extend(probs.tolist())
label_series = [None] * (SEQ_LEN-1) + [LABELS[int(p)] for p in preds]
prob_series = [None] * (SEQ_LEN-1) + [probs_all[i] for i in range(len(preds))]
# compute ATR
atr = compute_atr(df, n=atr_period) # numpy array
trades = []
state = "flat"
entry_idx = None
entry_price = None
side = None
hold_count = 0
t = SEQ_LEN-1
while t < n:
lab = label_series[t]
# enter
if state == "flat":
if lab == "Buy" or lab == "Sell":
# entry at next candle open if possible
entry_idx = t+1 if (t+1 < n) else t
entry_open = float(df["Open"].iloc[entry_idx]) if (t+1 < n) else float(df["Close"].iloc[t])
entry_price = entry_open
side = "Long" if lab == "Buy" else "Short"
# determine ATR at entry index (use most recent available)
atr_val = float(atr[entry_idx]) if entry_idx < len(atr) else float(atr[-1])
if np.isnan(atr_val) or atr_val == 0:
# fallback: use close-open
atr_val = float(abs(df["Close"].iloc[entry_idx] - df["Open"].iloc[entry_idx]))
if atr_val == 0:
atr_val = 1e-6
if side == "Long":
stop_price = entry_price - stop_mult * atr_val
take_price = entry_price + tp_mult * atr_val
else:
stop_price = entry_price + stop_mult * atr_val
take_price = entry_price - tp_mult * atr_val
hold_count = 0
# start scanning from entry candle (inclusive) for stop/take/opposite signal
i = entry_idx
exited = False
while i < n:
# examine candle i high/low
high_i = float(df["High"].iloc[i])
low_i = float(df["Low"].iloc[i])
# check stop / take first (intra-candle)
if side == "Long":
if low_i <= stop_price:
exit_price = stop_price
exit_idx = i
reason = "stop"
exited = True
elif high_i >= take_price:
exit_price = take_price
exit_idx = i
reason = "take"
exited = True
else: # Short
if high_i >= stop_price:
exit_price = stop_price
exit_idx = i
reason = "stop"
exited = True
elif low_i <= take_price:
exit_price = take_price
exit_idx = i
reason = "take"
exited = True
if exited:
# record trade
if side == "Long":
pnl = exit_price / entry_price - 1.0
else:
pnl = entry_price / exit_price - 1.0
trades.append({
"entry_idx": entry_idx,
"entry_time": df["Date"].iloc[entry_idx] if "Date" in df.columns else int(entry_idx),
"entry_price": entry_price,
"exit_idx": exit_idx,
"exit_time": df["Date"].iloc[exit_idx] if "Date" in df.columns else int(exit_idx),
"exit_price": exit_price,
"side": side,
"pnl": pnl,
"exit_reason": reason
})
state = "flat"
entry_idx = None; entry_price = None; side = None
# move t to exit_idx + 1
t = exit_idx + 1
break
# check opposite signal at this candle's aligned label (label_series aligned to window end)
lab_i = label_series[i] if i < len(label_series) else None
if (side == "Long" and lab_i == "Sell") or (side == "Short" and lab_i == "Buy"):
# exit at next open if exists
exit_idx = i+1 if (i+1 < n) else i
exit_price = float(df["Open"].iloc[exit_idx]) if (i+1 < n) else float(df["Close"].iloc[i])
reason = "opp_signal"
if side == "Long":
pnl = exit_price / entry_price - 1.0
else:
pnl = entry_price / exit_price - 1.0
trades.append({
"entry_idx": entry_idx,
"entry_time": df["Date"].iloc[entry_idx] if "Date" in df.columns else int(entry_idx),
"entry_price": entry_price,
"exit_idx": exit_idx,
"exit_time": df["Date"].iloc[exit_idx] if "Date" in df.columns else int(exit_idx),
"exit_price": exit_price,
"side": side,
"pnl": pnl,
"exit_reason": reason
})
state = "flat"
entry_idx = None; entry_price = None; side = None
t = exit_idx + 1
exited = True
break
# hold count
hold_count += 1
if hold_count >= max_hold:
exit_idx = i+1 if (i+1 < n) else i
exit_price = float(df["Open"].iloc[exit_idx]) if (i+1 < n) else float(df["Close"].iloc[i])
reason = "max_hold"
if side == "Long":
pnl = exit_price / entry_price - 1.0
else:
pnl = entry_price / exit_price - 1.0
trades.append({
"entry_idx": entry_idx,
"entry_time": df["Date"].iloc[entry_idx] if "Date" in df.columns else int(entry_idx),
"entry_price": entry_price,
"exit_idx": exit_idx,
"exit_time": df["Date"].iloc[exit_idx] if "Date" in df.columns else int(exit_idx),
"exit_price": exit_price,
"side": side,
"pnl": pnl,
"exit_reason": reason
})
state = "flat"
entry_idx = None; entry_price = None; side = None
t = exit_idx + 1
exited = True
break
# continue to next candle
i += 1
if not exited:
# reached end of series without exit -> close at last close
exit_idx = n-1
exit_price = float(df["Close"].iloc[exit_idx])
reason = "eod"
if side == "Long":
pnl = exit_price / entry_price - 1.0
else:
pnl = entry_price / exit_price - 1.0
trades.append({
"entry_idx": entry_idx,
"entry_time": df["Date"].iloc[entry_idx] if "Date" in df.columns else int(entry_idx),
"entry_price": entry_price,
"exit_idx": exit_idx,
"exit_time": df["Date"].iloc[exit_idx] if "Date" in df.columns else int(exit_idx),
"exit_price": exit_price,
"side": side,
"pnl": pnl,
"exit_reason": reason
})
return _save_and_report(df, trades) # end
continue # continue outer while (t already advanced to exit_idx+1)
# increment t if no entry or after handled
t += 1
# done scanning
return _save_and_report(df, trades)
def _save_and_report(df, trades):
# save trades CSV
trades_df = pd.DataFrame(trades)
ts = int(datetime.utcnow().timestamp())
trades_csv_path = os.path.join(MODEL_DIR, f"trades_{ts}.csv")
trades_df.to_csv(trades_csv_path, index=False)
# make plot
plt.figure(figsize=(12,5))
plt.plot(df.index, df["Close"].values, label="Close")
entries_x = []
entries_y = []
exits_x = []
exits_y = []
colors = []
for tr in trades:
eidx = int(tr["entry_idx"]); exidx = int(tr["exit_idx"])
entries_x.append(eidx); entries_y.append(tr["entry_price"])
exits_x.append(exidx); exits_y.append(tr["exit_price"])
colors.append("g" if tr["side"]=="Long" else "r")
# plot entries/exits
if len(entries_x):
for (x,y,c) in zip(entries_x, entries_y, colors):
plt.scatter([x],[y], marker="o", color=c, zorder=5)
if len(exits_x):
for (x,y,c) in zip(exits_x, exits_y, colors):
plt.scatter([x],[y], marker="x", color=c, zorder=5)
plt.title("Close price with entries (o) and exits (x)")
plt.xlabel("index")
plt.ylabel("Close")
plt.legend()
plot_path = os.path.join(MODEL_DIR, f"trades_plot_{ts}.png")
plt.tight_layout()
plt.savefig(plot_path)
plt.close()
# compute metrics
if trades_df.empty:
summary_text = "No trades generated."
report_path = _generate_html_report(df, trades_df, plot_path, summary_text, {})
return trades_csv_path, plot_path, report_path, summary_text
pnl_list = trades_df["pnl"].astype(float).values
total_pnl = float(np.nansum(pnl_list))
avg_pnl = float(np.nanmean(pnl_list))
wins = float((pnl_list > 0).sum())
win_rate = float(wins / len(pnl_list)) if len(pnl_list) > 0 else 0.0
# equity curve (assume 1 unit risk per trade, returns additively)
# we'll build cumulative returns as (1 + pnl) multiplicatively for realistic equity growth
equity = np.cumprod(1 + pnl_list) # start at 1
peak = np.maximum.accumulate(equity)
drawdowns = (peak - equity) / peak
max_dd = float(np.max(drawdowns)) if len(drawdowns)>0 else 0.0
summary = {
"n_trades": int(len(pnl_list)),
"total_pnl": total_pnl,
"avg_pnl": avg_pnl,
"win_rate": win_rate,
"max_drawdown": max_dd
}
summary_text = f"Trades: {summary['n_trades']}, Total PnL (factorized): {summary['total_pnl']:.4f}, Win rate: {summary['win_rate']:.3f}, Max Drawdown: {summary['max_drawdown']:.3f}"
report_path = _generate_html_report(df, trades_df, plot_path, summary_text, summary)
return trades_csv_path, plot_path, report_path, summary_text
def _generate_html_report(df, trades_df, plot_path, summary_text, summary):
# embed plot as base64
with open(plot_path, "rb") as f:
img_b64 = base64.b64encode(f.read()).decode("ascii")
img_tag = f''
# trades table html
if trades_df.empty:
trades_html = "
No trades
" else: trades_html = trades_df.to_html(classes="table table-striped", index=False, float_format="%.6f") # summary HTML block summary_html = "Generated: {datetime.utcnow().isoformat()} UTC
{summary_text}