import gradio as gr from typing import Any, Dict, List from plotly.graph_objs import Figure, Scatter import pandas as pd import datetime as dt import numpy as np import pandas as pd from common import get_db import plotly.express as px MAX_CHARTS_IN_PAGE = 40 NUM_COLS = 4 def coerce_to_number(val): """Try converting to int/float, else return original string.""" if val is None: return None try: # First try integer i = int(val) return i except (ValueError, TypeError): try: # Then try float f = float(val) return f except (ValueError, TypeError): return val # fallback to original (string, unit, etc.) def build_trend_figure(trend_doc: Dict[str, Any]) -> Figure: """Make a Plotly line chart for a single test's trend_data with optional reference ranges.""" points = trend_doc.get("trend_data", []) ref = trend_doc.get("test_reference_range") or {} # safe default {} if not points: fig = Figure() fig.update_layout( title="No trend data", xaxis_title="Date", yaxis_title="Value" ) return fig # Parse dates and values date_value_pairs = [] for p in points: date = pd.to_datetime(p.get("date"), errors="coerce") value = coerce_to_number(p.get("value")) if pd.notna(date) and value is not None: date_value_pairs.append((date, value)) # Sort by date date_value_pairs.sort(key=lambda x: x[0]) dates, values = zip(*date_value_pairs) if date_value_pairs else ([], []) fig = Figure() # === Reference Range Logic (only if present) === ref_min = coerce_to_number(ref.get("min")) if ref else None ref_max = coerce_to_number(ref.get("max")) if ref else None if ref_min is not None and ref_max is not None: fig.add_shape( type="rect", x0=min(dates), x1=max(dates), y0=ref_min, y1=ref_max, fillcolor="rgba(0,200,0,0.1)", # light green line=dict(width=0), layer="below", ) elif ref_min is not None: fig.add_trace( Scatter( x=[min(dates), max(dates)], y=[ref_min, ref_min], mode="lines", name="Min Ref", line=dict(color="green", dash="dot"), ) ) elif ref_max is not None: fig.add_trace( Scatter( x=[min(dates), max(dates)], y=[ref_max, ref_max], mode="lines", name="Max Ref", line=dict(color="red", dash="dot"), ) ) # === Actual Trend Data === fig.add_trace( Scatter( x=dates, y=values, mode="lines+markers", name=trend_doc.get("test_name", "Trend"), ) ) fig.update_layout( margin=dict(l=30, r=20, t=40, b=30), xaxis_title="Date", yaxis_title="Value", title=f"{trend_doc.get('test_name','')}", ) fig.update_yaxes(autorange=True) fig.update_xaxes( autorange=True, tickformat="%Y-%m-%d", tickangle=-45, tickmode="auto" ) return sanitize_plotly_figure(fig) async def load_all_trend_figures(patient_id: str): """Fetch all test trend docs and return list of Plot figures.""" if not patient_id: return [] db = get_db() cursor = db.trends.find({"patient_id": __import__("bson").ObjectId(patient_id)}) docs = await cursor.to_list(length=None) figures = [build_trend_figure(doc) for doc in docs if doc] return figures async def update_trends(patient_id, page=0, num_cols=NUM_COLS): figures = await load_all_trend_figures(patient_id) total_pages = (len(figures) - 1) // MAX_CHARTS_IN_PAGE + 1 start = page * MAX_CHARTS_IN_PAGE end = start + MAX_CHARTS_IN_PAGE page_figures = figures[start:end] outputs = [] for i in range(MAX_CHARTS_IN_PAGE): if i < len(page_figures): title = page_figures[i].layout.title.text page_figures[i].update_layout(title="") outputs.append(gr.update(value=page_figures[i], visible=True, label=title)) else: outputs.append(gr.update(visible=False, value=None, label="")) # Enable/disable buttons prev_disabled = page == 0 next_disabled = page >= total_pages - 1 # return as separate outputs + page + page info return ( *outputs, # plots page, # page number f"Page {page+1} / {total_pages}", # page info gr.update(interactive=not prev_disabled), # Prev button gr.update(interactive=not next_disabled), ) # Next button async def reset_trends(): """ Clears all trend plots and resets page info. Returns a list of gr.update(...) objects matching the outputs of update_trends. """ outputs = [] for _ in range(MAX_CHARTS_IN_PAGE): outputs.append(gr.update(visible=False, value=None, label="")) # Reset page number and page info page = 0 page_info = "Page 0 / 0" return ( *outputs, page, page_info, gr.update(interactive=False), gr.update(interactive=False), ) def reset_vitals_plots(): """ Clears all vitals plots and resets page info. Returns a list of gr.update(...) objects matching the outputs of update_trends. """ outputs = [] for _ in range(20): outputs.append(gr.update(visible=False, value=None, label="")) return (*outputs,) def reset_latest_vitals_labels(): """ Clears all latest vitals labels and resets page info. Returns a list of gr.update(...) objects matching the outputs of update_trends. """ outputs = [] for _ in range(20): outputs.append(gr.update(visible=False, value=None, label="")) return (*outputs,) def _to_jsonable_dt(x): if isinstance(x, pd.Timestamp): return x.to_pydatetime() # or x.isoformat() if isinstance(x, np.datetime64): return pd.to_datetime(x).to_pydatetime() return x def sanitize_plotly_figure(fig): # traces (x/xbins/…) for tr in fig.data: if hasattr(tr, "x") and tr.x is not None: try: tr.x = [_to_jsonable_dt(v) for v in tr.x] except TypeError: # x may be a scalar tr.x = _to_jsonable_dt(tr.x) # shapes (x0/x1) if fig.layout.shapes: for s in list(fig.layout.shapes): if getattr(s, "x0", None) is not None: s.x0 = _to_jsonable_dt(s.x0) if getattr(s, "x1", None) is not None: s.x1 = _to_jsonable_dt(s.x1) # annotations (x) if fig.layout.annotations: for a in list(fig.layout.annotations): if getattr(a, "x", None) is not None: a.x = _to_jsonable_dt(a.x) # axes ranges (range can contain datetimes) if getattr(fig.layout, "xaxis", None) and getattr(fig.layout.xaxis, "range", None): fig.layout.xaxis.range = [_to_jsonable_dt(v) for v in fig.layout.xaxis.range] return fig def next_page(page, figures_len): total_pages = (figures_len - 1) // MAX_CHARTS_IN_PAGE + 1 return min(page + 1, total_pages - 1) def prev_page(page): return max(page - 1, 0) async def render_vitals_plot_layout(patient_id): docs = await get_db().get_vitals_trends_by_patient(patient_id) figures = [build_trend_figure(doc) for doc in docs if doc] # Pad/truncate to exactly 20 charts if len(figures) > 20: figures = figures[:20] elif len(figures) < 20: while len(figures) < 20: empty_fig = Figure() empty_fig.update_layout( title="No Data", xaxis=dict(visible=False), yaxis=dict(visible=False), margin=dict(l=30, r=20, t=40, b=30), ) figures.append(empty_fig) plots = [] for fig in figures: plots.append(gr.Plot(value=fig, label=fig.layout.title.text)) fig.update_layout(title=None) return plots