Arrechenash commited on
Commit
903715e
·
1 Parent(s): 4daa2c1

Simple scatter plot with altair

Browse files
Files changed (2) hide show
  1. requirements.txt +1 -1
  2. src/streamlit_app.py +77 -101
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
  pandas
2
  duckdb
3
  streamlit
4
- plotly
 
1
  pandas
2
  duckdb
3
  streamlit
4
+ altair
src/streamlit_app.py CHANGED
@@ -1,10 +1,13 @@
1
  import os
 
 
2
  import duckdb
3
  import streamlit as st
4
- import plotly.express as px
5
 
6
- # Configuration
7
  st.set_page_config(layout="wide")
 
 
8
  url = (
9
  "stocks.parquet"
10
  if os.getenv("APP_ENV") == "development"
@@ -14,10 +17,10 @@ url = (
14
 
15
  @st.cache_data
16
  def get_data(filters=None):
17
- base_query = "SELECT * FROM read_parquet('{}')".format(url)
18
  if filters:
19
- base_query += f" WHERE {' AND '.join(filters)}"
20
- return duckdb.query(base_query + " ORDER BY date DESC").to_df()
21
 
22
 
23
  @st.cache_data
@@ -31,24 +34,22 @@ def load_symbols():
31
  )
32
 
33
 
34
- # Initialize filters and chart settings
35
- if "filters" not in st.session_state:
36
- st.session_state.update(
37
- {
38
- "date": None,
39
- "symbols": [],
40
- "min_open": 1.0,
41
- "min_gap": None,
42
- "min_run": None,
43
- "min_vol": 10000.0,
44
- "min_relvol": 5.0,
45
- "chart_type": "scatter",
46
- "x_axis": "open",
47
- "y_axis": "change_pct",
48
- }
49
- )
50
-
51
- # Filters panel
52
  with st.sidebar:
53
  st.session_state.date = st.date_input("Date", st.session_state.date)
54
  st.session_state.symbols = st.multiselect(
@@ -74,26 +75,28 @@ with st.sidebar:
74
  "Min rel volume", value=st.session_state.min_relvol
75
  )
76
 
77
- # Build query filters
78
- query_filters = []
79
- if st.session_state.date:
80
- query_filters.append(f"date = '{st.session_state.date.strftime('%Y-%m-%d')}'")
81
- if st.session_state.symbols:
82
- query_filters.append(f"symbol IN {tuple(st.session_state.symbols)}")
83
- if st.session_state.min_open:
84
- query_filters.append(f"open >= {st.session_state.min_open}")
85
- if st.session_state.min_vol:
86
- query_filters.append(f"volume >= {st.session_state.min_vol}")
87
- if st.session_state.min_relvol:
88
- query_filters.append(f"relative_volume >= {st.session_state.min_relvol}")
89
- if st.session_state.min_gap:
90
- query_filters.append(f"gap_pct >= {st.session_state.min_gap}")
91
- if st.session_state.min_run:
92
- query_filters.append(f"run_pct >= {st.session_state.min_run}")
93
-
94
- # Get and display data
95
- df = get_data(query_filters if query_filters else None)
96
-
 
 
97
  if df.empty:
98
  st.info("No data found with current filters")
99
  else:
@@ -103,73 +106,46 @@ else:
103
  st.dataframe(df, use_container_width=True)
104
 
105
  with tab2:
106
- # Chart configuration
107
- cols = st.columns(3)
108
  numeric_cols = df.select_dtypes("number").columns.tolist()
109
 
110
- with cols[0]:
111
- st.session_state.chart_type = st.selectbox(
112
- "Chart Type",
113
- ["scatter", "line", "bar", "histogram"],
114
- index=["scatter", "line", "bar", "histogram"].index(
115
- st.session_state.chart_type
116
- ),
117
- )
118
- with cols[1]:
119
- st.session_state.x_axis = st.selectbox(
120
  "X-axis",
121
  numeric_cols,
122
  index=(
123
- numeric_cols.index(st.session_state.x_axis)
124
- if st.session_state.x_axis in numeric_cols
125
  else 0
126
  ),
 
127
  )
128
- with cols[2]:
129
- # For histogram, we don't need Y-axis selection
130
- if st.session_state.chart_type != "histogram":
131
- st.session_state.y_axis = st.selectbox(
132
- "Y-axis",
133
- numeric_cols,
134
- index=(
135
- numeric_cols.index(st.session_state.y_axis)
136
- if st.session_state.y_axis in numeric_cols
137
- else 1
138
- ),
139
- )
140
-
141
- # Generate the appropriate chart
142
- if st.session_state.chart_type == "scatter":
143
- fig = px.scatter(
144
- df,
145
- x=st.session_state.x_axis,
146
- y=st.session_state.y_axis,
147
- hover_data=["symbol"],
148
- title=f"{st.session_state.y_axis} vs {st.session_state.x_axis}",
149
- )
150
- elif st.session_state.chart_type == "line":
151
- fig = px.line(
152
- df,
153
- x=st.session_state.x_axis,
154
- y=st.session_state.y_axis,
155
- hover_data=["symbol"],
156
- title=f"{st.session_state.y_axis} over {st.session_state.x_axis}",
157
- )
158
- elif st.session_state.chart_type == "bar":
159
- fig = px.bar(
160
- df,
161
- x=st.session_state.x_axis,
162
- y=st.session_state.y_axis,
163
- hover_data=["symbol"],
164
- title=f"{st.session_state.y_axis} by {st.session_state.x_axis}",
165
  )
166
- elif st.session_state.chart_type == "histogram":
167
- fig = px.histogram(
168
- df,
169
- x=st.session_state.x_axis,
170
- title=f"Distribution of {st.session_state.x_axis}",
 
 
 
 
 
171
  )
 
 
 
172
 
173
- st.plotly_chart(fig, use_container_width=True)
174
 
175
- st.write(f"Results: {len(df)}")
 
1
  import os
2
+
3
+ import altair as alt
4
  import duckdb
5
  import streamlit as st
 
6
 
7
+ # Page configuration
8
  st.set_page_config(layout="wide")
9
+
10
+ # Data source
11
  url = (
12
  "stocks.parquet"
13
  if os.getenv("APP_ENV") == "development"
 
17
 
18
  @st.cache_data
19
  def get_data(filters=None):
20
+ query = f"SELECT * FROM read_parquet('{url}')"
21
  if filters:
22
+ query += " WHERE " + " AND ".join(filters)
23
+ return duckdb.query(query + " ORDER BY date DESC").to_df()
24
 
25
 
26
  @st.cache_data
 
34
  )
35
 
36
 
37
+ # Initialize session state
38
+ defaults = {
39
+ "date": None,
40
+ "symbols": [],
41
+ "min_open": 1.0,
42
+ "min_gap": None,
43
+ "min_run": None,
44
+ "min_vol": 10000.0,
45
+ "min_relvol": 5.0,
46
+ "x_axis": "open",
47
+ "y_axis": "change_pct",
48
+ }
49
+ for key, value in defaults.items():
50
+ st.session_state.setdefault(key, value)
51
+
52
+ # Sidebar filters
 
 
53
  with st.sidebar:
54
  st.session_state.date = st.date_input("Date", st.session_state.date)
55
  st.session_state.symbols = st.multiselect(
 
75
  "Min rel volume", value=st.session_state.min_relvol
76
  )
77
 
78
+ # Construct query filters
79
+ f = st.session_state
80
+ filters = []
81
+ if f.date:
82
+ filters.append(f"date = '{f.date.strftime('%Y-%m-%d')}'")
83
+ if f.symbols:
84
+ filters.append(f"symbol IN {tuple(f.symbols)}")
85
+ if f.min_open:
86
+ filters.append(f"open >= {f.min_open}")
87
+ if f.min_vol:
88
+ filters.append(f"volume >= {f.min_vol}")
89
+ if f.min_relvol:
90
+ filters.append(f"relative_volume >= {f.min_relvol}")
91
+ if f.min_gap:
92
+ filters.append(f"gap_pct >= {f.min_gap}")
93
+ if f.min_run:
94
+ filters.append(f"run_pct >= {f.min_run}")
95
+
96
+ # Load data
97
+ df = get_data(filters if filters else None)
98
+
99
+ # UI
100
  if df.empty:
101
  st.info("No data found with current filters")
102
  else:
 
106
  st.dataframe(df, use_container_width=True)
107
 
108
  with tab2:
 
 
109
  numeric_cols = df.select_dtypes("number").columns.tolist()
110
 
111
+ col1, col2 = st.columns(2)
112
+ with col1:
113
+ x_axis = st.selectbox(
 
 
 
 
 
 
 
114
  "X-axis",
115
  numeric_cols,
116
  index=(
117
+ numeric_cols.index(f.get("x_axis", "open"))
118
+ if f.get("x_axis", "open") in numeric_cols
119
  else 0
120
  ),
121
+ key="x_axis",
122
  )
123
+ with col2:
124
+ y_axis = st.selectbox(
125
+ "Y-axis",
126
+ numeric_cols,
127
+ index=(
128
+ numeric_cols.index(f.get("y_axis", "change_pct"))
129
+ if f.get("y_axis", "change_pct") in numeric_cols
130
+ else 1
131
+ ),
132
+ key="y_axis",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  )
134
+
135
+ # Altair scatter plot
136
+ chart = (
137
+ alt.Chart(df)
138
+ .mark_circle(size=60)
139
+ .encode(
140
+ x=alt.X(x_axis, title=x_axis.capitalize()),
141
+ y=alt.Y(y_axis, title=y_axis.capitalize()),
142
+ color=alt.Color("symbol:N", legend=None),
143
+ tooltip=["symbol", x_axis, y_axis],
144
  )
145
+ .interactive()
146
+ .properties(title=f"{y_axis} vs {x_axis}")
147
+ )
148
 
149
+ st.altair_chart(chart, use_container_width=True)
150
 
151
+ st.write(f"Results: {len(df)}")