| | |
| | import io |
| | import json |
| | import os |
| | from typing import Dict, Any, Optional, Tuple, List |
| |
|
| | import numpy as np |
| | import pandas as pd |
| | import matplotlib.pyplot as plt |
| | import streamlit as st |
| |
|
| |
|
| | |
| | |
| | |
| | plt.rcParams["font.family"] = "monospace" |
| |
|
| | PRIMARY = np.array([166, 0, 0]) / 255 |
| | CONTRARY = np.array([0, 166, 166]) / 255 |
| | NEUTRAL_MEDIUM_GREY = np.array([128, 128, 128]) / 255 |
| | NEUTRAL_DARK_GREY = np.array([64, 64, 64]) / 255 |
| |
|
| |
|
| | def _mix(c1, c2, t: float): |
| | c1 = np.array(c1, dtype=float) |
| | c2 = np.array(c2, dtype=float) |
| | return (1 - t) * c1 + t * c2 |
| |
|
| |
|
| | def palette(): |
| | white = np.array([1.0, 1.0, 1.0]) |
| | return [ |
| | PRIMARY, |
| | CONTRARY, |
| | NEUTRAL_DARK_GREY, |
| | NEUTRAL_MEDIUM_GREY, |
| | _mix(PRIMARY, white, 0.35), |
| | _mix(CONTRARY, white, 0.35), |
| | _mix(NEUTRAL_DARK_GREY, white, 0.45), |
| | _mix(NEUTRAL_MEDIUM_GREY, white, 0.35), |
| | ] |
| |
|
| |
|
| | def set_paper_style(exaggerated: bool = True): |
| | if exaggerated: |
| | base = 18 |
| | label = 22 |
| | title = 24 |
| | tick = 18 |
| | legend = 18 |
| | else: |
| | base = 12 |
| | label = 14 |
| | title = 16 |
| | tick = 12 |
| | legend = 12 |
| |
|
| | plt.rcParams.update({ |
| | "font.size": base, |
| | "axes.titlesize": title, |
| | "axes.labelsize": label, |
| | "xtick.labelsize": tick, |
| | "ytick.labelsize": tick, |
| | "legend.fontsize": legend, |
| | "axes.linewidth": 1.6, |
| | "lines.linewidth": 2.8, |
| | "lines.markersize": 7.0, |
| | "grid.alpha": 0.25, |
| | "grid.linewidth": 1.0, |
| | "figure.dpi": 120, |
| | "savefig.dpi": 600, |
| | "savefig.bbox": "tight", |
| | "savefig.pad_inches": 0.03, |
| | "xtick.direction": "out", |
| | "ytick.direction": "out", |
| | "xtick.major.size": 6.0, |
| | "ytick.major.size": 6.0, |
| | "xtick.major.width": 1.4, |
| | "ytick.major.width": 1.4, |
| | }) |
| |
|
| |
|
| | def clean_axes(ax): |
| | ax.grid(True, which="major", axis="both") |
| | ax.spines["top"].set_visible(False) |
| | ax.spines["right"].set_visible(False) |
| | return ax |
| |
|
| |
|
| | def figure_size(preset: str) -> Tuple[float, float]: |
| | presets = { |
| | "single": (3.45, 2.60), |
| | "single_tall": (3.45, 3.20), |
| | "double": (7.10, 2.90), |
| | "double_tall": (7.10, 3.80), |
| | "square": (4.00, 4.00), |
| | "wide": (7.10, 2.40), |
| | } |
| | return presets[preset] |
| |
|
| |
|
| | |
| | |
| | |
| | def load_to_df(uploaded_file) -> pd.DataFrame: |
| | name = uploaded_file.name |
| | ext = os.path.splitext(name)[1].lower() |
| | data = uploaded_file.getvalue() |
| |
|
| | if ext == ".csv": |
| | return pd.read_csv(io.BytesIO(data)) |
| |
|
| | if ext == ".json": |
| | obj = json.loads(data.decode("utf-8")) |
| | if isinstance(obj, dict): |
| | return pd.DataFrame(obj) |
| | if isinstance(obj, list): |
| | return pd.DataFrame(obj) |
| | raise ValueError("Unsupported JSON: use dict-of-lists or list-of-dicts.") |
| |
|
| | if ext == ".npz": |
| | z = np.load(io.BytesIO(data), allow_pickle=True) |
| | cols: Dict[str, Any] = {k: z[k] for k in z.files} |
| | |
| | df = pd.DataFrame() |
| | for k, v in cols.items(): |
| | v = np.asarray(v) |
| | if v.ndim == 1: |
| | df[k] = v |
| | if len(df.columns) == 0: |
| | raise ValueError(".npz has no 1D arrays to treat as columns.") |
| | return df |
| |
|
| | if ext == ".npy": |
| | arr = np.load(io.BytesIO(data), allow_pickle=True) |
| | arr = np.asarray(arr) |
| | if arr.dtype.names: |
| | return pd.DataFrame({n: arr[n] for n in arr.dtype.names}) |
| | if arr.ndim == 1: |
| | return pd.DataFrame({"y": arr}) |
| | if arr.ndim == 2: |
| | |
| | return pd.DataFrame(arr, columns=[f"y{i}" for i in range(arr.shape[1])]) |
| | raise ValueError("Unsupported .npy shape. Use 1D or 2D array or structured array.") |
| |
|
| | raise ValueError(f"Unsupported file extension: {ext}") |
| |
|
| |
|
| | |
| | |
| | |
| | def aggregate_xy(x: np.ndarray, y: np.ndarray, mode: str): |
| | |
| | df = pd.DataFrame({"x": x, "y": y}).dropna() |
| | g = df.groupby("x")["y"] |
| | mean = g.mean() |
| | if mode == "std": |
| | err = g.std(ddof=1).fillna(0.0) |
| | elif mode == "sem": |
| | err = (g.std(ddof=1) / np.sqrt(g.count())).fillna(0.0) |
| | else: |
| | err = pd.Series(0.0, index=mean.index) |
| | xu = mean.index.to_numpy() |
| | return xu, mean.to_numpy(), err.to_numpy() |
| |
|
| |
|
| | |
| | |
| | |
| | def make_plot( |
| | df: pd.DataFrame, |
| | kind: str, |
| | xcol: Optional[str], |
| | ycols: List[str], |
| | hue: Optional[str], |
| | agg: str, |
| | fill_band: bool, |
| | title: str, |
| | xlabel: str, |
| | ylabel: str, |
| | logx: bool, |
| | logy: bool, |
| | legend_mode: str, |
| | size_preset: str, |
| | hist_bins: int, |
| | hist_density: bool, |
| | exaggerated_text: bool, |
| | ): |
| | set_paper_style(exaggerated=exaggerated_text) |
| | w, h = figure_size(size_preset) |
| | fig, ax = plt.subplots(figsize=(w, h), constrained_layout=True) |
| | colors = palette() |
| |
|
| | def _plot_series(label, x, y, color): |
| | if kind == "line": |
| | if agg in ("std", "sem"): |
| | xu, ym, ye = aggregate_xy(x, y, agg) |
| | ax.plot(xu, ym, marker="o", label=label, color=color) |
| | if fill_band and np.any(ye > 0): |
| | ax.fill_between(xu, ym - ye, ym + ye, alpha=0.18, color=color, linewidth=0) |
| | else: |
| | ax.plot(x, y, marker="o", label=label, color=color) |
| |
|
| | elif kind == "scatter": |
| | ax.scatter(x, y, label=label, color=color, s=52, alpha=0.85, edgecolors="none") |
| |
|
| | elif kind == "bar": |
| | |
| | tmp = pd.DataFrame({"x": x, "y": y}).dropna() |
| | means = tmp.groupby("x")["y"].mean() |
| | xs = means.index.tolist() |
| | ys = means.values |
| | |
| | pos = np.arange(len(xs)) |
| | ax.bar(pos, ys, label=label, color=color) |
| | ax.set_xticks(pos, xs) |
| |
|
| | elif kind == "hist": |
| | ax.hist(np.asarray(y, dtype=float), bins=hist_bins, density=hist_density, |
| | alpha=0.35, label=label, color=color) |
| |
|
| | if kind != "hist": |
| | assert xcol is not None |
| | x = df[xcol].to_numpy() |
| | |
| | if hue and hue in df.columns: |
| | groups = df[hue].astype(str).unique().tolist() |
| | ci = 0 |
| | for g in groups: |
| | sub = df[df[hue].astype(str) == g] |
| | gx = sub[xcol].to_numpy() |
| | for yc in ycols: |
| | _plot_series(f"{yc} | {hue}={g}", gx, sub[yc].to_numpy(), colors[ci % len(colors)]) |
| | ci += 1 |
| | else: |
| | for i, yc in enumerate(ycols): |
| | _plot_series(yc, x, df[yc].to_numpy(), colors[i % len(colors)]) |
| | else: |
| | for i, yc in enumerate(ycols): |
| | _plot_series(yc, None, df[yc].to_numpy(), colors[i % len(colors)]) |
| |
|
| | clean_axes(ax) |
| | if title.strip(): |
| | ax.set_title(title) |
| | if kind != "hist": |
| | ax.set_xlabel(xlabel if xlabel.strip() else xcol) |
| | else: |
| | ax.set_xlabel(xlabel if xlabel.strip() else "") |
| | ax.set_ylabel(ylabel if ylabel.strip() else (", ".join(ycols) if ycols else "")) |
| |
|
| | if logx and kind != "hist": |
| | ax.set_xscale("log") |
| | if logy: |
| | ax.set_yscale("log") |
| |
|
| | if legend_mode == "none": |
| | if ax.get_legend() is not None: |
| | ax.get_legend().remove() |
| | elif legend_mode == "outside": |
| | ax.legend(loc="center left", bbox_to_anchor=(1.02, 0.5), frameon=False) |
| | else: |
| | ax.legend(loc="best", frameon=False) |
| |
|
| | return fig |
| |
|
| |
|
| | def fig_to_bytes(fig, fmt: str) -> bytes: |
| | buf = io.BytesIO() |
| | fig.savefig(buf, format=fmt) |
| | buf.seek(0) |
| | return buf.read() |
| |
|
| |
|
| | |
| | |
| | |
| | st.set_page_config(page_title="PaperPlot (Matplotlib)", layout="wide") |
| | st.title("PaperPlot: upload data → tweak params → live preview → export") |
| |
|
| | left, right = st.columns([1, 2]) |
| |
|
| | with left: |
| | uploaded = st.file_uploader("Upload data", type=["csv", "json", "npz", "npy"]) |
| | st.caption("Supported: .csv / .json / .npz / .npy") |
| |
|
| | kind = st.selectbox("Plot kind", ["line", "scatter", "bar", "hist"], index=0) |
| | exaggerated_text = st.toggle("Exaggerate text (paper readability)", value=True) |
| |
|
| | size_preset = st.selectbox( |
| | "Figure size preset", |
| | ["single", "single_tall", "double", "double_tall", "square", "wide"], |
| | index=0 |
| | ) |
| |
|
| | title = st.text_input("Title", value="") |
| | xlabel = st.text_input("X label (optional)", value="") |
| | ylabel = st.text_input("Y label (optional)", value="") |
| |
|
| | logx = st.toggle("Log X", value=False) |
| | logy = st.toggle("Log Y", value=False) |
| |
|
| | legend_mode = st.selectbox("Legend", ["best", "outside", "none"], index=0) |
| |
|
| | agg = st.selectbox("Aggregate repeated x (line only)", ["none", "std", "sem"], index=0) |
| | fill_band = st.toggle("Show error band (line + agg)", value=True) |
| |
|
| | hist_bins = st.slider("Hist bins", 5, 200, 30) |
| | hist_density = st.toggle("Hist density", value=True) |
| |
|
| | with right: |
| | if not uploaded: |
| | st.info("Upload a dataset to start.") |
| | st.stop() |
| |
|
| | try: |
| | df = load_to_df(uploaded) |
| | except Exception as e: |
| | st.error(f"Failed to load file: {e}") |
| | st.stop() |
| |
|
| | st.subheader("Data preview") |
| | st.dataframe(df.head(50), use_container_width=True) |
| |
|
| | cols = df.columns.tolist() |
| | numeric_cols = [c for c in cols if pd.api.types.is_numeric_dtype(df[c])] |
| |
|
| | if kind != "hist": |
| | xcol = st.selectbox("X column", options=numeric_cols if numeric_cols else cols) |
| | else: |
| | xcol = None |
| |
|
| | if numeric_cols: |
| | default_y = numeric_cols[:1] |
| | else: |
| | default_y = cols[:1] |
| |
|
| | ycols = st.multiselect("Y column(s)", options=numeric_cols if numeric_cols else cols, default=default_y) |
| |
|
| | hue = None |
| | if kind != "hist": |
| | hue = st.selectbox("Group / hue (optional)", options=["(none)"] + cols, index=0) |
| | hue = None if hue == "(none)" else hue |
| |
|
| | if not ycols: |
| | st.warning("Pick at least one Y column.") |
| | st.stop() |
| |
|
| | fig = make_plot( |
| | df=df, |
| | kind=kind, |
| | xcol=xcol, |
| | ycols=ycols, |
| | hue=hue, |
| | agg=agg if kind == "line" else "none", |
| | fill_band=fill_band, |
| | title=title, |
| | xlabel=xlabel, |
| | ylabel=ylabel, |
| | logx=logx, |
| | logy=logy, |
| | legend_mode=legend_mode, |
| | size_preset=size_preset, |
| | hist_bins=hist_bins, |
| | hist_density=hist_density, |
| | exaggerated_text=exaggerated_text, |
| | ) |
| |
|
| | st.subheader("Live preview") |
| | st.pyplot(fig, use_container_width=True) |
| |
|
| | c1, c2 = st.columns(2) |
| | with c1: |
| | st.download_button( |
| | "Download PDF", |
| | data=fig_to_bytes(fig, "pdf"), |
| | file_name="figure.pdf", |
| | mime="application/pdf", |
| | ) |
| | with c2: |
| | st.download_button( |
| | "Download PNG", |
| | data=fig_to_bytes(fig, "png"), |
| | file_name="figure.png", |
| | mime="image/png", |
| | ) |
| |
|