eshan6704 commited on
Commit
475d23a
Β·
verified Β·
1 Parent(s): 8e4dd54

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -130
app.py CHANGED
@@ -1,193 +1,150 @@
1
- """
2
- TALib + mplfinance + Gradio
3
- - Fetch data ONLY when symbol/date changes
4
- - Reuse cached data for chart type & pattern changes
5
- - Clean dashboard layout
6
- """
7
-
8
  import yfinance as yf
9
  import pandas as pd
10
  import numpy as np
11
  import mplfinance as mpf
12
  import talib
13
  import gradio as gr
14
- from datetime import date
15
  import os
16
- from typing import Optional, List
17
 
18
  # =====================================================
19
- # TALib pattern mapping
20
  # =====================================================
21
 
22
- TALIB_PATTERNS = sorted([n for n in dir(talib) if n.startswith("CDL")])
23
- PATTERN_DISPLAY_MAP = {n.replace("CDL", ""): n for n in TALIB_PATTERNS}
24
- DISPLAY_PATTERNS = ["None"] + list(PATTERN_DISPLAY_MAP.keys())
25
-
26
  CHART_TYPE_MAP = {
27
  "Candlestick": "candle",
28
  "OHLC": "ohlc",
29
  "Line": "line"
30
  }
31
 
 
 
 
 
32
  # =====================================================
33
- # Data utilities
34
  # =====================================================
35
 
36
- def _normalize_col_name(col) -> str:
37
- if isinstance(col, (tuple, list)):
38
- return "_".join(str(c) for c in col if c).lower()
39
- return str(col).strip().lower()
40
-
41
- def _find_best_col(key: str, columns: List[str]) -> Optional[str]:
42
- if key in columns:
43
- return key
44
- for c in columns:
45
- if c.endswith("_" + key):
46
- return c
47
- for c in columns:
48
- if key in c:
49
- return c
50
- return None
51
-
52
- def clean_ohlc(df: pd.DataFrame) -> pd.DataFrame:
53
  df = df.copy()
54
- df.columns = [_normalize_col_name(c) for c in df.columns]
55
-
56
- found = {}
57
- for k in ["open", "high", "low", "close"]:
58
- col = _find_best_col(k, df.columns)
59
- if not col:
60
- raise ValueError(f"Missing column: {k}")
61
- found[k] = col
62
-
63
- vol_col = _find_best_col("volume", df.columns)
64
-
65
- cols = [found["open"], found["high"], found["low"], found["close"]]
66
- if vol_col:
67
- cols.append(vol_col)
68
-
69
- df = df[cols].rename(columns={
70
- found["open"]: "Open",
71
- found["high"]: "High",
72
- found["low"]: "Low",
73
- found["close"]: "Close",
74
- vol_col: "Volume" if vol_col else None
75
  })
76
 
77
- df.index = pd.to_datetime(df.index, errors="coerce")
78
- df = df.dropna(subset=["Open", "High", "Low", "Close"])
79
- df = df.apply(pd.to_numeric, errors="coerce")
80
  df = df.dropna().sort_index()
81
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  if df.empty:
83
- raise ValueError("No valid OHLC data")
84
 
85
- return df
 
 
 
 
 
86
 
87
  # =====================================================
88
- # Pattern detection
89
  # =====================================================
90
 
91
- def get_pattern_addplots(df, pattern_name):
92
- talib_name = PATTERN_DISPLAY_MAP[pattern_name]
93
- func = getattr(talib, talib_name)
94
 
95
- o, h, l, c = (
 
96
  df["Open"].values,
97
  df["High"].values,
98
  df["Low"].values,
99
- df["Close"].values,
100
  )
101
 
102
- s = pd.Series(func(o, h, l, c), index=df.index)
103
- apds = []
104
 
105
  if (s > 0).any():
106
  bull = pd.Series(np.nan, index=df.index)
107
- bull[s > 0] = df.loc[s > 0, "Low"] * 0.98
108
- apds.append(mpf.make_addplot(bull, type="scatter",
109
- marker="^", markersize=90,
110
- color="green", alpha=0.85))
 
111
 
112
  if (s < 0).any():
113
  bear = pd.Series(np.nan, index=df.index)
114
- bear[s < 0] = df.loc[s < 0, "High"] * 1.02
115
- apds.append(mpf.make_addplot(bear, type="scatter",
116
- marker="v", markersize=90,
117
- color="red", alpha=0.85))
118
- return apds
119
-
120
- # =====================================================
121
- # Step 1: Fetch & cache data (ONLY on symbol/date change)
122
- # =====================================================
123
-
124
- def fetch_stock_data(symbol, start, end):
125
- if not symbol:
126
- return None, "Symbol required"
127
-
128
- df = yf.download(symbol, start=start, end=end, progress=False)
129
- if df.empty:
130
- return None, "No data found"
131
 
132
- try:
133
- df_clean = clean_ohlc(df)
134
- except Exception as e:
135
- return None, str(e)
136
-
137
- return df_clean, f"Data loaded for {symbol}"
138
 
139
  # =====================================================
140
- # Step 2: Render chart (reuse cached data)
141
  # =====================================================
142
 
143
- def render_chart(df, symbol, chart_type, pattern):
144
  if df is None:
145
- return None, "No cached data"
146
-
147
- addplots = []
148
- pattern_label = ""
149
-
150
- if chart_type != "Line" and pattern != "None":
151
- addplots = get_pattern_addplots(df, pattern)
152
- pattern_label = f" | Pattern: {pattern}"
153
 
 
154
  mpf_type = CHART_TYPE_MAP[chart_type]
155
 
156
  os.makedirs("/tmp", exist_ok=True)
157
- path = f"/tmp/{symbol}_{pd.Timestamp.now().strftime('%Y%m%d%H%M%S')}.png"
158
 
159
  fig, _ = mpf.plot(
160
  df,
161
  type=mpf_type,
162
- volume="Volume" in df.columns and mpf_type != "line",
163
- addplot=addplots if addplots and mpf_type != "line" else None,
164
  style="yahoo",
165
- title=f"{symbol} β€’ {chart_type}{pattern_label}",
166
  figscale=1.7,
167
  returnfig=True
168
  )
169
 
170
  fig.savefig(path, dpi=150, bbox_inches="tight")
171
- return path, "Chart updated"
172
 
173
  # =====================================================
174
- # Gradio UI
175
  # =====================================================
176
 
177
- with gr.Blocks(fill_height=True, theme=gr.themes.Soft()) as iface:
178
- gr.Markdown(
179
- "# πŸ“Š TALib Candlestick Pattern Dashboard\n"
180
- "**Optimized data fetching – instant UI updates**"
181
- )
182
 
183
  cached_df = gr.State(None)
184
 
185
  with gr.Row():
186
- with gr.Column(scale=1, min_width=320):
187
  symbol = gr.Textbox(label="Symbol", value="MSFT")
188
  start = gr.Textbox(label="Start Date", value="2024-01-01")
189
  end = gr.Textbox(label="End Date", value=date.today().strftime("%Y-%m-%d"))
190
 
 
 
191
  chart_type = gr.Dropdown(
192
  label="Chart Type",
193
  choices=list(CHART_TYPE_MAP.keys()),
@@ -195,33 +152,32 @@ with gr.Blocks(fill_height=True, theme=gr.themes.Soft()) as iface:
195
  )
196
 
197
  pattern = gr.Dropdown(
198
- label="Pattern",
199
- choices=DISPLAY_PATTERNS,
200
  value="HAMMER"
201
  )
202
 
203
- load_btn = gr.Button("πŸ”„ Load Data", variant="primary")
 
204
  status = gr.Textbox(label="Status", interactive=False)
205
 
206
  with gr.Column(scale=3):
207
- chart = gr.Image(type="filepath", height=720, show_label=False)
208
 
209
- # Fetch only when symbol/date changes
210
  load_btn.click(
211
- fetch_stock_data,
212
  inputs=[symbol, start, end],
213
  outputs=[cached_df, status]
214
  )
215
 
216
- # Render chart when options change
217
- for trigger in (chart_type, pattern):
218
- trigger.change(
219
- render_chart,
220
- inputs=[cached_df, symbol, chart_type, pattern],
221
- outputs=[chart, status]
222
- )
223
 
224
- # Auto-disable pattern for Line chart
225
  chart_type.change(
226
  lambda ct: gr.update(interactive=(ct != "Line")),
227
  inputs=chart_type,
@@ -229,4 +185,4 @@ with gr.Blocks(fill_height=True, theme=gr.themes.Soft()) as iface:
229
  )
230
 
231
  if __name__ == "__main__":
232
- iface.launch()
 
 
 
 
 
 
 
 
1
  import yfinance as yf
2
  import pandas as pd
3
  import numpy as np
4
  import mplfinance as mpf
5
  import talib
6
  import gradio as gr
 
7
  import os
8
+ from datetime import date
9
 
10
  # =====================================================
11
+ # CONFIG
12
  # =====================================================
13
 
 
 
 
 
14
  CHART_TYPE_MAP = {
15
  "Candlestick": "candle",
16
  "OHLC": "ohlc",
17
  "Line": "line"
18
  }
19
 
20
+ TALIB_PATTERNS = sorted([n for n in dir(talib) if n.startswith("CDL")])
21
+ PATTERN_MAP = {n.replace("CDL", ""): n for n in TALIB_PATTERNS}
22
+ PATTERN_LIST = ["None"] + list(PATTERN_MAP.keys())
23
+
24
  # =====================================================
25
+ # DATA CLEANING
26
  # =====================================================
27
 
28
+ def clean_ohlc(df):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  df = df.copy()
30
+ df.columns = [c.lower() for c in df.columns]
31
+
32
+ df = df.rename(columns={
33
+ "open": "Open",
34
+ "high": "High",
35
+ "low": "Low",
36
+ "close": "Close",
37
+ "volume": "Volume"
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  })
39
 
40
+ df.index = pd.to_datetime(df.index)
41
+ df = df[["Open", "High", "Low", "Close", "Volume"]]
 
42
  df = df.dropna().sort_index()
43
 
44
+ return df
45
+
46
+ # =====================================================
47
+ # LOAD DATA (INTERNET ONLY)
48
+ # =====================================================
49
+
50
+ def load_data(symbol, start, end):
51
+ if not symbol:
52
+ return None, "❌ Symbol required"
53
+
54
+ df = yf.download(symbol, start=start, end=end, progress=False)
55
+
56
  if df.empty:
57
+ return None, "❌ No data fetched"
58
 
59
+ try:
60
+ df = clean_ohlc(df)
61
+ except Exception as e:
62
+ return None, f"❌ Data error: {e}"
63
+
64
+ return df, f"βœ… Data loaded for {symbol} ({len(df)} rows)"
65
 
66
  # =====================================================
67
+ # PATTERN DETECTION
68
  # =====================================================
69
 
70
+ def pattern_addplots(df, pattern):
71
+ if pattern == "None":
72
+ return []
73
 
74
+ func = getattr(talib, PATTERN_MAP[pattern])
75
+ res = func(
76
  df["Open"].values,
77
  df["High"].values,
78
  df["Low"].values,
79
+ df["Close"].values
80
  )
81
 
82
+ s = pd.Series(res, index=df.index)
83
+ aps = []
84
 
85
  if (s > 0).any():
86
  bull = pd.Series(np.nan, index=df.index)
87
+ bull[s > 0] = df["Low"][s > 0] * 0.98
88
+ aps.append(mpf.make_addplot(
89
+ bull, type="scatter", marker="^",
90
+ color="green", markersize=90
91
+ ))
92
 
93
  if (s < 0).any():
94
  bear = pd.Series(np.nan, index=df.index)
95
+ bear[s < 0] = df["High"][s < 0] * 1.02
96
+ aps.append(mpf.make_addplot(
97
+ bear, type="scatter", marker="v",
98
+ color="red", markersize=90
99
+ ))
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
+ return aps
 
 
 
 
 
102
 
103
  # =====================================================
104
+ # BUILD CHART (NO INTERNET)
105
  # =====================================================
106
 
107
+ def build_chart(df, symbol, chart_type, pattern):
108
  if df is None:
109
+ return None, "❌ Load data first"
 
 
 
 
 
 
 
110
 
111
+ aps = pattern_addplots(df, pattern)
112
  mpf_type = CHART_TYPE_MAP[chart_type]
113
 
114
  os.makedirs("/tmp", exist_ok=True)
115
+ path = f"/tmp/{symbol}_{pd.Timestamp.now().strftime('%H%M%S')}.png"
116
 
117
  fig, _ = mpf.plot(
118
  df,
119
  type=mpf_type,
120
+ volume=(mpf_type != "line"),
121
+ addplot=aps if mpf_type != "line" else None,
122
  style="yahoo",
123
+ title=f"{symbol} | {chart_type} | Pattern: {pattern}",
124
  figscale=1.7,
125
  returnfig=True
126
  )
127
 
128
  fig.savefig(path, dpi=150, bbox_inches="tight")
129
+ return path, "πŸ“Š Chart built successfully"
130
 
131
  # =====================================================
132
+ # UI
133
  # =====================================================
134
 
135
+ with gr.Blocks(theme=gr.themes.Soft()) as app:
136
+ gr.Markdown("## πŸ“ˆ Stock Chart & Candlestick Pattern Analyzer")
 
 
 
137
 
138
  cached_df = gr.State(None)
139
 
140
  with gr.Row():
141
+ with gr.Column(scale=1):
142
  symbol = gr.Textbox(label="Symbol", value="MSFT")
143
  start = gr.Textbox(label="Start Date", value="2024-01-01")
144
  end = gr.Textbox(label="End Date", value=date.today().strftime("%Y-%m-%d"))
145
 
146
+ load_btn = gr.Button("🌐 Load Data", variant="primary")
147
+
148
  chart_type = gr.Dropdown(
149
  label="Chart Type",
150
  choices=list(CHART_TYPE_MAP.keys()),
 
152
  )
153
 
154
  pattern = gr.Dropdown(
155
+ label="Candlestick Pattern",
156
+ choices=PATTERN_LIST,
157
  value="HAMMER"
158
  )
159
 
160
+ build_btn = gr.Button("πŸ“Š Build Chart", variant="secondary")
161
+
162
  status = gr.Textbox(label="Status", interactive=False)
163
 
164
  with gr.Column(scale=3):
165
+ chart = gr.Image(show_label=False, height=720)
166
 
167
+ # BUTTON ACTIONS
168
  load_btn.click(
169
+ load_data,
170
  inputs=[symbol, start, end],
171
  outputs=[cached_df, status]
172
  )
173
 
174
+ build_btn.click(
175
+ build_chart,
176
+ inputs=[cached_df, symbol, chart_type, pattern],
177
+ outputs=[chart, status]
178
+ )
 
 
179
 
180
+ # Disable pattern for line chart
181
  chart_type.change(
182
  lambda ct: gr.update(interactive=(ct != "Line")),
183
  inputs=chart_type,
 
185
  )
186
 
187
  if __name__ == "__main__":
188
+ app.launch()