zhangify commited on
Commit
3de369a
·
verified ·
1 Parent(s): 712e833

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +381 -36
src/streamlit_app.py CHANGED
@@ -1,40 +1,385 @@
1
- import altair as alt
 
 
 
 
 
2
  import numpy as np
3
  import pandas as pd
 
4
  import streamlit as st
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import io
3
+ import json
4
+ import os
5
+ from typing import Dict, Any, Optional, Tuple, List
6
+
7
  import numpy as np
8
  import pandas as pd
9
+ import matplotlib.pyplot as plt
10
  import streamlit as st
11
 
12
+
13
+ # =========================
14
+ # Theme (your spec + paper knobs)
15
+ # =========================
16
+ plt.rcParams["font.family"] = "monospace"
17
+
18
+ PRIMARY = np.array([166, 0, 0]) / 255
19
+ CONTRARY = np.array([0, 166, 166]) / 255
20
+ NEUTRAL_MEDIUM_GREY = np.array([128, 128, 128]) / 255
21
+ NEUTRAL_DARK_GREY = np.array([64, 64, 64]) / 255
22
+
23
+
24
+ def _mix(c1, c2, t: float):
25
+ c1 = np.array(c1, dtype=float)
26
+ c2 = np.array(c2, dtype=float)
27
+ return (1 - t) * c1 + t * c2
28
+
29
+
30
+ def palette():
31
+ white = np.array([1.0, 1.0, 1.0])
32
+ return [
33
+ PRIMARY,
34
+ CONTRARY,
35
+ NEUTRAL_DARK_GREY,
36
+ NEUTRAL_MEDIUM_GREY,
37
+ _mix(PRIMARY, white, 0.35),
38
+ _mix(CONTRARY, white, 0.35),
39
+ _mix(NEUTRAL_DARK_GREY, white, 0.45),
40
+ _mix(NEUTRAL_MEDIUM_GREY, white, 0.35),
41
+ ]
42
+
43
+
44
+ def set_paper_style(exaggerated: bool = True):
45
+ if exaggerated:
46
+ base = 18
47
+ label = 22
48
+ title = 24
49
+ tick = 18
50
+ legend = 18
51
+ else:
52
+ base = 12
53
+ label = 14
54
+ title = 16
55
+ tick = 12
56
+ legend = 12
57
+
58
+ plt.rcParams.update({
59
+ "font.size": base,
60
+ "axes.titlesize": title,
61
+ "axes.labelsize": label,
62
+ "xtick.labelsize": tick,
63
+ "ytick.labelsize": tick,
64
+ "legend.fontsize": legend,
65
+ "axes.linewidth": 1.6,
66
+ "lines.linewidth": 2.8,
67
+ "lines.markersize": 7.0,
68
+ "grid.alpha": 0.25,
69
+ "grid.linewidth": 1.0,
70
+ "figure.dpi": 120,
71
+ "savefig.dpi": 600,
72
+ "savefig.bbox": "tight",
73
+ "savefig.pad_inches": 0.03,
74
+ "xtick.direction": "out",
75
+ "ytick.direction": "out",
76
+ "xtick.major.size": 6.0,
77
+ "ytick.major.size": 6.0,
78
+ "xtick.major.width": 1.4,
79
+ "ytick.major.width": 1.4,
80
+ })
81
+
82
+
83
+ def clean_axes(ax):
84
+ ax.grid(True, which="major", axis="both")
85
+ ax.spines["top"].set_visible(False)
86
+ ax.spines["right"].set_visible(False)
87
+ return ax
88
+
89
+
90
+ def figure_size(preset: str) -> Tuple[float, float]:
91
+ presets = {
92
+ "single": (3.45, 2.60),
93
+ "single_tall": (3.45, 3.20),
94
+ "double": (7.10, 2.90),
95
+ "double_tall": (7.10, 3.80),
96
+ "square": (4.00, 4.00),
97
+ "wide": (7.10, 2.40),
98
+ }
99
+ return presets[preset]
100
+
101
+
102
+ # =========================
103
+ # Loading: csv / json / npz / npy
104
+ # =========================
105
+ def load_to_df(uploaded_file) -> pd.DataFrame:
106
+ name = uploaded_file.name
107
+ ext = os.path.splitext(name)[1].lower()
108
+ data = uploaded_file.getvalue()
109
+
110
+ if ext == ".csv":
111
+ return pd.read_csv(io.BytesIO(data))
112
+
113
+ if ext == ".json":
114
+ obj = json.loads(data.decode("utf-8"))
115
+ if isinstance(obj, dict):
116
+ return pd.DataFrame(obj)
117
+ if isinstance(obj, list):
118
+ return pd.DataFrame(obj)
119
+ raise ValueError("Unsupported JSON: use dict-of-lists or list-of-dicts.")
120
+
121
+ if ext == ".npz":
122
+ z = np.load(io.BytesIO(data), allow_pickle=True)
123
+ cols: Dict[str, Any] = {k: z[k] for k in z.files}
124
+ # try to flatten 1D arrays into columns
125
+ df = pd.DataFrame()
126
+ for k, v in cols.items():
127
+ v = np.asarray(v)
128
+ if v.ndim == 1:
129
+ df[k] = v
130
+ if len(df.columns) == 0:
131
+ raise ValueError(".npz has no 1D arrays to treat as columns.")
132
+ return df
133
+
134
+ if ext == ".npy":
135
+ arr = np.load(io.BytesIO(data), allow_pickle=True)
136
+ arr = np.asarray(arr)
137
+ if arr.dtype.names:
138
+ return pd.DataFrame({n: arr[n] for n in arr.dtype.names})
139
+ if arr.ndim == 1:
140
+ return pd.DataFrame({"y": arr})
141
+ if arr.ndim == 2:
142
+ # columns: y0,y1,...
143
+ return pd.DataFrame(arr, columns=[f"y{i}" for i in range(arr.shape[1])])
144
+ raise ValueError("Unsupported .npy shape. Use 1D or 2D array or structured array.")
145
+
146
+ raise ValueError(f"Unsupported file extension: {ext}")
147
+
148
+
149
+ # =========================
150
+ # Aggregation for error bars
151
+ # =========================
152
+ def aggregate_xy(x: np.ndarray, y: np.ndarray, mode: str):
153
+ # groups by exact x
154
+ df = pd.DataFrame({"x": x, "y": y}).dropna()
155
+ g = df.groupby("x")["y"]
156
+ mean = g.mean()
157
+ if mode == "std":
158
+ err = g.std(ddof=1).fillna(0.0)
159
+ elif mode == "sem":
160
+ err = (g.std(ddof=1) / np.sqrt(g.count())).fillna(0.0)
161
+ else:
162
+ err = pd.Series(0.0, index=mean.index)
163
+ xu = mean.index.to_numpy()
164
+ return xu, mean.to_numpy(), err.to_numpy()
165
+
166
+
167
+ # =========================
168
+ # Plotting
169
+ # =========================
170
+ def make_plot(
171
+ df: pd.DataFrame,
172
+ kind: str,
173
+ xcol: Optional[str],
174
+ ycols: List[str],
175
+ hue: Optional[str],
176
+ agg: str,
177
+ fill_band: bool,
178
+ title: str,
179
+ xlabel: str,
180
+ ylabel: str,
181
+ logx: bool,
182
+ logy: bool,
183
+ legend_mode: str,
184
+ size_preset: str,
185
+ hist_bins: int,
186
+ hist_density: bool,
187
+ exaggerated_text: bool,
188
+ ):
189
+ set_paper_style(exaggerated=exaggerated_text)
190
+ w, h = figure_size(size_preset)
191
+ fig, ax = plt.subplots(figsize=(w, h), constrained_layout=True)
192
+ colors = palette()
193
+
194
+ def _plot_series(label, x, y, color):
195
+ if kind == "line":
196
+ if agg in ("std", "sem"):
197
+ xu, ym, ye = aggregate_xy(x, y, agg)
198
+ ax.plot(xu, ym, marker="o", label=label, color=color)
199
+ if fill_band and np.any(ye > 0):
200
+ ax.fill_between(xu, ym - ye, ym + ye, alpha=0.18, color=color, linewidth=0)
201
+ else:
202
+ ax.plot(x, y, marker="o", label=label, color=color)
203
+
204
+ elif kind == "scatter":
205
+ ax.scatter(x, y, label=label, color=color, s=52, alpha=0.85, edgecolors="none")
206
+
207
+ elif kind == "bar":
208
+ # category bars: mean per category
209
+ tmp = pd.DataFrame({"x": x, "y": y}).dropna()
210
+ means = tmp.groupby("x")["y"].mean()
211
+ xs = means.index.tolist()
212
+ ys = means.values
213
+ # stable positions
214
+ pos = np.arange(len(xs))
215
+ ax.bar(pos, ys, label=label, color=color)
216
+ ax.set_xticks(pos, xs)
217
+
218
+ elif kind == "hist":
219
+ ax.hist(np.asarray(y, dtype=float), bins=hist_bins, density=hist_density,
220
+ alpha=0.35, label=label, color=color)
221
+
222
+ if kind != "hist":
223
+ assert xcol is not None
224
+ x = df[xcol].to_numpy()
225
+ # hue grouping
226
+ if hue and hue in df.columns:
227
+ groups = df[hue].astype(str).unique().tolist()
228
+ ci = 0
229
+ for g in groups:
230
+ sub = df[df[hue].astype(str) == g]
231
+ gx = sub[xcol].to_numpy()
232
+ for yc in ycols:
233
+ _plot_series(f"{yc} | {hue}={g}", gx, sub[yc].to_numpy(), colors[ci % len(colors)])
234
+ ci += 1
235
+ else:
236
+ for i, yc in enumerate(ycols):
237
+ _plot_series(yc, x, df[yc].to_numpy(), colors[i % len(colors)])
238
+ else:
239
+ for i, yc in enumerate(ycols):
240
+ _plot_series(yc, None, df[yc].to_numpy(), colors[i % len(colors)])
241
+
242
+ clean_axes(ax)
243
+ if title.strip():
244
+ ax.set_title(title)
245
+ if kind != "hist":
246
+ ax.set_xlabel(xlabel if xlabel.strip() else xcol)
247
+ else:
248
+ ax.set_xlabel(xlabel if xlabel.strip() else "")
249
+ ax.set_ylabel(ylabel if ylabel.strip() else (", ".join(ycols) if ycols else ""))
250
+
251
+ if logx and kind != "hist":
252
+ ax.set_xscale("log")
253
+ if logy:
254
+ ax.set_yscale("log")
255
+
256
+ if legend_mode == "none":
257
+ if ax.get_legend() is not None:
258
+ ax.get_legend().remove()
259
+ elif legend_mode == "outside":
260
+ ax.legend(loc="center left", bbox_to_anchor=(1.02, 0.5), frameon=False)
261
+ else:
262
+ ax.legend(loc="best", frameon=False)
263
+
264
+ return fig
265
+
266
+
267
+ def fig_to_bytes(fig, fmt: str) -> bytes:
268
+ buf = io.BytesIO()
269
+ fig.savefig(buf, format=fmt)
270
+ buf.seek(0)
271
+ return buf.read()
272
+
273
+
274
+ # =========================
275
+ # Streamlit UI
276
+ # =========================
277
+ st.set_page_config(page_title="PaperPlot (Matplotlib)", layout="wide")
278
+ st.title("PaperPlot: upload data → tweak params → live preview → export")
279
+
280
+ left, right = st.columns([1, 2])
281
+
282
+ with left:
283
+ uploaded = st.file_uploader("Upload data", type=["csv", "json", "npz", "npy"])
284
+ st.caption("Supported: .csv / .json / .npz / .npy")
285
+
286
+ kind = st.selectbox("Plot kind", ["line", "scatter", "bar", "hist"], index=0)
287
+ exaggerated_text = st.toggle("Exaggerate text (paper readability)", value=True)
288
+
289
+ size_preset = st.selectbox(
290
+ "Figure size preset",
291
+ ["single", "single_tall", "double", "double_tall", "square", "wide"],
292
+ index=0
293
+ )
294
+
295
+ title = st.text_input("Title", value="")
296
+ xlabel = st.text_input("X label (optional)", value="")
297
+ ylabel = st.text_input("Y label (optional)", value="")
298
+
299
+ logx = st.toggle("Log X", value=False)
300
+ logy = st.toggle("Log Y", value=False)
301
+
302
+ legend_mode = st.selectbox("Legend", ["best", "outside", "none"], index=0)
303
+
304
+ agg = st.selectbox("Aggregate repeated x (line only)", ["none", "std", "sem"], index=0)
305
+ fill_band = st.toggle("Show error band (line + agg)", value=True)
306
+
307
+ hist_bins = st.slider("Hist bins", 5, 200, 30)
308
+ hist_density = st.toggle("Hist density", value=True)
309
+
310
+ with right:
311
+ if not uploaded:
312
+ st.info("Upload a dataset to start.")
313
+ st.stop()
314
+
315
+ try:
316
+ df = load_to_df(uploaded)
317
+ except Exception as e:
318
+ st.error(f"Failed to load file: {e}")
319
+ st.stop()
320
+
321
+ st.subheader("Data preview")
322
+ st.dataframe(df.head(50), use_container_width=True)
323
+
324
+ cols = df.columns.tolist()
325
+ numeric_cols = [c for c in cols if pd.api.types.is_numeric_dtype(df[c])]
326
+
327
+ if kind != "hist":
328
+ xcol = st.selectbox("X column", options=numeric_cols if numeric_cols else cols)
329
+ else:
330
+ xcol = None
331
+
332
+ if numeric_cols:
333
+ default_y = numeric_cols[:1]
334
+ else:
335
+ default_y = cols[:1]
336
+
337
+ ycols = st.multiselect("Y column(s)", options=numeric_cols if numeric_cols else cols, default=default_y)
338
+
339
+ hue = None
340
+ if kind != "hist":
341
+ hue = st.selectbox("Group / hue (optional)", options=["(none)"] + cols, index=0)
342
+ hue = None if hue == "(none)" else hue
343
+
344
+ if not ycols:
345
+ st.warning("Pick at least one Y column.")
346
+ st.stop()
347
+
348
+ fig = make_plot(
349
+ df=df,
350
+ kind=kind,
351
+ xcol=xcol,
352
+ ycols=ycols,
353
+ hue=hue,
354
+ agg=agg if kind == "line" else "none",
355
+ fill_band=fill_band,
356
+ title=title,
357
+ xlabel=xlabel,
358
+ ylabel=ylabel,
359
+ logx=logx,
360
+ logy=logy,
361
+ legend_mode=legend_mode,
362
+ size_preset=size_preset,
363
+ hist_bins=hist_bins,
364
+ hist_density=hist_density,
365
+ exaggerated_text=exaggerated_text,
366
+ )
367
+
368
+ st.subheader("Live preview")
369
+ st.pyplot(fig, use_container_width=True)
370
+
371
+ c1, c2 = st.columns(2)
372
+ with c1:
373
+ st.download_button(
374
+ "Download PDF",
375
+ data=fig_to_bytes(fig, "pdf"),
376
+ file_name="figure.pdf",
377
+ mime="application/pdf",
378
+ )
379
+ with c2:
380
+ st.download_button(
381
+ "Download PNG",
382
+ data=fig_to_bytes(fig, "png"),
383
+ file_name="figure.png",
384
+ mime="image/png",
385
+ )