mohnishi's picture
brush up app.py
048e5cf
import json
import os
import traceback
import urllib.parse
from datetime import datetime, timezone
from typing import Any
import numpy as np
import pandas as pd
import gradio as gr
from gradio.routes import App as GradioApp
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import JSONResponse, PlainTextResponse, Response
from starlette.routing import Route, Mount
import uvicorn
# ---------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------
DATASET_REPO = os.environ.get("DATASET_REPO", "yhayashi1986/PolyOmics")
HF_BASE_URL = f"https://huggingface.co/datasets/{DATASET_REPO}/resolve/main"
SOLVENTS = [
"Water", "Methanol", "Ethanol", "Acetone", "Benzene", "Toluene",
"Chloroform", "Hexane", "Cyclohexane", "DMSO", "Tetrahydrofuran",
"Dimethylformamide", "Ethyl_acetate", "Diethyl_ether", "1,4-Dioxane",
"DEGDB", "DEHA", "DEHP", "Acetyl_tributyl_citrate",
]
DATASET_FILES = {
"general": "general_polymers_with_sp_abbe_dynamic-dielectric.csv",
"pfas": "PFAS.csv",
"biodegradable": "biodegradable_candidate_polyester.csv",
"cellulose": "cellulose.csv",
"ladder": "ladder_polymers.csv",
"small_molecules": "small_molecules.csv",
"non_ladder": "corresponding_non_ladder_polymers.csv",
}
# Known numeric property columns β€” used as defaults for statistics / filtering
NUMERIC_PROPERTY_COLUMNS = [
"density", "thermal_conductivity", "thermal_diffusivity",
"tg", "refractive_index", "static_dielectric_const", "dielectric_const_dc",
"bulk_modulus", "isentropic_bulk_modulus", "compressibility",
"volume_expansion", "linear_expansion", "Cp", "Cv",
"sp_total", "sp_vdw", "sp_ele", "sp_ced",
"mol_weight", "Mn", "Mw", "Rg", "self-diffusion",
"abbe_number_sos", "refractive_index_sos",
"efdp_permittivity_real", "efdp_permittivity_imaginary", "efdp_dielectric_loss_tan",
"nematic_order_parameter",
]
# Default compact field set returned by search_polymers
_SEARCH_DEFAULT_FIELDS = [
"UUID", "smiles_list", "polymer_class", "_source",
"density", "thermal_conductivity", "thermal_diffusivity",
"tg", "refractive_index", "static_dielectric_const",
"bulk_modulus", "sp_total", "abbe_number_sos",
]
# ---------------------------------------------------------------------
# Data layer
# ---------------------------------------------------------------------
_datasets: dict[str, pd.DataFrame] = {}
_main_df: pd.DataFrame | None = None
_chi_dfs: dict[str, pd.DataFrame | None] = {}
def _hf_url(path_in_repo: str) -> str:
encoded = urllib.parse.quote(path_in_repo, safe="/")
return f"{HF_BASE_URL}/{encoded}"
def _load_csv(path_in_repo: str) -> pd.DataFrame | None:
url = _hf_url(path_in_repo)
try:
return pd.read_csv(url, low_memory=False)
except Exception as e:
print(f"[WARN] Failed to load {path_in_repo}: {e}")
return None
def preload_data() -> None:
global _main_df
print("Loading polymer datasets ...")
for key, fname in DATASET_FILES.items():
df = _load_csv(fname)
if df is not None:
df["_source"] = key
# Coerce known numeric columns at load time for reliable downstream ops
for col in NUMERIC_PROPERTY_COLUMNS:
if col in df.columns:
df[col] = pd.to_numeric(df[col], errors="coerce")
_datasets[key] = df
print(f" [{key}] {len(df):,} rows, {len(df.columns)} cols")
else:
print(f" [{key}] FAILED")
if "general" in _datasets:
_main_df = _datasets["general"]
elif _datasets:
_main_df = pd.concat(list(_datasets.values()), ignore_index=True)
else:
_main_df = pd.DataFrame()
print("Loading chi parameter datasets ...")
loaded = []
for solvent in SOLVENTS:
df = _load_csv(f"chi_parameter/Summary_{solvent}.csv")
_chi_dfs[solvent] = df
if df is not None:
loaded.append(solvent)
print(f"Chi parameter solvents loaded: {loaded}")
def get_main_df() -> pd.DataFrame:
return _main_df if _main_df is not None else pd.DataFrame()
def get_chi_df(solvent: str) -> pd.DataFrame | None:
return _chi_dfs.get(solvent)
# ---------------------------------------------------------------------
# JSON / serialisation helpers
# ---------------------------------------------------------------------
def _safe_value(v: Any) -> Any:
"""Convert non-JSON-serialisable values to Python-native types."""
if isinstance(v, float) and (np.isnan(v) or np.isinf(v)):
return None
if isinstance(v, np.integer):
return int(v)
if isinstance(v, np.floating):
return float(v)
return v
def records_to_json(df: pd.DataFrame, fields: list[str] | None = None) -> list[dict[str, Any]]:
"""Convert a DataFrame to a clean list of JSON-serialisable dicts."""
if df.empty:
return []
if fields:
existing = [f for f in fields if f in df.columns]
df = df[existing]
return [{k: _safe_value(v) for k, v in row.items()} for row in df.to_dict(orient="records")]
def stringify_records(records: list[dict[str, Any]]) -> str:
if not records:
return "No records found."
return json.dumps(records, ensure_ascii=False, indent=2, default=str)
def safe_jsonrpc_result(req_id: Any, result: Any) -> JSONResponse:
return JSONResponse({"jsonrpc": "2.0", "id": req_id, "result": result})
def safe_jsonrpc_error(req_id: Any, code: int, message: str) -> JSONResponse:
return JSONResponse({"jsonrpc": "2.0", "id": req_id, "error": {"code": code, "message": message}})
# ---------------------------------------------------------------------
# Numeric filter engine
#
# Syntax:
# "density>1.0" strict greater-than
# "density>=1.0" greater-or-equal
# "thermal_conductivity!=" column must be non-null
# ---------------------------------------------------------------------
_OPERATORS = [">=", "<=", "!=", ">", "<", "=="]
def _apply_numeric_filters(
df: pd.DataFrame, filters: list[str]
) -> tuple[pd.DataFrame, list[str]]:
"""Apply numeric range filters; return (filtered_df, warning_list)."""
warnings: list[str] = []
for f in filters:
f = f.strip()
parsed = False
for op in _OPERATORS:
idx = f.find(op)
if idx == -1:
continue
col = f[:idx].strip()
val_str = f[idx + len(op):].strip()
if col not in df.columns:
warnings.append(f"Column '{col}' not found β€” filter '{f}' skipped.")
parsed = True
break
if df[col].dtype == object:
df = df.copy()
df[col] = pd.to_numeric(df[col], errors="coerce")
if op == "!=" and val_str == "":
df = df[df[col].notna()]
parsed = True
break
try:
val = float(val_str)
except ValueError:
warnings.append(f"Cannot parse value '{val_str}' in filter '{f}' β€” skipped.")
parsed = True
break
ops_map = {
">": df[col] > val,
">=": df[col] >= val,
"<": df[col] < val,
"<=": df[col] <= val,
"==": df[col] == val,
"!=": df[col] != val,
}
df = df[ops_map[op]]
parsed = True
break
if not parsed:
warnings.append(f"Could not parse filter '{f}' β€” skipped.")
return df, warnings
# ---------------------------------------------------------------------
# Tool: list_datasets
# ---------------------------------------------------------------------
def tool_list_datasets() -> dict[str, Any]:
"""List all loaded datasets with row counts and column names."""
try:
return {
"datasets": {
name: {"rows": int(len(df)), "columns": list(df.columns)}
for name, df in _datasets.items()
}
}
except Exception as e:
return {"error": str(e)}
# ---------------------------------------------------------------------
# Tool: search_polymers
# ---------------------------------------------------------------------
def tool_search_polymers(
query: str = "",
dataset: str = "general",
limit: int = 10,
filters: list[str] | None = None,
sort_by: str | None = None,
sort_ascending: bool = True,
fields: list[str] | None = None,
require_numeric: list[str] | None = None,
) -> dict[str, Any]:
"""Search and filter polymers.
- limit is capped at 1000; use the aggregation tools for whole-dataset analysis.
- Empty query with filters returns all rows passing those filters.
"""
try:
if dataset not in _datasets:
return {"error": f"Unknown dataset: '{dataset}'. Available: {list(_datasets.keys())}"}
limit = min(int(limit), 1000)
df = _datasets[dataset].copy()
q = (query or "").strip()
if q:
q_lower = q.lower()
mask = pd.Series(False, index=df.index)
for col in df.columns:
try:
if df[col].dtype == object:
mask = mask | df[col].astype(str).str.lower().str.contains(q_lower, na=False)
except Exception:
continue
df = df[mask]
if require_numeric:
filters = list(filters or []) + [f"{col}!=" for col in require_numeric]
warnings: list[str] = []
if filters:
df, warnings = _apply_numeric_filters(df, filters)
total_after_filter = len(df)
if sort_by and sort_by in df.columns:
df = df.sort_values(by=sort_by, ascending=sort_ascending, na_position="last")
if fields is None:
fields = [f for f in _SEARCH_DEFAULT_FIELDS if f in df.columns]
records = records_to_json(df.head(limit), fields=fields)
return {
"dataset": dataset,
"query": query,
"filters": filters or [],
"total_matched": total_after_filter,
"returned": len(records),
"sort_by": sort_by,
"warnings": warnings,
"records": records,
}
except Exception as e:
return {"error": str(e), "traceback": traceback.format_exc()}
# ---------------------------------------------------------------------
# Tool: get_dataset_columns
# ---------------------------------------------------------------------
def tool_get_dataset_columns(dataset: str = "general") -> dict[str, Any]:
try:
if dataset not in _datasets:
return {"error": f"Unknown dataset: '{dataset}'"}
df = _datasets[dataset]
numeric_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])]
return {
"dataset": dataset,
"num_columns": int(len(df.columns)),
"columns": list(df.columns),
"numeric_columns": numeric_cols,
"known_property_columns": [c for c in NUMERIC_PROPERTY_COLUMNS if c in df.columns],
}
except Exception as e:
return {"error": str(e)}
# ---------------------------------------------------------------------
# Tool: get_statistics β€” whole-dataset aggregation
# ---------------------------------------------------------------------
def tool_get_statistics(
dataset: str = "general",
columns: list[str] | None = None,
filters: list[str] | None = None,
) -> dict[str, Any]:
"""Descriptive statistics over ALL rows (or a filtered subset).
No per-row limit β€” aggregation runs server-side.
"""
try:
if dataset not in _datasets:
return {"error": f"Unknown dataset: '{dataset}'"}
df = _datasets[dataset].copy()
warnings: list[str] = []
if filters:
df, warnings = _apply_numeric_filters(df, filters)
if columns is None:
columns = [c for c in NUMERIC_PROPERTY_COLUMNS if c in df.columns]
stats: dict[str, Any] = {}
for col in columns:
if col not in df.columns:
stats[col] = {"error": "column not found"}
continue
s = pd.to_numeric(df[col], errors="coerce").dropna()
if s.empty:
stats[col] = {"count": 0, "note": "no numeric data"}
continue
stats[col] = {
"count": int(s.count()),
"mean": round(float(s.mean()), 6),
"std": round(float(s.std()), 6),
"min": round(float(s.min()), 6),
"p25": round(float(s.quantile(0.25)), 6),
"median": round(float(s.median()), 6),
"p75": round(float(s.quantile(0.75)), 6),
"max": round(float(s.max()), 6),
}
return {
"dataset": dataset,
"total_rows": len(_datasets[dataset]),
"rows_after_filter": len(df),
"filters": filters or [],
"warnings": warnings,
"statistics": stats,
}
except Exception as e:
return {"error": str(e), "traceback": traceback.format_exc()}
# ---------------------------------------------------------------------
# Tool: get_correlation β€” whole-dataset correlation
# ---------------------------------------------------------------------
def tool_get_correlation(
dataset: str = "general",
col_x: str = "density",
col_y: str = "thermal_conductivity",
filters: list[str] | None = None,
sample_limit: int = 500,
) -> dict[str, Any]:
"""Pearson + Spearman correlation computed on ALL matching rows.
Returns a random sample of up to `sample_limit` points for plotting.
"""
try:
if dataset not in _datasets:
return {"error": f"Unknown dataset: '{dataset}'"}
df = _datasets[dataset].copy()
warnings: list[str] = []
if filters:
df, warnings = _apply_numeric_filters(df, filters)
for col in [col_x, col_y]:
if col not in df.columns:
return {"error": f"Column '{col}' not found in dataset '{dataset}'."}
df[col] = pd.to_numeric(df[col], errors="coerce")
pair = df[[col_x, col_y]].dropna()
n = len(pair)
if n < 2:
return {
"dataset": dataset, "col_x": col_x, "col_y": col_y, "n_total": n,
"warnings": warnings, "error": "Not enough data points after filtering.",
}
pearson_r = float(pair[col_x].corr(pair[col_y], method="pearson"))
spearman_r = float(pair[col_x].corr(pair[col_y], method="spearman"))
sample_n = min(sample_limit, n)
sample = pair.sample(n=sample_n, random_state=42).sort_values(col_x)
sample_points = [
{col_x: _safe_value(row[col_x]), col_y: _safe_value(row[col_y])}
for _, row in sample.iterrows()
]
return {
"dataset": dataset,
"col_x": col_x,
"col_y": col_y,
"n_total": n,
"pearson_r": round(pearson_r, 4),
"spearman_r": round(spearman_r, 4),
"filters": filters or [],
"warnings": warnings,
"sample_n": sample_n,
"sample_points": sample_points,
}
except Exception as e:
return {"error": str(e), "traceback": traceback.format_exc()}
# ---------------------------------------------------------------------
# Tool: get_distribution β€” whole-dataset histogram
# ---------------------------------------------------------------------
def tool_get_distribution(
dataset: str = "general",
column: str = "density",
bins: int = 20,
filters: list[str] | None = None,
) -> dict[str, Any]:
"""Histogram (bin counts + frequencies) for a numeric column over ALL rows.
No per-row limit β€” computed server-side.
"""
try:
if dataset not in _datasets:
return {"error": f"Unknown dataset: '{dataset}'"}
df = _datasets[dataset].copy()
warnings: list[str] = []
if filters:
df, warnings = _apply_numeric_filters(df, filters)
if column not in df.columns:
return {"error": f"Column '{column}' not found in dataset '{dataset}'."}
s = pd.to_numeric(df[column], errors="coerce").dropna()
if s.empty:
return {"error": f"No numeric data in column '{column}' after filtering."}
bins = max(1, min(int(bins), 200))
counts, edges = np.histogram(s.values, bins=bins)
histogram = [
{
"bin_start": round(float(edges[i]), 6),
"bin_end": round(float(edges[i + 1]), 6),
"bin_mid": round(float((edges[i] + edges[i + 1]) / 2), 6),
"count": int(counts[i]),
"frequency": round(float(counts[i] / len(s)), 6),
}
for i in range(len(counts))
]
return {
"dataset": dataset,
"column": column,
"total_rows": len(_datasets[dataset]),
"rows_used": int(len(s)),
"bins": bins,
"filters": filters or [],
"warnings": warnings,
"min": round(float(s.min()), 6),
"max": round(float(s.max()), 6),
"mean": round(float(s.mean()), 6),
"median": round(float(s.median()), 6),
"std": round(float(s.std()), 6),
"histogram": histogram,
}
except Exception as e:
return {"error": str(e), "traceback": traceback.format_exc()}
# ---------------------------------------------------------------------
# Tool: get_group_stats β€” GROUP BY aggregation
# ---------------------------------------------------------------------
def tool_get_group_stats(
dataset: str = "general",
group_by: str = "polymer_class",
value_column: str = "thermal_conductivity",
filters: list[str] | None = None,
min_group_size: int = 5,
) -> dict[str, Any]:
"""Aggregate a numeric column by a categorical column over ALL rows.
Returns count, mean, std, min, p25, median, p75, max per group,
sorted by descending count.
"""
try:
if dataset not in _datasets:
return {"error": f"Unknown dataset: '{dataset}'"}
df = _datasets[dataset].copy()
warnings: list[str] = []
if filters:
df, warnings = _apply_numeric_filters(df, filters)
for col in [group_by, value_column]:
if col not in df.columns:
return {"error": f"Column '{col}' not found in dataset '{dataset}'."}
df[value_column] = pd.to_numeric(df[value_column], errors="coerce")
df_clean = df[[group_by, value_column]].dropna(subset=[value_column])
groups: list[dict[str, Any]] = []
for name, grp in df_clean.groupby(group_by, sort=False):
s = grp[value_column]
if len(s) < min_group_size:
continue
groups.append({
"group": _safe_value(name),
"count": int(len(s)),
"mean": round(float(s.mean()), 6),
"std": round(float(s.std()), 6),
"min": round(float(s.min()), 6),
"p25": round(float(s.quantile(0.25)), 6),
"median": round(float(s.median()), 6),
"p75": round(float(s.quantile(0.75)), 6),
"max": round(float(s.max()), 6),
})
groups.sort(key=lambda g: g["count"], reverse=True)
return {
"dataset": dataset,
"group_by": group_by,
"value_column": value_column,
"total_rows": len(_datasets[dataset]),
"rows_after_filter": len(df),
"rows_with_value": int(len(df_clean)),
"num_groups": len(groups),
"min_group_size": min_group_size,
"filters": filters or [],
"warnings": warnings,
"groups": groups,
}
except Exception as e:
return {"error": str(e), "traceback": traceback.format_exc()}
# ---------------------------------------------------------------------
# Tool: get_chi_solvents / search_chi
# ---------------------------------------------------------------------
def tool_get_chi_solvents() -> dict[str, Any]:
try:
loaded = [s for s in SOLVENTS if get_chi_df(s) is not None]
return {"count": len(loaded), "solvents": loaded}
except Exception as e:
return {"error": str(e)}
def tool_search_chi(
solvent: str,
query: str = "",
limit: int = 10,
fields: list[str] | None = None,
) -> dict[str, Any]:
try:
df = get_chi_df(solvent)
if df is None:
return {"error": f"Chi dataset not found for solvent: '{solvent}'"}
if query:
q_lower = query.lower()
mask = pd.Series(False, index=df.index)
for col in df.columns:
try:
if df[col].dtype == object:
mask = mask | df[col].astype(str).str.lower().str.contains(q_lower, na=False)
except Exception:
continue
result_df = df[mask]
else:
result_df = df
records = records_to_json(result_df.head(limit), fields=fields)
return {
"solvent": solvent,
"query": query,
"total_matched": len(result_df),
"returned": len(records),
"records": records,
}
except Exception as e:
return {"error": str(e)}
# ---------------------------------------------------------------------
# MCP metadata
# ---------------------------------------------------------------------
SERVER_INFO = {
"name": "polyomics-mcp-server",
"version": "0.3.0",
}
TOOLS = [
{
"name": "list_datasets",
"description": "List available polymer datasets with row counts and column names.",
"inputSchema": {"type": "object", "properties": {}, "additionalProperties": False},
},
{
"name": "search_polymers",
"description": (
"Search and filter polymers in a dataset. Supports full-text search, "
"numeric range filters (e.g. 'density>1.0'), sorting, and field selection. "
"Returns up to 1000 rows. For whole-dataset aggregations use "
"get_statistics, get_correlation, get_distribution, or get_group_stats."
),
"inputSchema": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Full-text search string. Empty string returns all rows.",
"default": "",
},
"dataset": {"type": "string", "default": "general"},
"limit": {
"type": "integer",
"default": 10,
"description": "Max rows to return (max 1000).",
},
"filters": {
"type": "array",
"items": {"type": "string"},
"description": (
"Numeric filter expressions, e.g. ['density>1.0', 'tg>=300', "
"'thermal_conductivity!=']. "
"Operators: >, >=, <, <=, ==, !=. "
"Use 'col!=' to require a non-null value."
),
"default": [],
},
"sort_by": {"type": "string", "description": "Column name to sort by."},
"sort_ascending": {"type": "boolean", "default": True},
"fields": {
"type": "array",
"items": {"type": "string"},
"description": "Columns to include in output. Defaults to a compact property set.",
},
"require_numeric": {
"type": "array",
"items": {"type": "string"},
"description": "Convenience: columns that must have a non-null numeric value.",
},
},
"additionalProperties": False,
},
},
{
"name": "get_dataset_columns",
"description": (
"Return all column names for a dataset, plus lists of numeric columns "
"and known property columns useful for filtering."
),
"inputSchema": {
"type": "object",
"properties": {"dataset": {"type": "string", "default": "general"}},
"additionalProperties": False,
},
},
{
"name": "get_statistics",
"description": (
"Compute descriptive statistics (count, mean, std, min, p25, median, p75, max) "
"for numeric property columns using ALL rows in the dataset (server-side aggregation). "
"Optional filters are applied before aggregation. "
"Use this instead of search_polymers for whole-dataset summaries."
),
"inputSchema": {
"type": "object",
"properties": {
"dataset": {"type": "string", "default": "general"},
"columns": {
"type": "array",
"items": {"type": "string"},
"description": "Columns to summarise. Defaults to all known property columns.",
},
"filters": {
"type": "array",
"items": {"type": "string"},
"description": "Numeric filters applied before computing statistics.",
"default": [],
},
},
"additionalProperties": False,
},
},
{
"name": "get_correlation",
"description": (
"Compute Pearson and Spearman correlation between two numeric columns "
"using ALL matching rows (server-side β€” no row limit). "
"Returns correlation coefficients plus a random scatter-plot sample."
),
"inputSchema": {
"type": "object",
"properties": {
"dataset": {"type": "string", "default": "general"},
"col_x": {"type": "string", "default": "density"},
"col_y": {"type": "string", "default": "thermal_conductivity"},
"filters": {
"type": "array",
"items": {"type": "string"},
"description": "Numeric filters applied before computing the correlation.",
"default": [],
},
"sample_limit": {
"type": "integer",
"default": 500,
"description": "Max data points to return for scatter-plot visualisation.",
},
},
"required": ["col_x", "col_y"],
"additionalProperties": False,
},
},
{
"name": "get_distribution",
"description": (
"Compute a histogram (bin counts and frequencies) for a numeric column "
"over ALL matching rows (server-side β€” no row limit). "
"Ideal for visualising the distribution of density, TC, Tg, etc."
),
"inputSchema": {
"type": "object",
"properties": {
"dataset": {"type": "string", "default": "general"},
"column": {
"type": "string",
"default": "density",
"description": "Numeric column to compute the histogram for.",
},
"bins": {
"type": "integer",
"default": 20,
"description": "Number of histogram bins (1–200).",
},
"filters": {
"type": "array",
"items": {"type": "string"},
"description": "Numeric filters applied before computing the histogram.",
"default": [],
},
},
"additionalProperties": False,
},
},
{
"name": "get_group_stats",
"description": (
"Aggregate a numeric column grouped by a categorical column (e.g. polymer_class). "
"Runs on ALL matching rows server-side. "
"Returns count, mean, std, min, p25, median, p75, max per group. "
"Useful for comparing thermal_conductivity or density across polymer classes."
),
"inputSchema": {
"type": "object",
"properties": {
"dataset": {"type": "string", "default": "general"},
"group_by": {
"type": "string",
"default": "polymer_class",
"description": "Categorical column to group on (e.g. 'polymer_class', '_source').",
},
"value_column": {
"type": "string",
"default": "thermal_conductivity",
"description": "Numeric column to aggregate.",
},
"filters": {
"type": "array",
"items": {"type": "string"},
"description": "Numeric filters applied before grouping.",
"default": [],
},
"min_group_size": {
"type": "integer",
"default": 5,
"description": "Groups with fewer rows than this are omitted.",
},
},
"additionalProperties": False,
},
},
{
"name": "get_chi_solvents",
"description": "List available solvents for chi parameter tables.",
"inputSchema": {"type": "object", "properties": {}, "additionalProperties": False},
},
{
"name": "search_chi",
"description": "Search a chi parameter table for a given solvent.",
"inputSchema": {
"type": "object",
"properties": {
"solvent": {"type": "string"},
"query": {"type": "string", "default": ""},
"limit": {"type": "integer", "default": 10},
"fields": {
"type": "array",
"items": {"type": "string"},
"description": "Columns to include in output.",
},
},
"required": ["solvent"],
"additionalProperties": False,
},
},
]
# ---------------------------------------------------------------------
# Tool dispatcher
# ---------------------------------------------------------------------
def handle_tool_call(name: str, arguments: dict[str, Any]) -> dict[str, Any]:
try:
if name == "list_datasets":
return tool_list_datasets()
if name == "search_polymers":
return tool_search_polymers(
query=arguments.get("query", ""),
dataset=arguments.get("dataset", "general"),
limit=int(arguments.get("limit", 10)),
filters=arguments.get("filters") or [],
sort_by=arguments.get("sort_by"),
sort_ascending=bool(arguments.get("sort_ascending", True)),
fields=arguments.get("fields"),
require_numeric=arguments.get("require_numeric"),
)
if name == "get_dataset_columns":
return tool_get_dataset_columns(dataset=arguments.get("dataset", "general"))
if name == "get_statistics":
return tool_get_statistics(
dataset=arguments.get("dataset", "general"),
columns=arguments.get("columns"),
filters=arguments.get("filters") or [],
)
if name == "get_correlation":
return tool_get_correlation(
dataset=arguments.get("dataset", "general"),
col_x=arguments.get("col_x", "density"),
col_y=arguments.get("col_y", "thermal_conductivity"),
filters=arguments.get("filters") or [],
sample_limit=int(arguments.get("sample_limit", 500)),
)
if name == "get_distribution":
return tool_get_distribution(
dataset=arguments.get("dataset", "general"),
column=arguments.get("column", "density"),
bins=int(arguments.get("bins", 20)),
filters=arguments.get("filters") or [],
)
if name == "get_group_stats":
return tool_get_group_stats(
dataset=arguments.get("dataset", "general"),
group_by=arguments.get("group_by", "polymer_class"),
value_column=arguments.get("value_column", "thermal_conductivity"),
filters=arguments.get("filters") or [],
min_group_size=int(arguments.get("min_group_size", 5)),
)
if name == "get_chi_solvents":
return tool_get_chi_solvents()
if name == "search_chi":
return tool_search_chi(
solvent=arguments.get("solvent", ""),
query=arguments.get("query", ""),
limit=int(arguments.get("limit", 10)),
fields=arguments.get("fields"),
)
return {"error": f"Unknown tool: '{name}'"}
except Exception as e:
return {
"error": f"Unhandled exception in tool '{name}': {e}",
"traceback": traceback.format_exc(),
}
# ---------------------------------------------------------------------
# MCP HTTP handlers
# ---------------------------------------------------------------------
async def mcp_sse_get(request: Request) -> Response:
print(">>> GET /mcp/sse")
return PlainTextResponse(
"event: ping\ndata: alive\n\n",
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Access-Control-Allow-Origin": "*",
},
)
async def mcp_sse_post(request: Request) -> JSONResponse:
try:
payload = await request.json()
except Exception:
raw = await request.body()
print(">>> POST /mcp/sse β€” invalid JSON")
print(raw.decode("utf-8", errors="replace")[:500])
return safe_jsonrpc_error(None, -32700, "Parse error")
print(">>> POST /mcp/sse")
print(json.dumps(payload, ensure_ascii=False)[:1000])
req_id = payload.get("id")
method = payload.get("method")
params = payload.get("params") or {}
if method == "initialize":
return safe_jsonrpc_result(req_id, {
"protocolVersion": params.get("protocolVersion", "2025-03-26"),
"capabilities": {"tools": {}},
"serverInfo": SERVER_INFO,
})
if method == "notifications/initialized":
return JSONResponse({}, status_code=202)
if method == "ping":
return safe_jsonrpc_result(req_id, {})
if method == "tools/list":
return safe_jsonrpc_result(req_id, {"tools": TOOLS})
if method == "tools/call":
name = params.get("name")
arguments = params.get("arguments") or {}
result = handle_tool_call(name, arguments)
is_error = isinstance(result, dict) and "error" in result
return safe_jsonrpc_result(req_id, {
"content": [{"type": "text", "text": stringify_records([result])}],
"isError": is_error,
})
return safe_jsonrpc_error(req_id, -32601, f"Method not found: {method}")
# ---------------------------------------------------------------------
# Auxiliary HTTP endpoints
# ---------------------------------------------------------------------
async def health(request: Request) -> JSONResponse:
chi_loaded = [s for s in SOLVENTS if get_chi_df(s) is not None]
return JSONResponse({
"status": "ok",
"server": SERVER_INFO,
"dataset_repo": DATASET_REPO,
"datasets": {k: len(v) for k, v in _datasets.items()},
"chi_solvents_loaded": len(chi_loaded),
})
async def oauth_protected_resource(request: Request) -> Response:
return Response(status_code=204)
async def oauth_authorization_server(request: Request) -> Response:
return Response(status_code=204)
async def register(request: Request) -> Response:
raw = await request.body()
print(">>> POST /register:", raw.decode("utf-8", errors="replace")[:200])
return Response(status_code=204)
async def options_handler(request: Request) -> Response:
return Response(
status_code=204,
headers={
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "GET, POST, OPTIONS",
"Access-Control-Allow-Headers": "*",
},
)
# ---------------------------------------------------------------------
# Gradio UI
# ---------------------------------------------------------------------
def build_gradio_app() -> GradioApp:
with gr.Blocks(title="PolyOmics MCP Server") as demo:
gr.Markdown(
"# PolyOmics MCP Server\n\n"
f"**Dataset repo:** `{DATASET_REPO}`\n\n"
"**Claude connector endpoint:** `https://mohnishi-polyomics-mcp-server.hf.space/mcp/sse`\n\n"
"**Available MCP tools (v0.3.0):**\n"
"- `list_datasets` β€” list datasets and row counts\n"
"- `search_polymers` β€” full-text + numeric filter + sort + field selection (up to 1000 rows)\n"
"- `get_dataset_columns` β€” column names, numeric columns, known property columns\n"
"- `get_statistics` β€” **whole-dataset** descriptive stats per property column\n"
"- `get_correlation` β€” **whole-dataset** Pearson/Spearman + scatter sample\n"
"- `get_distribution` β€” **whole-dataset** histogram for any numeric column\n"
"- `get_group_stats` β€” **whole-dataset** GROUP BY aggregation (e.g. TC by polymer_class)\n"
"- `get_chi_solvents` β€” list loaded chi-parameter solvents\n"
"- `search_chi` β€” search chi parameter tables\n"
)
df = get_main_df()
gr.JSON({
"dataset_repo": DATASET_REPO,
"server_version": SERVER_INFO["version"],
"main_rows": int(len(df)),
"main_columns": int(len(df.columns)) if not df.empty else 0,
"datasets": list(_datasets.keys()),
})
return GradioApp.create_app(demo, app_kwargs={"docs_url": "/docs"})
# ---------------------------------------------------------------------
# Starlette app assembly
# ---------------------------------------------------------------------
def build_app() -> Starlette:
gradio_app = build_gradio_app()
app = Starlette(
routes=[
Route("/health", endpoint=health, methods=["GET"]),
Route("/mcp/sse", endpoint=mcp_sse_get, methods=["GET"]),
Route("/mcp/sse", endpoint=mcp_sse_post, methods=["POST"]),
Route("/mcp/sse", endpoint=options_handler, methods=["OPTIONS"]),
Route("/.well-known/oauth-protected-resource", endpoint=oauth_protected_resource, methods=["GET"]),
Route("/.well-known/oauth-protected-resource/mcp/sse", endpoint=oauth_protected_resource, methods=["GET"]),
Route("/.well-known/oauth-authorization-server", endpoint=oauth_authorization_server, methods=["GET"]),
Route("/register", endpoint=register, methods=["POST"]),
Mount("/", app=gradio_app),
]
)
print(f"MCP JSON-RPC server v{SERVER_INFO['version']} ready at /mcp/sse")
for r in app.routes:
print(f" {type(r).__name__}: {getattr(r, 'path', None)}")
return app
# ---------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------
if __name__ == "__main__":
now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S UTC")
print(f"===== Application Startup at {now} =====")
preload_data()
app = build_app()
port = int(os.environ.get("PORT", 7860))
print(f"Starting uvicorn on 0.0.0.0:{port} ...")
uvicorn.run(app, host="0.0.0.0", port=port)