Spaces:
Sleeping
Sleeping
| # smc_multimodal_atr_report.py | |
| """ | |
| SMC Multimodal + Gemini Auto-label + Batch CSV Prediction with ATR-based SL/TP + HTML report. | |
| Requirements (example): | |
| pip install torch torchvision pandas scikit-learn pillow gradio tqdm joblib matplotlib google-generativeai | |
| Set GENAI_API_KEY environment variable for Gemini if using that tab. | |
| """ | |
| import os | |
| import io | |
| import zipfile | |
| import tempfile | |
| import shutil | |
| import json | |
| import base64 | |
| from typing import List, Tuple, Optional | |
| from datetime import datetime | |
| import numpy as np | |
| import pandas as pd | |
| from PIL import Image | |
| import joblib | |
| import matplotlib.pyplot as plt | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.utils.data import Dataset, DataLoader | |
| from sklearn.preprocessing import StandardScaler | |
| from sklearn.model_selection import train_test_split | |
| import gradio as gr | |
| # --- Gemini client (google-genai) ----------------- | |
| try: | |
| from google import genai | |
| except Exception: | |
| genai = None | |
| # ------------------------- | |
| # Config | |
| # ------------------------- | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| IMAGE_SIZE = 128 | |
| SEQ_LEN = 32 | |
| SEQ_BATCH = 64 | |
| EPOCHS = 8 | |
| MODEL_DIR = "models_smc_multimodal" | |
| AUTO_DIR = "auto_labeled" | |
| os.makedirs(MODEL_DIR, exist_ok=True) | |
| os.makedirs(AUTO_DIR, exist_ok=True) | |
| LABELS = ["Sell", "Hold", "Buy"] | |
| # ------------------------- | |
| # Helpers and CSV robustness | |
| # ------------------------- | |
| COMMON_COLS = { | |
| "open": ["open", "Open", "O", "o"], | |
| "high": ["high", "High", "H", "h"], | |
| "low": ["low", "Low", "L", "l"], | |
| "close":["close", "Close", "C", "c"], | |
| "volume":["volume", "Volume", "V", "v"] | |
| } | |
| def find_col(df, choices): | |
| for c in choices: | |
| if c in df.columns: | |
| return c | |
| return None | |
| def handle_ohlcv_csv(csv_file): | |
| df = pd.read_csv(csv_file.name) | |
| col_map = {} | |
| for key, choices in COMMON_COLS.items(): | |
| found = find_col(df, choices) | |
| if not found: | |
| raise ValueError(f"OHLCV CSV missing column for '{key}'. Tried: {choices}") | |
| col_map[found] = key.capitalize() | |
| df = df.rename(columns=col_map) | |
| if "Date" in df.columns: | |
| try: | |
| df["Date"] = pd.to_datetime(df["Date"], infer_datetime_format=True) | |
| except Exception: | |
| pass | |
| df = df.sort_values("Date").reset_index(drop=True) | |
| else: | |
| df = df.reset_index(drop=True) | |
| for c in ["Open","High","Low","Close","Volume"]: | |
| df[c] = pd.to_numeric(df[c], errors="coerce") | |
| df[["Open","High","Low","Close","Volume"]] = df[["Open","High","Low","Close","Volume"]].fillna(method="ffill").fillna(method="bfill") | |
| if df[["Open","High","Low","Close","Volume"]].isnull().any().any(): | |
| raise ValueError("OHLCV CSV contains too many missing values after fill.") | |
| return df | |
| def label_to_index_generic(l): | |
| if isinstance(l, str): | |
| s = l.strip().lower() | |
| for i, lab in enumerate(LABELS): | |
| if lab.lower() == s: | |
| return i | |
| try: | |
| idx = int(l) | |
| if idx in [0,1,2]: | |
| return idx | |
| except: | |
| pass | |
| raise ValueError(f"Unknown label: {l}") | |
| # ------------------------- | |
| # Models (same as before) | |
| # ------------------------- | |
| class MultimodalDataset(Dataset): | |
| def __init__(self, items: List[Tuple[str, np.ndarray, int]]): | |
| self.items = items | |
| def __len__(self): return len(self.items) | |
| def __getitem__(self, idx): | |
| img_path, seq_arr, label = self.items[idx] | |
| img = Image.open(img_path).convert("RGB").resize((IMAGE_SIZE, IMAGE_SIZE)) | |
| img_arr = (np.array(img).astype(np.float32) / 255.0) | |
| img_arr = np.transpose(img_arr, (2,0,1)) | |
| img_tensor = torch.tensor(img_arr, dtype=torch.float32) | |
| seq_tensor = torch.tensor(seq_arr.astype(np.float32)) | |
| label_tensor = torch.tensor(int(label), dtype=torch.long) | |
| return img_tensor, seq_tensor, label_tensor | |
| class SimpleCNNEmbedding(nn.Module): | |
| def __init__(self, emb_size=128): | |
| super().__init__() | |
| self.conv = nn.Sequential( | |
| nn.Conv2d(3, 16, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2), | |
| nn.Conv2d(16, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2), | |
| nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2), | |
| ) | |
| conv_out = (IMAGE_SIZE // 8) * (IMAGE_SIZE // 8) * 64 | |
| self.head = nn.Sequential(nn.Flatten(), nn.Linear(conv_out, emb_size), nn.ReLU()) | |
| def forward(self, x): return self.head(self.conv(x)) | |
| class SimpleLSTMEmbedding(nn.Module): | |
| def __init__(self, n_features, hidden=64, emb_size=128): | |
| super().__init__() | |
| self.lstm = nn.LSTM(input_size=n_features, hidden_size=hidden, num_layers=1, batch_first=True, bidirectional=True) | |
| self.proj = nn.Sequential(nn.Linear(hidden*2, emb_size), nn.ReLU()) | |
| def forward(self, x): | |
| out, _ = self.lstm(x) | |
| last = out[:, -1, :] | |
| return self.proj(last) | |
| class MultimodalNet(nn.Module): | |
| def __init__(self, n_features, emb_size=128, n_classes=3): | |
| super().__init__() | |
| self.img_model = SimpleCNNEmbedding(emb_size=emb_size) | |
| self.seq_model = SimpleLSTMEmbedding(n_features=n_features, hidden=64, emb_size=emb_size) | |
| self.head = nn.Sequential(nn.Linear(emb_size*2, 256), nn.ReLU(), nn.Dropout(0.2), nn.Linear(256, n_classes)) | |
| def forward(self, img, seq): | |
| img_emb = self.img_model(img) | |
| seq_emb = self.seq_model(seq) | |
| cat = torch.cat([img_emb, seq_emb], dim=1) | |
| return self.head(cat) | |
| # ------------------------- | |
| # Helper: ATR | |
| # ------------------------- | |
| def compute_atr(df: pd.DataFrame, n=14): | |
| high = df["High"].values | |
| low = df["Low"].values | |
| close = df["Close"].values | |
| tr = np.maximum(high - low, np.maximum(np.abs(high - np.concatenate(([close[0]], close[:-1]))), np.abs(low - np.concatenate(([close[0]], close[:-1]))))) | |
| # first TR uses high-low | |
| atr = pd.Series(tr).rolling(window=n, min_periods=1).mean().values | |
| return atr # numpy array length n_rows | |
| # ------------------------- | |
| # Load & predict (single) | |
| # ------------------------- | |
| def load_multimodal_model(): | |
| path = os.path.join(MODEL_DIR, "multimodal_model.pt") | |
| if not os.path.exists(path): | |
| return None | |
| ckpt = torch.load(path, map_location=DEVICE) | |
| n_features = ckpt.get("n_features") | |
| model = MultimodalNet(n_features=n_features, emb_size=128, n_classes=len(LABELS)).to(DEVICE) | |
| model.load_state_dict(ckpt["model_state"]) | |
| model.eval() | |
| return model | |
| def predict_multimodal(image_bytes, csv_snippet_file, scaler_path=os.path.join(MODEL_DIR,"csv_scaler.pkl")): | |
| model = load_multimodal_model() | |
| if model is None: | |
| return "No trained multimodal model found.", None | |
| if not os.path.exists(scaler_path): | |
| return "CSV scaler not found. Train first.", None | |
| img = Image.open(io.BytesIO(image_bytes)).convert("RGB").resize((IMAGE_SIZE, IMAGE_SIZE)) | |
| img_arr = (np.array(img).astype(np.float32) / 255.0); img_arr = np.transpose(img_arr, (2,0,1)) | |
| img_tensor = torch.tensor(img_arr[None], dtype=torch.float32).to(DEVICE) | |
| df = pd.read_csv(csv_snippet_file.name) | |
| if df.shape[0] < SEQ_LEN: | |
| return f"CSV snippet too short. Need at least {SEQ_LEN} rows.", None | |
| required = {"Open","High","Low","Close","Volume"} | |
| if not required.issubset(set(df.columns)): | |
| return f"CSV snippet must have columns: {required}", None | |
| scaler = joblib.load(scaler_path) | |
| feats = df[["Open","High","Low","Close","Volume"]].values.astype(np.float32) | |
| feats_scaled = scaler.transform(feats) | |
| seq_arr = feats_scaled[-SEQ_LEN:] | |
| seq_tensor = torch.tensor(seq_arr[None], dtype=torch.float32).to(DEVICE) | |
| with torch.no_grad(): | |
| logits = model(img_tensor, seq_tensor) | |
| probs = torch.softmax(logits, dim=-1).cpu().numpy()[0] | |
| idx = int(probs.argmax()) | |
| return LABELS[idx], {LABELS[i]: float(probs[i]) for i in range(len(LABELS))} | |
| # ------------------------- | |
| # Batch predict with ATR-based SL/TP + HTML report | |
| # ------------------------- | |
| def batch_predict_extract_trades_atr(full_csv_file, stop_mult=1.0, tp_mult=2.0, atr_period=14, max_hold=32, scaler_path=os.path.join(MODEL_DIR,"csv_scaler.pkl")): | |
| """ | |
| Sliding-window sequence-only predictions and ATR-based stop/take management. | |
| stop_mult: ATR multiplier for stop loss (e.g., 1.0) | |
| tp_mult: ATR multiplier for take profit (e.g., 2.0) | |
| atr_period: ATR period length (default 14) | |
| """ | |
| # require scaler & model | |
| if not os.path.exists(scaler_path): | |
| return None, None, None, "CSV scaler not found. Train first." | |
| model = load_multimodal_model() | |
| if model is None: | |
| return None, None, None, "No trained multimodal model found." | |
| scaler = joblib.load(scaler_path) | |
| # load & clean df | |
| df = handle_ohlcv_csv(full_csv_file) | |
| n = df.shape[0] | |
| if n < SEQ_LEN: | |
| return None, None, None, f"CSV too short: need at least {SEQ_LEN} rows." | |
| feats = df[["Open","High","Low","Close","Volume"]].values.astype(np.float32) | |
| feats_scaled = scaler.transform(feats) | |
| # sliding windows (same as earlier) | |
| windows = [] | |
| indices = [] | |
| for end_idx in range(SEQ_LEN-1, n): | |
| start = end_idx - (SEQ_LEN-1) | |
| windows.append(feats_scaled[start:end_idx+1]) | |
| indices.append(end_idx) | |
| X = np.stack(windows) # (B, SEQ_LEN, features) | |
| # predict using sequence branch (zero image emb) | |
| device = DEVICE | |
| preds = [] | |
| probs_all = [] | |
| with torch.no_grad(): | |
| B = X.shape[0] | |
| batch_size = SEQ_BATCH | |
| for i in range(0, B, batch_size): | |
| xb = torch.tensor(X[i:i+batch_size], dtype=torch.float32).to(device) | |
| seq_emb = model.seq_model(xb) # (b, emb_size) | |
| zero_img_emb = torch.zeros_like(seq_emb) | |
| head_in = torch.cat([zero_img_emb, seq_emb], dim=1) | |
| logits = model.head(head_in) | |
| probs = torch.softmax(logits, dim=-1).cpu().numpy() | |
| batch_preds = probs.argmax(axis=1) | |
| preds.extend(batch_preds.tolist()) | |
| probs_all.extend(probs.tolist()) | |
| label_series = [None] * (SEQ_LEN-1) + [LABELS[int(p)] for p in preds] | |
| prob_series = [None] * (SEQ_LEN-1) + [probs_all[i] for i in range(len(preds))] | |
| # compute ATR | |
| atr = compute_atr(df, n=atr_period) # numpy array | |
| trades = [] | |
| state = "flat" | |
| entry_idx = None | |
| entry_price = None | |
| side = None | |
| hold_count = 0 | |
| t = SEQ_LEN-1 | |
| while t < n: | |
| lab = label_series[t] | |
| # enter | |
| if state == "flat": | |
| if lab == "Buy" or lab == "Sell": | |
| # entry at next candle open if possible | |
| entry_idx = t+1 if (t+1 < n) else t | |
| entry_open = float(df["Open"].iloc[entry_idx]) if (t+1 < n) else float(df["Close"].iloc[t]) | |
| entry_price = entry_open | |
| side = "Long" if lab == "Buy" else "Short" | |
| # determine ATR at entry index (use most recent available) | |
| atr_val = float(atr[entry_idx]) if entry_idx < len(atr) else float(atr[-1]) | |
| if np.isnan(atr_val) or atr_val == 0: | |
| # fallback: use close-open | |
| atr_val = float(abs(df["Close"].iloc[entry_idx] - df["Open"].iloc[entry_idx])) | |
| if atr_val == 0: | |
| atr_val = 1e-6 | |
| if side == "Long": | |
| stop_price = entry_price - stop_mult * atr_val | |
| take_price = entry_price + tp_mult * atr_val | |
| else: | |
| stop_price = entry_price + stop_mult * atr_val | |
| take_price = entry_price - tp_mult * atr_val | |
| hold_count = 0 | |
| # start scanning from entry candle (inclusive) for stop/take/opposite signal | |
| i = entry_idx | |
| exited = False | |
| while i < n: | |
| # examine candle i high/low | |
| high_i = float(df["High"].iloc[i]) | |
| low_i = float(df["Low"].iloc[i]) | |
| # check stop / take first (intra-candle) | |
| if side == "Long": | |
| if low_i <= stop_price: | |
| exit_price = stop_price | |
| exit_idx = i | |
| reason = "stop" | |
| exited = True | |
| elif high_i >= take_price: | |
| exit_price = take_price | |
| exit_idx = i | |
| reason = "take" | |
| exited = True | |
| else: # Short | |
| if high_i >= stop_price: | |
| exit_price = stop_price | |
| exit_idx = i | |
| reason = "stop" | |
| exited = True | |
| elif low_i <= take_price: | |
| exit_price = take_price | |
| exit_idx = i | |
| reason = "take" | |
| exited = True | |
| if exited: | |
| # record trade | |
| if side == "Long": | |
| pnl = exit_price / entry_price - 1.0 | |
| else: | |
| pnl = entry_price / exit_price - 1.0 | |
| trades.append({ | |
| "entry_idx": entry_idx, | |
| "entry_time": df["Date"].iloc[entry_idx] if "Date" in df.columns else int(entry_idx), | |
| "entry_price": entry_price, | |
| "exit_idx": exit_idx, | |
| "exit_time": df["Date"].iloc[exit_idx] if "Date" in df.columns else int(exit_idx), | |
| "exit_price": exit_price, | |
| "side": side, | |
| "pnl": pnl, | |
| "exit_reason": reason | |
| }) | |
| state = "flat" | |
| entry_idx = None; entry_price = None; side = None | |
| # move t to exit_idx + 1 | |
| t = exit_idx + 1 | |
| break | |
| # check opposite signal at this candle's aligned label (label_series aligned to window end) | |
| lab_i = label_series[i] if i < len(label_series) else None | |
| if (side == "Long" and lab_i == "Sell") or (side == "Short" and lab_i == "Buy"): | |
| # exit at next open if exists | |
| exit_idx = i+1 if (i+1 < n) else i | |
| exit_price = float(df["Open"].iloc[exit_idx]) if (i+1 < n) else float(df["Close"].iloc[i]) | |
| reason = "opp_signal" | |
| if side == "Long": | |
| pnl = exit_price / entry_price - 1.0 | |
| else: | |
| pnl = entry_price / exit_price - 1.0 | |
| trades.append({ | |
| "entry_idx": entry_idx, | |
| "entry_time": df["Date"].iloc[entry_idx] if "Date" in df.columns else int(entry_idx), | |
| "entry_price": entry_price, | |
| "exit_idx": exit_idx, | |
| "exit_time": df["Date"].iloc[exit_idx] if "Date" in df.columns else int(exit_idx), | |
| "exit_price": exit_price, | |
| "side": side, | |
| "pnl": pnl, | |
| "exit_reason": reason | |
| }) | |
| state = "flat" | |
| entry_idx = None; entry_price = None; side = None | |
| t = exit_idx + 1 | |
| exited = True | |
| break | |
| # hold count | |
| hold_count += 1 | |
| if hold_count >= max_hold: | |
| exit_idx = i+1 if (i+1 < n) else i | |
| exit_price = float(df["Open"].iloc[exit_idx]) if (i+1 < n) else float(df["Close"].iloc[i]) | |
| reason = "max_hold" | |
| if side == "Long": | |
| pnl = exit_price / entry_price - 1.0 | |
| else: | |
| pnl = entry_price / exit_price - 1.0 | |
| trades.append({ | |
| "entry_idx": entry_idx, | |
| "entry_time": df["Date"].iloc[entry_idx] if "Date" in df.columns else int(entry_idx), | |
| "entry_price": entry_price, | |
| "exit_idx": exit_idx, | |
| "exit_time": df["Date"].iloc[exit_idx] if "Date" in df.columns else int(exit_idx), | |
| "exit_price": exit_price, | |
| "side": side, | |
| "pnl": pnl, | |
| "exit_reason": reason | |
| }) | |
| state = "flat" | |
| entry_idx = None; entry_price = None; side = None | |
| t = exit_idx + 1 | |
| exited = True | |
| break | |
| # continue to next candle | |
| i += 1 | |
| if not exited: | |
| # reached end of series without exit -> close at last close | |
| exit_idx = n-1 | |
| exit_price = float(df["Close"].iloc[exit_idx]) | |
| reason = "eod" | |
| if side == "Long": | |
| pnl = exit_price / entry_price - 1.0 | |
| else: | |
| pnl = entry_price / exit_price - 1.0 | |
| trades.append({ | |
| "entry_idx": entry_idx, | |
| "entry_time": df["Date"].iloc[entry_idx] if "Date" in df.columns else int(entry_idx), | |
| "entry_price": entry_price, | |
| "exit_idx": exit_idx, | |
| "exit_time": df["Date"].iloc[exit_idx] if "Date" in df.columns else int(exit_idx), | |
| "exit_price": exit_price, | |
| "side": side, | |
| "pnl": pnl, | |
| "exit_reason": reason | |
| }) | |
| return _save_and_report(df, trades) # end | |
| continue # continue outer while (t already advanced to exit_idx+1) | |
| # increment t if no entry or after handled | |
| t += 1 | |
| # done scanning | |
| return _save_and_report(df, trades) | |
| def _save_and_report(df, trades): | |
| # save trades CSV | |
| trades_df = pd.DataFrame(trades) | |
| ts = int(datetime.utcnow().timestamp()) | |
| trades_csv_path = os.path.join(MODEL_DIR, f"trades_{ts}.csv") | |
| trades_df.to_csv(trades_csv_path, index=False) | |
| # make plot | |
| plt.figure(figsize=(12,5)) | |
| plt.plot(df.index, df["Close"].values, label="Close") | |
| entries_x = [] | |
| entries_y = [] | |
| exits_x = [] | |
| exits_y = [] | |
| colors = [] | |
| for tr in trades: | |
| eidx = int(tr["entry_idx"]); exidx = int(tr["exit_idx"]) | |
| entries_x.append(eidx); entries_y.append(tr["entry_price"]) | |
| exits_x.append(exidx); exits_y.append(tr["exit_price"]) | |
| colors.append("g" if tr["side"]=="Long" else "r") | |
| # plot entries/exits | |
| if len(entries_x): | |
| for (x,y,c) in zip(entries_x, entries_y, colors): | |
| plt.scatter([x],[y], marker="o", color=c, zorder=5) | |
| if len(exits_x): | |
| for (x,y,c) in zip(exits_x, exits_y, colors): | |
| plt.scatter([x],[y], marker="x", color=c, zorder=5) | |
| plt.title("Close price with entries (o) and exits (x)") | |
| plt.xlabel("index") | |
| plt.ylabel("Close") | |
| plt.legend() | |
| plot_path = os.path.join(MODEL_DIR, f"trades_plot_{ts}.png") | |
| plt.tight_layout() | |
| plt.savefig(plot_path) | |
| plt.close() | |
| # compute metrics | |
| if trades_df.empty: | |
| summary_text = "No trades generated." | |
| report_path = _generate_html_report(df, trades_df, plot_path, summary_text, {}) | |
| return trades_csv_path, plot_path, report_path, summary_text | |
| pnl_list = trades_df["pnl"].astype(float).values | |
| total_pnl = float(np.nansum(pnl_list)) | |
| avg_pnl = float(np.nanmean(pnl_list)) | |
| wins = float((pnl_list > 0).sum()) | |
| win_rate = float(wins / len(pnl_list)) if len(pnl_list) > 0 else 0.0 | |
| # equity curve (assume 1 unit risk per trade, returns additively) | |
| # we'll build cumulative returns as (1 + pnl) multiplicatively for realistic equity growth | |
| equity = np.cumprod(1 + pnl_list) # start at 1 | |
| peak = np.maximum.accumulate(equity) | |
| drawdowns = (peak - equity) / peak | |
| max_dd = float(np.max(drawdowns)) if len(drawdowns)>0 else 0.0 | |
| summary = { | |
| "n_trades": int(len(pnl_list)), | |
| "total_pnl": total_pnl, | |
| "avg_pnl": avg_pnl, | |
| "win_rate": win_rate, | |
| "max_drawdown": max_dd | |
| } | |
| summary_text = f"Trades: {summary['n_trades']}, Total PnL (factorized): {summary['total_pnl']:.4f}, Win rate: {summary['win_rate']:.3f}, Max Drawdown: {summary['max_drawdown']:.3f}" | |
| report_path = _generate_html_report(df, trades_df, plot_path, summary_text, summary) | |
| return trades_csv_path, plot_path, report_path, summary_text | |
| def _generate_html_report(df, trades_df, plot_path, summary_text, summary): | |
| # embed plot as base64 | |
| with open(plot_path, "rb") as f: | |
| img_b64 = base64.b64encode(f.read()).decode("ascii") | |
| img_tag = f'<img src="data:image/png;base64,{img_b64}" style="max-width:100%;height:auto;" />' | |
| # trades table html | |
| if trades_df.empty: | |
| trades_html = "<p>No trades</p>" | |
| else: | |
| trades_html = trades_df.to_html(classes="table table-striped", index=False, float_format="%.6f") | |
| # summary HTML block | |
| summary_html = "<ul>" | |
| for k,v in summary.items(): | |
| summary_html += f"<li><strong>{k}</strong>: {v}</li>" | |
| summary_html += "</ul>" | |
| html = f""" | |
| <html> | |
| <head> | |
| <meta charset="utf-8"/> | |
| <title>SMC Multimodal - ATR Trades Report</title> | |
| <style> | |
| body {{ font-family: Arial, sans-serif; margin: 20px; }} | |
| .header {{ margin-bottom: 20px; }} | |
| .section {{ margin-bottom: 30px; }} | |
| table {{ border-collapse: collapse; width: 100%; }} | |
| table, th, td {{ border: 1px solid #ddd; }} | |
| th, td {{ padding: 8px; text-align: left; }} | |
| th {{ background-color: #f2f2f2; }} | |
| </style> | |
| </head> | |
| <body> | |
| <div class="header"> | |
| <h1>SMC Multimodal - ATR Trades Report</h1> | |
| <p>Generated: {datetime.utcnow().isoformat()} UTC</p> | |
| <p>{summary_text}</p> | |
| </div> | |
| <div class="section"> | |
| <h2>Price chart with trades</h2> | |
| {img_tag} | |
| </div> | |
| <div class="section"> | |
| <h2>Trades</h2> | |
| {trades_html} | |
| </div> | |
| <div class="section"> | |
| <h2>Summary metrics</h2> | |
| {summary_html} | |
| </div> | |
| </body> | |
| </html> | |
| """ | |
| ts = int(datetime.utcnow().timestamp()) | |
| report_path = os.path.join(MODEL_DIR, f"report_{ts}.html") | |
| with open(report_path, "w", encoding="utf-8") as f: | |
| f.write(html) | |
| return report_path | |
| # ------------------------- | |
| # The rest (Gemini, auto-label, training) left mostly unchanged from earlier app | |
| # For brevity, reuse previous functions for ai_teacher and training wrappers. | |
| # (You can paste the previous implementation here for a full runnable app.) | |
| # ------------------------- | |
| # For this response I will implement minimal wrappers to keep the app runnable: | |
| GENAI_API_KEY = os.environ.get("GENAI_API_KEY") or os.environ.get("GOOGLE_API_KEY") | |
| PROMPT_JSON_SCHEMA = f""" | |
| You are a chart-analysis assistant. INPUT: a candlestick chart image (standard candlestick, no overlays). | |
| You MUST OUTPUT a single JSON object only (no explanatory text). Schema: | |
| {{ | |
| "label": "Buy" | "Sell" | "Hold", | |
| "confidence": 0.0-1.0, | |
| "seq_start": <integer index into the provided OHLCV CSV> OR a date string "YYYY-MM-DD" (one of the two), | |
| "note": "brief explanation (optional)" | |
| }} | |
| Constraints: | |
| - Only output valid JSON following exactly the keys above. | |
| - If uncertain, return confidence < 0.6. | |
| - seq_start should be either an integer or an ISO date string (not both). | |
| - Do NOT include any other keys. | |
| """ | |
| def ai_teacher_with_gemini(image_bytes: bytes, prompt_extra: str = "") -> dict: | |
| if genai is None: | |
| raise RuntimeError("google-genai package not installed. pip install google-generativeai") | |
| api_key = GENAI_API_KEY | |
| if not api_key: | |
| raise RuntimeError("Set GENAI_API_KEY environment variable with your Gemini API key.") | |
| client = genai.Client(api_key=api_key) | |
| prompt = PROMPT_JSON_SCHEMA + "\n" + prompt_extra | |
| pil_img = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
| contents = [prompt, pil_img] | |
| response = client.models.generate_content(model="gemini-2.5-flash", contents=contents, temperature=0.0, max_output_tokens=800, response_format="TEXT") | |
| text = None | |
| try: | |
| cand = response.candidates[0] | |
| if hasattr(cand, "content") and hasattr(cand.content, "parts"): | |
| parts = cand.content.parts | |
| collected = [] | |
| for p in parts: | |
| if getattr(p, "text", None): | |
| collected.append(p.text) | |
| text = "\n".join(collected).strip() | |
| else: | |
| text = str(response) | |
| except Exception: | |
| text = str(response) | |
| try: | |
| parsed = json.loads(text) | |
| except Exception: | |
| import re | |
| m = re.search(r"(\{.*\})", text, flags=re.S) | |
| if not m: | |
| raise ValueError(f"Could not parse JSON from model response: {text[:500]}") | |
| parsed = json.loads(m.group(1)) | |
| return parsed | |
| # Minimal stubs for auto-label/train/predict wrappers so UI launches | |
| AUTO_IMAGES_DIR = os.path.join(AUTO_DIR, "images") | |
| os.makedirs(AUTO_IMAGES_DIR, exist_ok=True) | |
| AUTO_MAPPING_CSV = os.path.join(AUTO_DIR, "mapping_auto.csv") | |
| if not os.path.exists(AUTO_MAPPING_CSV): | |
| pd.DataFrame(columns=["filename","label","seq_start"]).to_csv(AUTO_MAPPING_CSV, index=False) | |
| def ui_ai_label_and_suggest(img_filepath, ohlcv_csv=None): | |
| return "AI tab not active in this trimmed example (install GenAI)", "{}", None, img_filepath | |
| def ui_accept_auto_label(img_filepath, parsed_json_text, seq_start_override, save_prefix="auto"): | |
| return "Not implemented in trimmed example", None | |
| def ui_train_from_auto(ohlcv_csv, epochs=EPOCHS): | |
| return "Train-from-auto not implemented in trimmed example", None | |
| def ui_train_multimodal_wrapper(img_zip_file, mapping_file, ohlcv_file_in, epochs_in): | |
| return "Train wrapper not implemented in trimmed example", None | |
| def ui_predict_multimodal(img_filepath, csv_snippet_file): | |
| try: | |
| with open(img_filepath, "rb") as f: | |
| image_bytes = f.read() | |
| except Exception as e: | |
| return f"Could not read image file: {e}", None, None | |
| try: | |
| label, probs = predict_multimodal(image_bytes, csv_snippet_file) | |
| return label, str(probs), img_filepath | |
| except Exception as e: | |
| return f"Prediction error: {e}", None, None | |
| def ui_batch_predict_and_export(full_csv_file, stop_mult, tp_mult, atr_period, max_hold): | |
| try: | |
| trades_csv, plot_path, report_path, msg = batch_predict_extract_trades_atr(full_csv_file, stop_mult=float(stop_mult), tp_mult=float(tp_mult), atr_period=int(atr_period), max_hold=int(max_hold)) | |
| if trades_csv is None: | |
| return msg, None, None, None | |
| return msg, trades_csv, plot_path, report_path | |
| except Exception as e: | |
| return f"Batch predict error: {e}", None, None, None | |
| # ------------------------- | |
| # Build Gradio UI | |
| # ------------------------- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# SMC Multimodal Trainer + ATR-based Trade Extraction & HTML Report") | |
| with gr.Tab("Query / Batch Predict (ATR)"): | |
| gr.Markdown("Upload a full OHLCV CSV and choose ATR/SL/TP settings to generate trades and an HTML report.") | |
| full_csv = gr.File(label="Full OHLCV CSV (for batch prediction)") | |
| with gr.Row(): | |
| stop_mult_input = gr.Number(value=1.0, label="Stop ATR multiplier (stop = entry ± ATR * stop_mult)") | |
| tp_mult_input = gr.Number(value=2.0, label="TakeProfit ATR multiplier (take = entry ± ATR * tp_mult)") | |
| atr_period_input = gr.Number(value=14, label="ATR period") | |
| max_hold_input = gr.Number(value=32, label="Max holding candles") | |
| batch_btn = gr.Button("Batch Predict -> ATR trades + report") | |
| batch_out = gr.Textbox(label="Batch Output") | |
| batch_trades_file = gr.File(label="Trades CSV (download)") | |
| batch_plot = gr.Image(label="Trades plot (entries/exits)") | |
| batch_report = gr.File(label="HTML report (download)") | |
| batch_btn.click(fn=ui_batch_predict_and_export, inputs=[full_csv, stop_mult_input, tp_mult_input, atr_period_input, max_hold_input], outputs=[batch_out, batch_trades_file, batch_plot, batch_report]) | |
| with gr.Tab("Notes"): | |
| gr.Markdown(""" | |
| **ATR-based trade simulation notes** | |
| - This simulator uses ATR for stop loss and take profit levels. | |
| - Entry is at next candle Open after the signal. | |
| - The simulator checks each candle to see if SL/TP hit (intra-candle). | |
| - If the opposite model signal appears, it exits at the next open. | |
| - HTML report includes embedded plot and trade table, plus summary metrics (win rate, max drawdown). | |
| - Validate the results and tune multipliers (stop_mult / tp_mult) before trusting strategies. | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch(debug=True, share=False) | |