Spaces:
Sleeping
Sleeping
| import json | |
| import os | |
| import traceback | |
| import urllib.parse | |
| from datetime import datetime, timezone | |
| from typing import Any | |
| import numpy as np | |
| import pandas as pd | |
| import gradio as gr | |
| from gradio.routes import App as GradioApp | |
| from starlette.applications import Starlette | |
| from starlette.requests import Request | |
| from starlette.responses import JSONResponse, PlainTextResponse, Response | |
| from starlette.routing import Route, Mount | |
| import uvicorn | |
| # --------------------------------------------------------------------- | |
| # Config | |
| # --------------------------------------------------------------------- | |
| DATASET_REPO = os.environ.get("DATASET_REPO", "yhayashi1986/PolyOmics") | |
| HF_BASE_URL = f"https://huggingface.co/datasets/{DATASET_REPO}/resolve/main" | |
| SOLVENTS = [ | |
| "Water", "Methanol", "Ethanol", "Acetone", "Benzene", "Toluene", | |
| "Chloroform", "Hexane", "Cyclohexane", "DMSO", "Tetrahydrofuran", | |
| "Dimethylformamide", "Ethyl_acetate", "Diethyl_ether", "1,4-Dioxane", | |
| "DEGDB", "DEHA", "DEHP", "Acetyl_tributyl_citrate", | |
| ] | |
| DATASET_FILES = { | |
| "general": "general_polymers_with_sp_abbe_dynamic-dielectric.csv", | |
| "pfas": "PFAS.csv", | |
| "biodegradable": "biodegradable_candidate_polyester.csv", | |
| "cellulose": "cellulose.csv", | |
| "ladder": "ladder_polymers.csv", | |
| "small_molecules": "small_molecules.csv", | |
| "non_ladder": "corresponding_non_ladder_polymers.csv", | |
| } | |
| # Known numeric property columns β used as defaults for statistics / filtering | |
| NUMERIC_PROPERTY_COLUMNS = [ | |
| "density", "thermal_conductivity", "thermal_diffusivity", | |
| "tg", "refractive_index", "static_dielectric_const", "dielectric_const_dc", | |
| "bulk_modulus", "isentropic_bulk_modulus", "compressibility", | |
| "volume_expansion", "linear_expansion", "Cp", "Cv", | |
| "sp_total", "sp_vdw", "sp_ele", "sp_ced", | |
| "mol_weight", "Mn", "Mw", "Rg", "self-diffusion", | |
| "abbe_number_sos", "refractive_index_sos", | |
| "efdp_permittivity_real", "efdp_permittivity_imaginary", "efdp_dielectric_loss_tan", | |
| "nematic_order_parameter", | |
| ] | |
| # Default compact field set returned by search_polymers | |
| _SEARCH_DEFAULT_FIELDS = [ | |
| "UUID", "smiles_list", "polymer_class", "_source", | |
| "density", "thermal_conductivity", "thermal_diffusivity", | |
| "tg", "refractive_index", "static_dielectric_const", | |
| "bulk_modulus", "sp_total", "abbe_number_sos", | |
| ] | |
| # --------------------------------------------------------------------- | |
| # Data layer | |
| # --------------------------------------------------------------------- | |
| _datasets: dict[str, pd.DataFrame] = {} | |
| _main_df: pd.DataFrame | None = None | |
| _chi_dfs: dict[str, pd.DataFrame | None] = {} | |
| def _hf_url(path_in_repo: str) -> str: | |
| encoded = urllib.parse.quote(path_in_repo, safe="/") | |
| return f"{HF_BASE_URL}/{encoded}" | |
| def _load_csv(path_in_repo: str) -> pd.DataFrame | None: | |
| url = _hf_url(path_in_repo) | |
| try: | |
| return pd.read_csv(url, low_memory=False) | |
| except Exception as e: | |
| print(f"[WARN] Failed to load {path_in_repo}: {e}") | |
| return None | |
| def preload_data() -> None: | |
| global _main_df | |
| print("Loading polymer datasets ...") | |
| for key, fname in DATASET_FILES.items(): | |
| df = _load_csv(fname) | |
| if df is not None: | |
| df["_source"] = key | |
| # Coerce known numeric columns at load time for reliable downstream ops | |
| for col in NUMERIC_PROPERTY_COLUMNS: | |
| if col in df.columns: | |
| df[col] = pd.to_numeric(df[col], errors="coerce") | |
| _datasets[key] = df | |
| print(f" [{key}] {len(df):,} rows, {len(df.columns)} cols") | |
| else: | |
| print(f" [{key}] FAILED") | |
| if "general" in _datasets: | |
| _main_df = _datasets["general"] | |
| elif _datasets: | |
| _main_df = pd.concat(list(_datasets.values()), ignore_index=True) | |
| else: | |
| _main_df = pd.DataFrame() | |
| print("Loading chi parameter datasets ...") | |
| loaded = [] | |
| for solvent in SOLVENTS: | |
| df = _load_csv(f"chi_parameter/Summary_{solvent}.csv") | |
| _chi_dfs[solvent] = df | |
| if df is not None: | |
| loaded.append(solvent) | |
| print(f"Chi parameter solvents loaded: {loaded}") | |
| def get_main_df() -> pd.DataFrame: | |
| return _main_df if _main_df is not None else pd.DataFrame() | |
| def get_chi_df(solvent: str) -> pd.DataFrame | None: | |
| return _chi_dfs.get(solvent) | |
| # --------------------------------------------------------------------- | |
| # JSON / serialisation helpers | |
| # --------------------------------------------------------------------- | |
| def _safe_value(v: Any) -> Any: | |
| """Convert non-JSON-serialisable values to Python-native types.""" | |
| if isinstance(v, float) and (np.isnan(v) or np.isinf(v)): | |
| return None | |
| if isinstance(v, np.integer): | |
| return int(v) | |
| if isinstance(v, np.floating): | |
| return float(v) | |
| return v | |
| def records_to_json(df: pd.DataFrame, fields: list[str] | None = None) -> list[dict[str, Any]]: | |
| """Convert a DataFrame to a clean list of JSON-serialisable dicts.""" | |
| if df.empty: | |
| return [] | |
| if fields: | |
| existing = [f for f in fields if f in df.columns] | |
| df = df[existing] | |
| return [{k: _safe_value(v) for k, v in row.items()} for row in df.to_dict(orient="records")] | |
| def stringify_records(records: list[dict[str, Any]]) -> str: | |
| if not records: | |
| return "No records found." | |
| return json.dumps(records, ensure_ascii=False, indent=2, default=str) | |
| def safe_jsonrpc_result(req_id: Any, result: Any) -> JSONResponse: | |
| return JSONResponse({"jsonrpc": "2.0", "id": req_id, "result": result}) | |
| def safe_jsonrpc_error(req_id: Any, code: int, message: str) -> JSONResponse: | |
| return JSONResponse({"jsonrpc": "2.0", "id": req_id, "error": {"code": code, "message": message}}) | |
| # --------------------------------------------------------------------- | |
| # Numeric filter engine | |
| # | |
| # Syntax: | |
| # "density>1.0" strict greater-than | |
| # "density>=1.0" greater-or-equal | |
| # "thermal_conductivity!=" column must be non-null | |
| # --------------------------------------------------------------------- | |
| _OPERATORS = [">=", "<=", "!=", ">", "<", "=="] | |
| def _apply_numeric_filters( | |
| df: pd.DataFrame, filters: list[str] | |
| ) -> tuple[pd.DataFrame, list[str]]: | |
| """Apply numeric range filters; return (filtered_df, warning_list).""" | |
| warnings: list[str] = [] | |
| for f in filters: | |
| f = f.strip() | |
| parsed = False | |
| for op in _OPERATORS: | |
| idx = f.find(op) | |
| if idx == -1: | |
| continue | |
| col = f[:idx].strip() | |
| val_str = f[idx + len(op):].strip() | |
| if col not in df.columns: | |
| warnings.append(f"Column '{col}' not found β filter '{f}' skipped.") | |
| parsed = True | |
| break | |
| if df[col].dtype == object: | |
| df = df.copy() | |
| df[col] = pd.to_numeric(df[col], errors="coerce") | |
| if op == "!=" and val_str == "": | |
| df = df[df[col].notna()] | |
| parsed = True | |
| break | |
| try: | |
| val = float(val_str) | |
| except ValueError: | |
| warnings.append(f"Cannot parse value '{val_str}' in filter '{f}' β skipped.") | |
| parsed = True | |
| break | |
| ops_map = { | |
| ">": df[col] > val, | |
| ">=": df[col] >= val, | |
| "<": df[col] < val, | |
| "<=": df[col] <= val, | |
| "==": df[col] == val, | |
| "!=": df[col] != val, | |
| } | |
| df = df[ops_map[op]] | |
| parsed = True | |
| break | |
| if not parsed: | |
| warnings.append(f"Could not parse filter '{f}' β skipped.") | |
| return df, warnings | |
| # --------------------------------------------------------------------- | |
| # Tool: list_datasets | |
| # --------------------------------------------------------------------- | |
| def tool_list_datasets() -> dict[str, Any]: | |
| """List all loaded datasets with row counts and column names.""" | |
| try: | |
| return { | |
| "datasets": { | |
| name: {"rows": int(len(df)), "columns": list(df.columns)} | |
| for name, df in _datasets.items() | |
| } | |
| } | |
| except Exception as e: | |
| return {"error": str(e)} | |
| # --------------------------------------------------------------------- | |
| # Tool: search_polymers | |
| # --------------------------------------------------------------------- | |
| def tool_search_polymers( | |
| query: str = "", | |
| dataset: str = "general", | |
| limit: int = 10, | |
| filters: list[str] | None = None, | |
| sort_by: str | None = None, | |
| sort_ascending: bool = True, | |
| fields: list[str] | None = None, | |
| require_numeric: list[str] | None = None, | |
| ) -> dict[str, Any]: | |
| """Search and filter polymers. | |
| - limit is capped at 1000; use the aggregation tools for whole-dataset analysis. | |
| - Empty query with filters returns all rows passing those filters. | |
| """ | |
| try: | |
| if dataset not in _datasets: | |
| return {"error": f"Unknown dataset: '{dataset}'. Available: {list(_datasets.keys())}"} | |
| limit = min(int(limit), 1000) | |
| df = _datasets[dataset].copy() | |
| q = (query or "").strip() | |
| if q: | |
| q_lower = q.lower() | |
| mask = pd.Series(False, index=df.index) | |
| for col in df.columns: | |
| try: | |
| if df[col].dtype == object: | |
| mask = mask | df[col].astype(str).str.lower().str.contains(q_lower, na=False) | |
| except Exception: | |
| continue | |
| df = df[mask] | |
| if require_numeric: | |
| filters = list(filters or []) + [f"{col}!=" for col in require_numeric] | |
| warnings: list[str] = [] | |
| if filters: | |
| df, warnings = _apply_numeric_filters(df, filters) | |
| total_after_filter = len(df) | |
| if sort_by and sort_by in df.columns: | |
| df = df.sort_values(by=sort_by, ascending=sort_ascending, na_position="last") | |
| if fields is None: | |
| fields = [f for f in _SEARCH_DEFAULT_FIELDS if f in df.columns] | |
| records = records_to_json(df.head(limit), fields=fields) | |
| return { | |
| "dataset": dataset, | |
| "query": query, | |
| "filters": filters or [], | |
| "total_matched": total_after_filter, | |
| "returned": len(records), | |
| "sort_by": sort_by, | |
| "warnings": warnings, | |
| "records": records, | |
| } | |
| except Exception as e: | |
| return {"error": str(e), "traceback": traceback.format_exc()} | |
| # --------------------------------------------------------------------- | |
| # Tool: get_dataset_columns | |
| # --------------------------------------------------------------------- | |
| def tool_get_dataset_columns(dataset: str = "general") -> dict[str, Any]: | |
| try: | |
| if dataset not in _datasets: | |
| return {"error": f"Unknown dataset: '{dataset}'"} | |
| df = _datasets[dataset] | |
| numeric_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])] | |
| return { | |
| "dataset": dataset, | |
| "num_columns": int(len(df.columns)), | |
| "columns": list(df.columns), | |
| "numeric_columns": numeric_cols, | |
| "known_property_columns": [c for c in NUMERIC_PROPERTY_COLUMNS if c in df.columns], | |
| } | |
| except Exception as e: | |
| return {"error": str(e)} | |
| # --------------------------------------------------------------------- | |
| # Tool: get_statistics β whole-dataset aggregation | |
| # --------------------------------------------------------------------- | |
| def tool_get_statistics( | |
| dataset: str = "general", | |
| columns: list[str] | None = None, | |
| filters: list[str] | None = None, | |
| ) -> dict[str, Any]: | |
| """Descriptive statistics over ALL rows (or a filtered subset). | |
| No per-row limit β aggregation runs server-side. | |
| """ | |
| try: | |
| if dataset not in _datasets: | |
| return {"error": f"Unknown dataset: '{dataset}'"} | |
| df = _datasets[dataset].copy() | |
| warnings: list[str] = [] | |
| if filters: | |
| df, warnings = _apply_numeric_filters(df, filters) | |
| if columns is None: | |
| columns = [c for c in NUMERIC_PROPERTY_COLUMNS if c in df.columns] | |
| stats: dict[str, Any] = {} | |
| for col in columns: | |
| if col not in df.columns: | |
| stats[col] = {"error": "column not found"} | |
| continue | |
| s = pd.to_numeric(df[col], errors="coerce").dropna() | |
| if s.empty: | |
| stats[col] = {"count": 0, "note": "no numeric data"} | |
| continue | |
| stats[col] = { | |
| "count": int(s.count()), | |
| "mean": round(float(s.mean()), 6), | |
| "std": round(float(s.std()), 6), | |
| "min": round(float(s.min()), 6), | |
| "p25": round(float(s.quantile(0.25)), 6), | |
| "median": round(float(s.median()), 6), | |
| "p75": round(float(s.quantile(0.75)), 6), | |
| "max": round(float(s.max()), 6), | |
| } | |
| return { | |
| "dataset": dataset, | |
| "total_rows": len(_datasets[dataset]), | |
| "rows_after_filter": len(df), | |
| "filters": filters or [], | |
| "warnings": warnings, | |
| "statistics": stats, | |
| } | |
| except Exception as e: | |
| return {"error": str(e), "traceback": traceback.format_exc()} | |
| # --------------------------------------------------------------------- | |
| # Tool: get_correlation β whole-dataset correlation | |
| # --------------------------------------------------------------------- | |
| def tool_get_correlation( | |
| dataset: str = "general", | |
| col_x: str = "density", | |
| col_y: str = "thermal_conductivity", | |
| filters: list[str] | None = None, | |
| sample_limit: int = 500, | |
| ) -> dict[str, Any]: | |
| """Pearson + Spearman correlation computed on ALL matching rows. | |
| Returns a random sample of up to `sample_limit` points for plotting. | |
| """ | |
| try: | |
| if dataset not in _datasets: | |
| return {"error": f"Unknown dataset: '{dataset}'"} | |
| df = _datasets[dataset].copy() | |
| warnings: list[str] = [] | |
| if filters: | |
| df, warnings = _apply_numeric_filters(df, filters) | |
| for col in [col_x, col_y]: | |
| if col not in df.columns: | |
| return {"error": f"Column '{col}' not found in dataset '{dataset}'."} | |
| df[col] = pd.to_numeric(df[col], errors="coerce") | |
| pair = df[[col_x, col_y]].dropna() | |
| n = len(pair) | |
| if n < 2: | |
| return { | |
| "dataset": dataset, "col_x": col_x, "col_y": col_y, "n_total": n, | |
| "warnings": warnings, "error": "Not enough data points after filtering.", | |
| } | |
| pearson_r = float(pair[col_x].corr(pair[col_y], method="pearson")) | |
| spearman_r = float(pair[col_x].corr(pair[col_y], method="spearman")) | |
| sample_n = min(sample_limit, n) | |
| sample = pair.sample(n=sample_n, random_state=42).sort_values(col_x) | |
| sample_points = [ | |
| {col_x: _safe_value(row[col_x]), col_y: _safe_value(row[col_y])} | |
| for _, row in sample.iterrows() | |
| ] | |
| return { | |
| "dataset": dataset, | |
| "col_x": col_x, | |
| "col_y": col_y, | |
| "n_total": n, | |
| "pearson_r": round(pearson_r, 4), | |
| "spearman_r": round(spearman_r, 4), | |
| "filters": filters or [], | |
| "warnings": warnings, | |
| "sample_n": sample_n, | |
| "sample_points": sample_points, | |
| } | |
| except Exception as e: | |
| return {"error": str(e), "traceback": traceback.format_exc()} | |
| # --------------------------------------------------------------------- | |
| # Tool: get_distribution β whole-dataset histogram | |
| # --------------------------------------------------------------------- | |
| def tool_get_distribution( | |
| dataset: str = "general", | |
| column: str = "density", | |
| bins: int = 20, | |
| filters: list[str] | None = None, | |
| ) -> dict[str, Any]: | |
| """Histogram (bin counts + frequencies) for a numeric column over ALL rows. | |
| No per-row limit β computed server-side. | |
| """ | |
| try: | |
| if dataset not in _datasets: | |
| return {"error": f"Unknown dataset: '{dataset}'"} | |
| df = _datasets[dataset].copy() | |
| warnings: list[str] = [] | |
| if filters: | |
| df, warnings = _apply_numeric_filters(df, filters) | |
| if column not in df.columns: | |
| return {"error": f"Column '{column}' not found in dataset '{dataset}'."} | |
| s = pd.to_numeric(df[column], errors="coerce").dropna() | |
| if s.empty: | |
| return {"error": f"No numeric data in column '{column}' after filtering."} | |
| bins = max(1, min(int(bins), 200)) | |
| counts, edges = np.histogram(s.values, bins=bins) | |
| histogram = [ | |
| { | |
| "bin_start": round(float(edges[i]), 6), | |
| "bin_end": round(float(edges[i + 1]), 6), | |
| "bin_mid": round(float((edges[i] + edges[i + 1]) / 2), 6), | |
| "count": int(counts[i]), | |
| "frequency": round(float(counts[i] / len(s)), 6), | |
| } | |
| for i in range(len(counts)) | |
| ] | |
| return { | |
| "dataset": dataset, | |
| "column": column, | |
| "total_rows": len(_datasets[dataset]), | |
| "rows_used": int(len(s)), | |
| "bins": bins, | |
| "filters": filters or [], | |
| "warnings": warnings, | |
| "min": round(float(s.min()), 6), | |
| "max": round(float(s.max()), 6), | |
| "mean": round(float(s.mean()), 6), | |
| "median": round(float(s.median()), 6), | |
| "std": round(float(s.std()), 6), | |
| "histogram": histogram, | |
| } | |
| except Exception as e: | |
| return {"error": str(e), "traceback": traceback.format_exc()} | |
| # --------------------------------------------------------------------- | |
| # Tool: get_group_stats β GROUP BY aggregation | |
| # --------------------------------------------------------------------- | |
| def tool_get_group_stats( | |
| dataset: str = "general", | |
| group_by: str = "polymer_class", | |
| value_column: str = "thermal_conductivity", | |
| filters: list[str] | None = None, | |
| min_group_size: int = 5, | |
| ) -> dict[str, Any]: | |
| """Aggregate a numeric column by a categorical column over ALL rows. | |
| Returns count, mean, std, min, p25, median, p75, max per group, | |
| sorted by descending count. | |
| """ | |
| try: | |
| if dataset not in _datasets: | |
| return {"error": f"Unknown dataset: '{dataset}'"} | |
| df = _datasets[dataset].copy() | |
| warnings: list[str] = [] | |
| if filters: | |
| df, warnings = _apply_numeric_filters(df, filters) | |
| for col in [group_by, value_column]: | |
| if col not in df.columns: | |
| return {"error": f"Column '{col}' not found in dataset '{dataset}'."} | |
| df[value_column] = pd.to_numeric(df[value_column], errors="coerce") | |
| df_clean = df[[group_by, value_column]].dropna(subset=[value_column]) | |
| groups: list[dict[str, Any]] = [] | |
| for name, grp in df_clean.groupby(group_by, sort=False): | |
| s = grp[value_column] | |
| if len(s) < min_group_size: | |
| continue | |
| groups.append({ | |
| "group": _safe_value(name), | |
| "count": int(len(s)), | |
| "mean": round(float(s.mean()), 6), | |
| "std": round(float(s.std()), 6), | |
| "min": round(float(s.min()), 6), | |
| "p25": round(float(s.quantile(0.25)), 6), | |
| "median": round(float(s.median()), 6), | |
| "p75": round(float(s.quantile(0.75)), 6), | |
| "max": round(float(s.max()), 6), | |
| }) | |
| groups.sort(key=lambda g: g["count"], reverse=True) | |
| return { | |
| "dataset": dataset, | |
| "group_by": group_by, | |
| "value_column": value_column, | |
| "total_rows": len(_datasets[dataset]), | |
| "rows_after_filter": len(df), | |
| "rows_with_value": int(len(df_clean)), | |
| "num_groups": len(groups), | |
| "min_group_size": min_group_size, | |
| "filters": filters or [], | |
| "warnings": warnings, | |
| "groups": groups, | |
| } | |
| except Exception as e: | |
| return {"error": str(e), "traceback": traceback.format_exc()} | |
| # --------------------------------------------------------------------- | |
| # Tool: get_chi_solvents / search_chi | |
| # --------------------------------------------------------------------- | |
| def tool_get_chi_solvents() -> dict[str, Any]: | |
| try: | |
| loaded = [s for s in SOLVENTS if get_chi_df(s) is not None] | |
| return {"count": len(loaded), "solvents": loaded} | |
| except Exception as e: | |
| return {"error": str(e)} | |
| def tool_search_chi( | |
| solvent: str, | |
| query: str = "", | |
| limit: int = 10, | |
| fields: list[str] | None = None, | |
| ) -> dict[str, Any]: | |
| try: | |
| df = get_chi_df(solvent) | |
| if df is None: | |
| return {"error": f"Chi dataset not found for solvent: '{solvent}'"} | |
| if query: | |
| q_lower = query.lower() | |
| mask = pd.Series(False, index=df.index) | |
| for col in df.columns: | |
| try: | |
| if df[col].dtype == object: | |
| mask = mask | df[col].astype(str).str.lower().str.contains(q_lower, na=False) | |
| except Exception: | |
| continue | |
| result_df = df[mask] | |
| else: | |
| result_df = df | |
| records = records_to_json(result_df.head(limit), fields=fields) | |
| return { | |
| "solvent": solvent, | |
| "query": query, | |
| "total_matched": len(result_df), | |
| "returned": len(records), | |
| "records": records, | |
| } | |
| except Exception as e: | |
| return {"error": str(e)} | |
| # --------------------------------------------------------------------- | |
| # MCP metadata | |
| # --------------------------------------------------------------------- | |
| SERVER_INFO = { | |
| "name": "polyomics-mcp-server", | |
| "version": "0.3.0", | |
| } | |
| TOOLS = [ | |
| { | |
| "name": "list_datasets", | |
| "description": "List available polymer datasets with row counts and column names.", | |
| "inputSchema": {"type": "object", "properties": {}, "additionalProperties": False}, | |
| }, | |
| { | |
| "name": "search_polymers", | |
| "description": ( | |
| "Search and filter polymers in a dataset. Supports full-text search, " | |
| "numeric range filters (e.g. 'density>1.0'), sorting, and field selection. " | |
| "Returns up to 1000 rows. For whole-dataset aggregations use " | |
| "get_statistics, get_correlation, get_distribution, or get_group_stats." | |
| ), | |
| "inputSchema": { | |
| "type": "object", | |
| "properties": { | |
| "query": { | |
| "type": "string", | |
| "description": "Full-text search string. Empty string returns all rows.", | |
| "default": "", | |
| }, | |
| "dataset": {"type": "string", "default": "general"}, | |
| "limit": { | |
| "type": "integer", | |
| "default": 10, | |
| "description": "Max rows to return (max 1000).", | |
| }, | |
| "filters": { | |
| "type": "array", | |
| "items": {"type": "string"}, | |
| "description": ( | |
| "Numeric filter expressions, e.g. ['density>1.0', 'tg>=300', " | |
| "'thermal_conductivity!=']. " | |
| "Operators: >, >=, <, <=, ==, !=. " | |
| "Use 'col!=' to require a non-null value." | |
| ), | |
| "default": [], | |
| }, | |
| "sort_by": {"type": "string", "description": "Column name to sort by."}, | |
| "sort_ascending": {"type": "boolean", "default": True}, | |
| "fields": { | |
| "type": "array", | |
| "items": {"type": "string"}, | |
| "description": "Columns to include in output. Defaults to a compact property set.", | |
| }, | |
| "require_numeric": { | |
| "type": "array", | |
| "items": {"type": "string"}, | |
| "description": "Convenience: columns that must have a non-null numeric value.", | |
| }, | |
| }, | |
| "additionalProperties": False, | |
| }, | |
| }, | |
| { | |
| "name": "get_dataset_columns", | |
| "description": ( | |
| "Return all column names for a dataset, plus lists of numeric columns " | |
| "and known property columns useful for filtering." | |
| ), | |
| "inputSchema": { | |
| "type": "object", | |
| "properties": {"dataset": {"type": "string", "default": "general"}}, | |
| "additionalProperties": False, | |
| }, | |
| }, | |
| { | |
| "name": "get_statistics", | |
| "description": ( | |
| "Compute descriptive statistics (count, mean, std, min, p25, median, p75, max) " | |
| "for numeric property columns using ALL rows in the dataset (server-side aggregation). " | |
| "Optional filters are applied before aggregation. " | |
| "Use this instead of search_polymers for whole-dataset summaries." | |
| ), | |
| "inputSchema": { | |
| "type": "object", | |
| "properties": { | |
| "dataset": {"type": "string", "default": "general"}, | |
| "columns": { | |
| "type": "array", | |
| "items": {"type": "string"}, | |
| "description": "Columns to summarise. Defaults to all known property columns.", | |
| }, | |
| "filters": { | |
| "type": "array", | |
| "items": {"type": "string"}, | |
| "description": "Numeric filters applied before computing statistics.", | |
| "default": [], | |
| }, | |
| }, | |
| "additionalProperties": False, | |
| }, | |
| }, | |
| { | |
| "name": "get_correlation", | |
| "description": ( | |
| "Compute Pearson and Spearman correlation between two numeric columns " | |
| "using ALL matching rows (server-side β no row limit). " | |
| "Returns correlation coefficients plus a random scatter-plot sample." | |
| ), | |
| "inputSchema": { | |
| "type": "object", | |
| "properties": { | |
| "dataset": {"type": "string", "default": "general"}, | |
| "col_x": {"type": "string", "default": "density"}, | |
| "col_y": {"type": "string", "default": "thermal_conductivity"}, | |
| "filters": { | |
| "type": "array", | |
| "items": {"type": "string"}, | |
| "description": "Numeric filters applied before computing the correlation.", | |
| "default": [], | |
| }, | |
| "sample_limit": { | |
| "type": "integer", | |
| "default": 500, | |
| "description": "Max data points to return for scatter-plot visualisation.", | |
| }, | |
| }, | |
| "required": ["col_x", "col_y"], | |
| "additionalProperties": False, | |
| }, | |
| }, | |
| { | |
| "name": "get_distribution", | |
| "description": ( | |
| "Compute a histogram (bin counts and frequencies) for a numeric column " | |
| "over ALL matching rows (server-side β no row limit). " | |
| "Ideal for visualising the distribution of density, TC, Tg, etc." | |
| ), | |
| "inputSchema": { | |
| "type": "object", | |
| "properties": { | |
| "dataset": {"type": "string", "default": "general"}, | |
| "column": { | |
| "type": "string", | |
| "default": "density", | |
| "description": "Numeric column to compute the histogram for.", | |
| }, | |
| "bins": { | |
| "type": "integer", | |
| "default": 20, | |
| "description": "Number of histogram bins (1β200).", | |
| }, | |
| "filters": { | |
| "type": "array", | |
| "items": {"type": "string"}, | |
| "description": "Numeric filters applied before computing the histogram.", | |
| "default": [], | |
| }, | |
| }, | |
| "additionalProperties": False, | |
| }, | |
| }, | |
| { | |
| "name": "get_group_stats", | |
| "description": ( | |
| "Aggregate a numeric column grouped by a categorical column (e.g. polymer_class). " | |
| "Runs on ALL matching rows server-side. " | |
| "Returns count, mean, std, min, p25, median, p75, max per group. " | |
| "Useful for comparing thermal_conductivity or density across polymer classes." | |
| ), | |
| "inputSchema": { | |
| "type": "object", | |
| "properties": { | |
| "dataset": {"type": "string", "default": "general"}, | |
| "group_by": { | |
| "type": "string", | |
| "default": "polymer_class", | |
| "description": "Categorical column to group on (e.g. 'polymer_class', '_source').", | |
| }, | |
| "value_column": { | |
| "type": "string", | |
| "default": "thermal_conductivity", | |
| "description": "Numeric column to aggregate.", | |
| }, | |
| "filters": { | |
| "type": "array", | |
| "items": {"type": "string"}, | |
| "description": "Numeric filters applied before grouping.", | |
| "default": [], | |
| }, | |
| "min_group_size": { | |
| "type": "integer", | |
| "default": 5, | |
| "description": "Groups with fewer rows than this are omitted.", | |
| }, | |
| }, | |
| "additionalProperties": False, | |
| }, | |
| }, | |
| { | |
| "name": "get_chi_solvents", | |
| "description": "List available solvents for chi parameter tables.", | |
| "inputSchema": {"type": "object", "properties": {}, "additionalProperties": False}, | |
| }, | |
| { | |
| "name": "search_chi", | |
| "description": "Search a chi parameter table for a given solvent.", | |
| "inputSchema": { | |
| "type": "object", | |
| "properties": { | |
| "solvent": {"type": "string"}, | |
| "query": {"type": "string", "default": ""}, | |
| "limit": {"type": "integer", "default": 10}, | |
| "fields": { | |
| "type": "array", | |
| "items": {"type": "string"}, | |
| "description": "Columns to include in output.", | |
| }, | |
| }, | |
| "required": ["solvent"], | |
| "additionalProperties": False, | |
| }, | |
| }, | |
| ] | |
| # --------------------------------------------------------------------- | |
| # Tool dispatcher | |
| # --------------------------------------------------------------------- | |
| def handle_tool_call(name: str, arguments: dict[str, Any]) -> dict[str, Any]: | |
| try: | |
| if name == "list_datasets": | |
| return tool_list_datasets() | |
| if name == "search_polymers": | |
| return tool_search_polymers( | |
| query=arguments.get("query", ""), | |
| dataset=arguments.get("dataset", "general"), | |
| limit=int(arguments.get("limit", 10)), | |
| filters=arguments.get("filters") or [], | |
| sort_by=arguments.get("sort_by"), | |
| sort_ascending=bool(arguments.get("sort_ascending", True)), | |
| fields=arguments.get("fields"), | |
| require_numeric=arguments.get("require_numeric"), | |
| ) | |
| if name == "get_dataset_columns": | |
| return tool_get_dataset_columns(dataset=arguments.get("dataset", "general")) | |
| if name == "get_statistics": | |
| return tool_get_statistics( | |
| dataset=arguments.get("dataset", "general"), | |
| columns=arguments.get("columns"), | |
| filters=arguments.get("filters") or [], | |
| ) | |
| if name == "get_correlation": | |
| return tool_get_correlation( | |
| dataset=arguments.get("dataset", "general"), | |
| col_x=arguments.get("col_x", "density"), | |
| col_y=arguments.get("col_y", "thermal_conductivity"), | |
| filters=arguments.get("filters") or [], | |
| sample_limit=int(arguments.get("sample_limit", 500)), | |
| ) | |
| if name == "get_distribution": | |
| return tool_get_distribution( | |
| dataset=arguments.get("dataset", "general"), | |
| column=arguments.get("column", "density"), | |
| bins=int(arguments.get("bins", 20)), | |
| filters=arguments.get("filters") or [], | |
| ) | |
| if name == "get_group_stats": | |
| return tool_get_group_stats( | |
| dataset=arguments.get("dataset", "general"), | |
| group_by=arguments.get("group_by", "polymer_class"), | |
| value_column=arguments.get("value_column", "thermal_conductivity"), | |
| filters=arguments.get("filters") or [], | |
| min_group_size=int(arguments.get("min_group_size", 5)), | |
| ) | |
| if name == "get_chi_solvents": | |
| return tool_get_chi_solvents() | |
| if name == "search_chi": | |
| return tool_search_chi( | |
| solvent=arguments.get("solvent", ""), | |
| query=arguments.get("query", ""), | |
| limit=int(arguments.get("limit", 10)), | |
| fields=arguments.get("fields"), | |
| ) | |
| return {"error": f"Unknown tool: '{name}'"} | |
| except Exception as e: | |
| return { | |
| "error": f"Unhandled exception in tool '{name}': {e}", | |
| "traceback": traceback.format_exc(), | |
| } | |
| # --------------------------------------------------------------------- | |
| # MCP HTTP handlers | |
| # --------------------------------------------------------------------- | |
| async def mcp_sse_get(request: Request) -> Response: | |
| print(">>> GET /mcp/sse") | |
| return PlainTextResponse( | |
| "event: ping\ndata: alive\n\n", | |
| media_type="text/event-stream", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "Connection": "keep-alive", | |
| "Access-Control-Allow-Origin": "*", | |
| }, | |
| ) | |
| async def mcp_sse_post(request: Request) -> JSONResponse: | |
| try: | |
| payload = await request.json() | |
| except Exception: | |
| raw = await request.body() | |
| print(">>> POST /mcp/sse β invalid JSON") | |
| print(raw.decode("utf-8", errors="replace")[:500]) | |
| return safe_jsonrpc_error(None, -32700, "Parse error") | |
| print(">>> POST /mcp/sse") | |
| print(json.dumps(payload, ensure_ascii=False)[:1000]) | |
| req_id = payload.get("id") | |
| method = payload.get("method") | |
| params = payload.get("params") or {} | |
| if method == "initialize": | |
| return safe_jsonrpc_result(req_id, { | |
| "protocolVersion": params.get("protocolVersion", "2025-03-26"), | |
| "capabilities": {"tools": {}}, | |
| "serverInfo": SERVER_INFO, | |
| }) | |
| if method == "notifications/initialized": | |
| return JSONResponse({}, status_code=202) | |
| if method == "ping": | |
| return safe_jsonrpc_result(req_id, {}) | |
| if method == "tools/list": | |
| return safe_jsonrpc_result(req_id, {"tools": TOOLS}) | |
| if method == "tools/call": | |
| name = params.get("name") | |
| arguments = params.get("arguments") or {} | |
| result = handle_tool_call(name, arguments) | |
| is_error = isinstance(result, dict) and "error" in result | |
| return safe_jsonrpc_result(req_id, { | |
| "content": [{"type": "text", "text": stringify_records([result])}], | |
| "isError": is_error, | |
| }) | |
| return safe_jsonrpc_error(req_id, -32601, f"Method not found: {method}") | |
| # --------------------------------------------------------------------- | |
| # Auxiliary HTTP endpoints | |
| # --------------------------------------------------------------------- | |
| async def health(request: Request) -> JSONResponse: | |
| chi_loaded = [s for s in SOLVENTS if get_chi_df(s) is not None] | |
| return JSONResponse({ | |
| "status": "ok", | |
| "server": SERVER_INFO, | |
| "dataset_repo": DATASET_REPO, | |
| "datasets": {k: len(v) for k, v in _datasets.items()}, | |
| "chi_solvents_loaded": len(chi_loaded), | |
| }) | |
| async def oauth_protected_resource(request: Request) -> Response: | |
| return Response(status_code=204) | |
| async def oauth_authorization_server(request: Request) -> Response: | |
| return Response(status_code=204) | |
| async def register(request: Request) -> Response: | |
| raw = await request.body() | |
| print(">>> POST /register:", raw.decode("utf-8", errors="replace")[:200]) | |
| return Response(status_code=204) | |
| async def options_handler(request: Request) -> Response: | |
| return Response( | |
| status_code=204, | |
| headers={ | |
| "Access-Control-Allow-Origin": "*", | |
| "Access-Control-Allow-Methods": "GET, POST, OPTIONS", | |
| "Access-Control-Allow-Headers": "*", | |
| }, | |
| ) | |
| # --------------------------------------------------------------------- | |
| # Gradio UI | |
| # --------------------------------------------------------------------- | |
| def build_gradio_app() -> GradioApp: | |
| with gr.Blocks(title="PolyOmics MCP Server") as demo: | |
| gr.Markdown( | |
| "# PolyOmics MCP Server\n\n" | |
| f"**Dataset repo:** `{DATASET_REPO}`\n\n" | |
| "**Claude connector endpoint:** `https://mohnishi-polyomics-mcp-server.hf.space/mcp/sse`\n\n" | |
| "**Available MCP tools (v0.3.0):**\n" | |
| "- `list_datasets` β list datasets and row counts\n" | |
| "- `search_polymers` β full-text + numeric filter + sort + field selection (up to 1000 rows)\n" | |
| "- `get_dataset_columns` β column names, numeric columns, known property columns\n" | |
| "- `get_statistics` β **whole-dataset** descriptive stats per property column\n" | |
| "- `get_correlation` β **whole-dataset** Pearson/Spearman + scatter sample\n" | |
| "- `get_distribution` β **whole-dataset** histogram for any numeric column\n" | |
| "- `get_group_stats` β **whole-dataset** GROUP BY aggregation (e.g. TC by polymer_class)\n" | |
| "- `get_chi_solvents` β list loaded chi-parameter solvents\n" | |
| "- `search_chi` β search chi parameter tables\n" | |
| ) | |
| df = get_main_df() | |
| gr.JSON({ | |
| "dataset_repo": DATASET_REPO, | |
| "server_version": SERVER_INFO["version"], | |
| "main_rows": int(len(df)), | |
| "main_columns": int(len(df.columns)) if not df.empty else 0, | |
| "datasets": list(_datasets.keys()), | |
| }) | |
| return GradioApp.create_app(demo, app_kwargs={"docs_url": "/docs"}) | |
| # --------------------------------------------------------------------- | |
| # Starlette app assembly | |
| # --------------------------------------------------------------------- | |
| def build_app() -> Starlette: | |
| gradio_app = build_gradio_app() | |
| app = Starlette( | |
| routes=[ | |
| Route("/health", endpoint=health, methods=["GET"]), | |
| Route("/mcp/sse", endpoint=mcp_sse_get, methods=["GET"]), | |
| Route("/mcp/sse", endpoint=mcp_sse_post, methods=["POST"]), | |
| Route("/mcp/sse", endpoint=options_handler, methods=["OPTIONS"]), | |
| Route("/.well-known/oauth-protected-resource", endpoint=oauth_protected_resource, methods=["GET"]), | |
| Route("/.well-known/oauth-protected-resource/mcp/sse", endpoint=oauth_protected_resource, methods=["GET"]), | |
| Route("/.well-known/oauth-authorization-server", endpoint=oauth_authorization_server, methods=["GET"]), | |
| Route("/register", endpoint=register, methods=["POST"]), | |
| Mount("/", app=gradio_app), | |
| ] | |
| ) | |
| print(f"MCP JSON-RPC server v{SERVER_INFO['version']} ready at /mcp/sse") | |
| for r in app.routes: | |
| print(f" {type(r).__name__}: {getattr(r, 'path', None)}") | |
| return app | |
| # --------------------------------------------------------------------- | |
| # Entry point | |
| # --------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S UTC") | |
| print(f"===== Application Startup at {now} =====") | |
| preload_data() | |
| app = build_app() | |
| port = int(os.environ.get("PORT", 7860)) | |
| print(f"Starting uvicorn on 0.0.0.0:{port} ...") | |
| uvicorn.run(app, host="0.0.0.0", port=port) |