nithin varghese commited on
Commit
6b4d66c
·
unverified ·
2 Parent(s): 87dc528 f8c07f6

Merge pull request #1 from humandotlearning/npi-mcp-server-implementation

Browse files
explanation.md CHANGED
@@ -1,23 +1,46 @@
1
- # NPI MCP Server for CredentialWatch
2
 
3
- This MCP server (`npi-mcp`) provides a normalized interface to the NPPES NPI Registry API, allowing the CredentialWatch agent system to search for healthcare providers and retrieve detailed provider information.
4
 
5
- ## How it works
6
 
7
- The server implements the Model Context Protocol (MCP) using HTTP + SSE. It exposes two tools:
 
 
 
8
 
9
- 1. **`search_providers`**: Searches for providers using a flexible query string (handling names and organization names) along with optional filters for state and taxonomy. It aggregates results from both Individual (NPI-1) and Organization (NPI-2) searches and normalizes the output.
10
- 2. **`get_provider_by_npi`**: Retrieves full details for a specific NPI, including all addresses and taxonomies, normalized into a clean JSON structure.
11
 
12
- ## Deployment
13
 
14
- The server is built with **FastAPI** and uses **uv** for dependency management. It is designed to be deployed as a stateless service (e.g., on Hugging Face Spaces).
15
 
16
- ### Endpoints
17
- - `/sse`: The MCP SSE endpoint for connecting agents.
18
- - `/messages`: The endpoint for sending JSON-RPC messages (handled via the SSE session).
19
- - `/healthz`: A simple health check endpoint.
20
 
21
- ## Usage
 
 
 
22
 
23
- Agents connect to the `/sse` endpoint to establish a session and discover tools. They can then invoke tools by sending JSON-RPC requests to the `/messages` endpoint (linked via session ID).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # NPI MCP Server Explanation
2
 
3
+ ## Architecture
4
 
5
+ This project implements a Model Context Protocol (MCP) server that acts as a bridge between an AI Agent and the NPPES NPI Registry.
6
 
7
+ 1. **Agent (Client)**: An AI agent (e.g., LangGraph) running in a separate environment connects to this MCP server.
8
+ 2. **MCP Server (`npi_mcp_server`)**: This server, running in a Hugging Face Space, exposes tools defined in the MCP specification (`search_npi_providers`, `get_npi_provider`).
9
+ 3. **Modal NPI API (`NPI_API`)**: The MCP server does not call the NPPES Registry directly. Instead, it forwards requests to a Modal-hosted FastAPI service (`NPI_API`). This Modal service handles the complex logic of querying the NPPES API and normalizing the data.
10
+ 4. **NPPES Registry**: The ultimate source of truth for provider data.
11
 
12
+ **Flow:**
13
+ `Agent` -> `MCP Server (this repo)` -> `Modal NPI API` -> `NPPES Registry`
14
 
15
+ ## Tools
16
 
17
+ ### 1. `search_npi_providers`
18
 
19
+ Searches for healthcare providers.
 
 
 
20
 
21
+ **Arguments:**
22
+ * `query` (str): Name or organization name.
23
+ * `state` (str, optional): 2-letter state code.
24
+ * `taxonomy` (str, optional): Taxonomy code.
25
 
26
+ **Example Invocation (JSON arguments):**
27
+ ```json
28
+ {
29
+ "query": "Mayo Clinic",
30
+ "state": "MN"
31
+ }
32
+ ```
33
+
34
+ ### 2. `get_npi_provider`
35
+
36
+ Retrieves details for a specific NPI.
37
+
38
+ **Arguments:**
39
+ * `npi` (str): 10-digit NPI number.
40
+
41
+ **Example Invocation (JSON arguments):**
42
+ ```json
43
+ {
44
+ "npi": "1234567890"
45
+ }
46
+ ```
pyproject.toml CHANGED
@@ -1,7 +1,7 @@
1
  [project]
2
- name = "npi-mcp"
3
  version = "0.1.0"
4
- description = "MCP server for NPPES NPI Registry"
5
  requires-python = ">=3.11"
6
  dependencies = [
7
  "fastapi>=0.100.0",
@@ -9,8 +9,10 @@ dependencies = [
9
  "httpx>=0.24.0",
10
  "pydantic>=2.0.0",
11
  "mcp>=1.0.0",
12
- "sse-starlette>=1.8.0",
13
- # Dev dependencies included here for simplicity in hackathon context
 
 
14
  "pytest>=7.0.0",
15
  "pytest-asyncio>=0.21.0",
16
  "pytest-mock>=3.10.0",
@@ -21,4 +23,9 @@ requires = ["hatchling"]
21
  build-backend = "hatchling.build"
22
 
23
  [tool.hatch.build.targets.wheel]
24
- packages = ["src/npi_mcp"]
 
 
 
 
 
 
1
  [project]
2
+ name = "npi-mcp-server"
3
  version = "0.1.0"
4
+ description = "MCP server for NPPES NPI Registry, forwarding to Modal NPI API"
5
  requires-python = ">=3.11"
6
  dependencies = [
7
  "fastapi>=0.100.0",
 
9
  "httpx>=0.24.0",
10
  "pydantic>=2.0.0",
11
  "mcp>=1.0.0",
12
+ ]
13
+
14
+ [project.optional-dependencies]
15
+ dev = [
16
  "pytest>=7.0.0",
17
  "pytest-asyncio>=0.21.0",
18
  "pytest-mock>=3.10.0",
 
23
  build-backend = "hatchling.build"
24
 
25
  [tool.hatch.build.targets.wheel]
26
+ packages = ["src/npi_mcp_server"]
27
+
28
+ [tool.pytest.ini_options]
29
+ pythonpath = ["src"]
30
+ testpaths = ["tests"]
31
+ asyncio_mode = "auto"
src/npi_mcp/__init__.py DELETED
File without changes
src/npi_mcp/__pycache__/__init__.cpython-312.pyc DELETED
Binary file (125 Bytes)
 
src/npi_mcp/__pycache__/main.cpython-312.pyc DELETED
Binary file (4.97 kB)
 
src/npi_mcp/main.py DELETED
@@ -1,121 +0,0 @@
1
- import logging
2
- from contextlib import asynccontextmanager
3
- import uuid
4
-
5
- from fastapi import FastAPI, Request
6
- from starlette.responses import JSONResponse
7
- from sse_starlette.sse import EventSourceResponse
8
-
9
- # mcp imports
10
- from mcp.server.sse import SseServerTransport
11
- from npi_mcp.mcp_tools import mcp_server, npi_client
12
-
13
- # Configure logging
14
- logging.basicConfig(level=logging.INFO)
15
- logger = logging.getLogger(__name__)
16
-
17
- # We need to track active SSE sessions to route POST messages to the correct transport
18
- # In a distributed deployment, this should be in an external store (e.g. Redis).
19
- sse_transports = {}
20
-
21
- @asynccontextmanager
22
- async def lifespan(app: FastAPI):
23
- # Startup
24
- logger.info("Starting NPI MCP Server...")
25
- yield
26
- # Shutdown
27
- logger.info("Shutting down NPI MCP Server...")
28
- await npi_client.close()
29
-
30
- app = FastAPI(lifespan=lifespan)
31
-
32
- @app.get("/healthz")
33
- async def healthcheck():
34
- """Health check endpoint."""
35
- return {"status": "ok"}
36
-
37
- @app.get("/sse")
38
- async def handle_sse(request: Request):
39
- """
40
- Handle incoming SSE connection.
41
- Creates a new SseServerTransport and runs the MCP server loop for this session.
42
- """
43
- session_id = str(uuid.uuid4())
44
-
45
- # Construct the endpoint URL that the client should use for subsequent messages
46
- # This URL is sent to the client in the initial 'endpoint' event.
47
- # Note: request.url_for handles the base URL automatically.
48
- endpoint_url = str(request.url_for("handle_messages")) + f"?session_id={session_id}"
49
-
50
- logger.info(f"New SSE connection: {session_id}")
51
-
52
- # Create the transport
53
- transport = SseServerTransport(endpoint_url)
54
-
55
- # Store it so handle_messages can find it
56
- sse_transports[session_id] = transport
57
-
58
- async def event_generator():
59
- try:
60
- # mcp_server.run connects the server logic to the transport
61
- # It reads from transport.incoming_messages and writes to transport.outgoing_messages
62
- # initialization_options can be passed if needed
63
- async with mcp_server.run(
64
- transport.read_incoming(),
65
- transport.write_outgoing(),
66
- initialization_options={}
67
- ):
68
- # The transport should yield the 'endpoint' event immediately upon connection?
69
- # SseServerTransport logic typically handles sending the endpoint event at start.
70
- # We just need to iterate over outgoing messages and yield them as SSE events.
71
-
72
- async for message in transport.outgoing_messages():
73
- # message is an SSEMessage object usually, or we need to format it?
74
- # mcp.server.sse.SseServerTransport.outgoing_messages yields starlette ServerSentEvent objects or similar?
75
- # Let's assume it yields objects compatible with EventSourceResponse or we need to extract.
76
-
77
- # Checking `mcp` implementation (mental model):
78
- # It likely yields ServerSentEvent objects.
79
- yield message
80
-
81
- except Exception as e:
82
- logger.error(f"Error in SSE session {session_id}: {e}")
83
- finally:
84
- logger.info(f"Closing SSE session: {session_id}")
85
- sse_transports.pop(session_id, None)
86
-
87
- return EventSourceResponse(event_generator())
88
-
89
- @app.post("/messages")
90
- async def handle_messages(request: Request):
91
- """
92
- Handle incoming JSON-RPC messages from the client.
93
- Routes the message to the correct SSE transport based on session_id.
94
- """
95
- session_id = request.query_params.get("session_id")
96
-
97
- if not session_id:
98
- # Some clients might pass it in the body or header? Spec says "endpoint" URI.
99
- # We encoded it in the query param.
100
- return JSONResponse(status_code=400, content={"error": "Missing session_id"})
101
-
102
- if session_id not in sse_transports:
103
- return JSONResponse(status_code=404, content={"error": "Session not found or expired"})
104
-
105
- transport = sse_transports[session_id]
106
-
107
- try:
108
- # Read the JSON-RPC message
109
- message = await request.json()
110
- except Exception:
111
- return JSONResponse(status_code=400, content={"error": "Invalid JSON"})
112
-
113
- # Pass the message to the transport
114
- # The transport puts it into the input queue which mcp_server.run consumes
115
- await transport.receive_json_message(message)
116
-
117
- return JSONResponse(content={"status": "accepted"})
118
-
119
- if __name__ == "__main__":
120
- import uvicorn
121
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/npi_mcp/mcp_tools.py DELETED
@@ -1,61 +0,0 @@
1
- from typing import Any, List
2
- import mcp.types as types
3
- from mcp.server import Server
4
- from npi_mcp.npi_client import NPIClient
5
- from npi_mcp.models import SearchProvidersArgs, GetProviderArgs
6
-
7
- # Create the MCP Server instance
8
- mcp_server = Server("npi-mcp")
9
-
10
- # We will need a way to pass the NPIClient to the tools.
11
- # We can instantiate it globally or contextually.
12
- # For simplicity, we'll use a global client, but we need to manage its lifecycle.
13
-
14
- npi_client = NPIClient()
15
-
16
- @mcp_server.list_tools()
17
- async def list_tools() -> List[types.Tool]:
18
- return [
19
- types.Tool(
20
- name="search_providers",
21
- description="Search for healthcare providers in the NPI Registry by name, organization, state, or taxonomy.",
22
- inputSchema=SearchProvidersArgs.model_json_schema(),
23
- ),
24
- types.Tool(
25
- name="get_provider_by_npi",
26
- description="Retrieve detailed information about a specific provider using their NPI number.",
27
- inputSchema=GetProviderArgs.model_json_schema(),
28
- ),
29
- ]
30
-
31
- @mcp_server.call_tool()
32
- async def call_tool(name: str, arguments: Any) -> List[types.TextContent]:
33
- if name == "search_providers":
34
- # Validate arguments
35
- args = SearchProvidersArgs(**arguments)
36
-
37
- results = await npi_client.search_providers(
38
- query=args.query,
39
- state=args.state,
40
- taxonomy=args.taxonomy
41
- )
42
-
43
- # Format as JSON string
44
- json_results = [r.model_dump_json() for r in results]
45
- # Or return a single JSON list
46
- import json
47
- final_json = json.dumps([r.model_dump() for r in results], indent=2)
48
-
49
- return [types.TextContent(type="text", text=final_json)]
50
-
51
- elif name == "get_provider_by_npi":
52
- args = GetProviderArgs(**arguments)
53
- result = await npi_client.get_provider_by_npi(args.npi)
54
-
55
- if result:
56
- return [types.TextContent(type="text", text=result.model_dump_json(indent=2))]
57
- else:
58
- return [types.TextContent(type="text", text=f"{{ 'error': 'Provider with NPI {args.npi} not found.' }}")]
59
-
60
- else:
61
- raise ValueError(f"Unknown tool: {name}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/npi_mcp/npi_client.py DELETED
@@ -1,211 +0,0 @@
1
- import httpx
2
- import logging
3
- from typing import List, Optional, Dict, Any
4
-
5
- from npi_mcp.models import ProviderSummary, ProviderDetail, Address, Taxonomy
6
-
7
- logger = logging.getLogger(__name__)
8
-
9
- class NPIClient:
10
- BASE_URL = "https://npiregistry.cms.hhs.gov/api/"
11
-
12
- def __init__(self):
13
- self.client = httpx.AsyncClient(timeout=30.0)
14
-
15
- async def close(self):
16
- await self.client.aclose()
17
-
18
- def _normalize_address(self, addr_data: Dict[str, Any]) -> Address:
19
- """Helper to convert API address format to our Address model."""
20
- return Address(
21
- line1=addr_data.get("address_1", ""),
22
- line2=addr_data.get("address_2") or None,
23
- city=addr_data.get("city", ""),
24
- state=addr_data.get("state", ""),
25
- postal_code=addr_data.get("postal_code", "")[:5], # Normalize to 5 digit for simplicity? Or keep full.
26
- country=addr_data.get("country_code", "US")
27
- )
28
-
29
- def _get_full_name(self, basic: Dict[str, Any], enumeration_type: str) -> str:
30
- if enumeration_type == "NPI-2":
31
- return basic.get("organization_name", "Unknown Organization")
32
- else:
33
- first = basic.get("first_name", "")
34
- last = basic.get("last_name", "")
35
- credential = basic.get("credential", "")
36
- name = f"{first} {last}".strip()
37
- if credential:
38
- name += f", {credential}"
39
- return name
40
-
41
- def _extract_primary_taxonomy(self, taxonomies: List[Dict[str, Any]]) -> tuple[Optional[str], Optional[str]]:
42
- """Returns (code, description) of primary taxonomy."""
43
- for tax in taxonomies:
44
- if tax.get("primary") is True:
45
- return tax.get("code"), tax.get("desc")
46
- # Fallback to first if no primary
47
- if taxonomies:
48
- return taxonomies[0].get("code"), taxonomies[0].get("desc")
49
- return None, None
50
-
51
- async def search_providers(
52
- self,
53
- query: str,
54
- state: Optional[str] = None,
55
- taxonomy: Optional[str] = None
56
- ) -> List[ProviderSummary]:
57
- """
58
- Searches for providers.
59
- Since the API splits fields, we try to be smart about 'query'.
60
- """
61
- results: List[Dict[str, Any]] = []
62
-
63
- # Strategy:
64
- # 1. Generic Organization Search (wildcard)
65
- # 2. Individual Search (splitting query)
66
-
67
- # We'll make parallel requests or sequential.
68
- # API requires specific fields.
69
-
70
- params_common = {
71
- "version": "2.1",
72
- "limit": 50 # Reasonable limit
73
- }
74
- if state:
75
- params_common["state"] = state
76
- if taxonomy:
77
- params_common["taxonomy_description"] = taxonomy
78
- # Note: API doc says "taxonomy_description", but often code works or is handled.
79
- # If "207RC0000X" is passed, we rely on the API handling it in description or matching.
80
- # If not, this might be a limitation.
81
-
82
- search_requests = []
83
-
84
- # Request 1: Organization
85
- req_org = params_common.copy()
86
- req_org["enumeration_type"] = "NPI-2"
87
- req_org["organization_name"] = f"{query}*"
88
- search_requests.append(req_org)
89
-
90
- # Request 2: Individual (Last Name match)
91
- # If query is single word
92
- parts = query.split()
93
- if len(parts) == 1:
94
- req_ind = params_common.copy()
95
- req_ind["enumeration_type"] = "NPI-1"
96
- req_ind["last_name"] = f"{query}*"
97
- search_requests.append(req_ind)
98
- elif len(parts) >= 2:
99
- # First Last
100
- req_ind = params_common.copy()
101
- req_ind["enumeration_type"] = "NPI-1"
102
- req_ind["first_name"] = parts[0]
103
- req_ind["last_name"] = f"{parts[-1]}*" # Use wildcard on last name
104
- search_requests.append(req_ind)
105
-
106
- # Execute requests
107
- # We run them sequentially for simplicity in this implementation,
108
- # but could use asyncio.gather
109
-
110
- seen_npis = set()
111
- normalized_results = []
112
-
113
- for params in search_requests:
114
- try:
115
- resp = await self.client.get(self.BASE_URL, params=params)
116
- resp.raise_for_status()
117
- data = resp.json()
118
-
119
- # API returns { "result_count": ..., "results": [...] } or errors
120
- items = data.get("results", [])
121
-
122
- for item in items:
123
- npi = item.get("number")
124
- if npi in seen_npis:
125
- continue
126
- seen_npis.add(npi)
127
-
128
- basic = item.get("basic", {})
129
- enum_type = item.get("enumeration_type", "UNKNOWN")
130
- # Map NPI-1 to INDIVIDUAL, NPI-2 to ORGANIZATION
131
- type_str = "INDIVIDUAL" if enum_type == "NPI-1" else "ORGANIZATION"
132
-
133
- full_name = self._get_full_name(basic, enum_type)
134
-
135
- taxonomies = item.get("taxonomies", [])
136
- prim_code, prim_desc = self._extract_primary_taxonomy(taxonomies)
137
-
138
- # Find primary address (usually location address)
139
- addresses = item.get("addresses", [])
140
- primary_addr_data = next(
141
- (a for a in addresses if a.get("address_purpose") == "LOCATION"),
142
- addresses[0] if addresses else {}
143
- )
144
-
145
- normalized_results.append(ProviderSummary(
146
- npi=str(npi),
147
- full_name=full_name,
148
- enumeration_type=type_str,
149
- primary_taxonomy=prim_code,
150
- primary_specialty=prim_desc,
151
- primary_address=self._normalize_address(primary_addr_data)
152
- ))
153
- except Exception as e:
154
- logger.error(f"Error querying NPI API with params {params}: {e}")
155
- # Continue to next request strategy
156
- continue
157
-
158
- return normalized_results
159
-
160
- async def get_provider_by_npi(self, npi: str) -> Optional[ProviderDetail]:
161
- params = {
162
- "version": "2.1",
163
- "number": npi
164
- }
165
- try:
166
- resp = await self.client.get(self.BASE_URL, params=params)
167
- resp.raise_for_status()
168
- data = resp.json()
169
-
170
- results = data.get("results", [])
171
- if not results:
172
- return None
173
-
174
- item = results[0]
175
- basic = item.get("basic", {})
176
- enum_type = item.get("enumeration_type", "UNKNOWN")
177
- type_str = "INDIVIDUAL" if enum_type == "NPI-1" else "ORGANIZATION"
178
-
179
- full_name = self._get_full_name(basic, enum_type)
180
-
181
- # Addresses
182
- raw_addresses = item.get("addresses", [])
183
- addresses = [self._normalize_address(a) for a in raw_addresses]
184
-
185
- # Taxonomies
186
- raw_taxonomies = item.get("taxonomies", [])
187
- taxonomies = []
188
- for t in raw_taxonomies:
189
- taxonomies.append(Taxonomy(
190
- code=t.get("code", ""),
191
- description=t.get("desc"),
192
- primary=t.get("primary", False),
193
- state=t.get("state"),
194
- license=t.get("license")
195
- ))
196
-
197
- return ProviderDetail(
198
- npi=str(item.get("number")),
199
- full_name=full_name,
200
- enumeration_type=type_str,
201
- addresses=addresses,
202
- taxonomies=taxonomies
203
- )
204
-
205
- except httpx.HTTPStatusError as e:
206
- if e.response.status_code == 404:
207
- return None
208
- raise e
209
- except Exception as e:
210
- logger.error(f"Error fetching NPI {npi}: {e}")
211
- raise e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/npi_mcp_server/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Expose main entry point
2
+ from npi_mcp_server.main import mcp
3
+
4
+ def main():
5
+ mcp.run()
src/npi_mcp_server/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (334 Bytes). View file
 
src/npi_mcp_server/__pycache__/config.cpython-312.pyc ADDED
Binary file (288 Bytes). View file
 
src/npi_mcp_server/__pycache__/main.cpython-312.pyc ADDED
Binary file (1.82 kB). View file
 
src/{npi_mcp → npi_mcp_server}/__pycache__/mcp_tools.cpython-312.pyc RENAMED
File without changes
src/{npi_mcp → npi_mcp_server}/__pycache__/models.cpython-312.pyc RENAMED
File without changes
src/{npi_mcp → npi_mcp_server}/__pycache__/npi_client.cpython-312.pyc RENAMED
File without changes
src/npi_mcp_server/__pycache__/schemas.cpython-312.pyc ADDED
Binary file (3.01 kB). View file
 
src/npi_mcp_server/__pycache__/tools.cpython-312.pyc ADDED
Binary file (4.66 kB). View file
 
src/npi_mcp_server/config.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import os
2
+
3
+ # Base URL for the Modal NPI API service
4
+ # Default to localhost for testing if not set, but in prod it should be set.
5
+ NPI_API_BASE_URL = os.environ.get("NPI_API_BASE_URL", "http://localhost:8000")
src/npi_mcp_server/main.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mcp.server.fastmcp import FastMCP
2
+ from npi_mcp_server.tools import search_providers as _search_providers, get_provider_by_npi as _get_provider_by_npi
3
+ from npi_mcp_server.schemas import ProviderSummary, ProviderDetail
4
+
5
+ # Initialize FastMCP server
6
+ # "npi-mcp" is the name of the server
7
+ mcp = FastMCP("npi-mcp")
8
+
9
+ @mcp.tool()
10
+ async def search_providers(query: str, state: str | None = None, taxonomy: str | None = None) -> list[ProviderSummary]:
11
+ """
12
+ Search for healthcare providers in the NPI Registry by name, organization, state, or taxonomy.
13
+ Fowards requests to the NPI API service.
14
+
15
+ Args:
16
+ query: Name of the provider (first/last) or organization, or a generic search term.
17
+ state: 2-letter state code (e.g. 'CA', 'NY').
18
+ taxonomy: Taxonomy code or description (e.g. '207RC0000X').
19
+ """
20
+ return await _search_providers(query, state, taxonomy)
21
+
22
+ @mcp.tool()
23
+ async def get_provider_by_npi(npi: str) -> ProviderDetail | str:
24
+ """
25
+ Retrieve detailed information about a specific provider using their NPI number.
26
+
27
+ Args:
28
+ npi: The 10-digit NPI number.
29
+ """
30
+ result = await _get_provider_by_npi(npi)
31
+ if result:
32
+ return result
33
+ return f"Provider with NPI {npi} not found."
34
+
35
+ # Entry point for running the server directly (e.g. for testing)
36
+ if __name__ == "__main__":
37
+ mcp.run()
src/{npi_mcp/models.py → npi_mcp_server/schemas.py} RENAMED
@@ -11,7 +11,8 @@ class SearchProvidersArgs(BaseModel):
11
  class GetProviderArgs(BaseModel):
12
  npi: str = Field(..., description="The 10-digit NPI number.")
13
 
14
- # --- Normalized Response Models ---
 
15
 
16
  class Address(BaseModel):
17
  line1: str
@@ -29,6 +30,9 @@ class ProviderSummary(BaseModel):
29
  primary_specialty: Optional[str] = None
30
  primary_address: Address
31
 
 
 
 
32
  class Taxonomy(BaseModel):
33
  code: str
34
  description: Optional[str] = None
 
11
  class GetProviderArgs(BaseModel):
12
  npi: str = Field(..., description="The 10-digit NPI number.")
13
 
14
+ # --- NPI API Response Models ---
15
+ # These mirror the structure returned by the Modal NPI_API.
16
 
17
  class Address(BaseModel):
18
  line1: str
 
30
  primary_specialty: Optional[str] = None
31
  primary_address: Address
32
 
33
+ class SearchProvidersResponse(BaseModel):
34
+ results: List[ProviderSummary]
35
+
36
  class Taxonomy(BaseModel):
37
  code: str
38
  description: Optional[str] = None
src/npi_mcp_server/tools.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import httpx
2
+ import logging
3
+ from typing import List, Optional
4
+
5
+ from npi_mcp_server.config import NPI_API_BASE_URL
6
+ from npi_mcp_server.schemas import (
7
+ ProviderSummary,
8
+ ProviderDetail,
9
+ SearchProvidersResponse
10
+ )
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ async def search_providers(
15
+ query: str,
16
+ state: Optional[str] = None,
17
+ taxonomy: Optional[str] = None
18
+ ) -> List[ProviderSummary]:
19
+ """
20
+ Search for healthcare providers via the NPI API.
21
+
22
+ Args:
23
+ query: Name (first/last) or organization name.
24
+ state: Two-letter state code.
25
+ taxonomy: Taxonomy code or description.
26
+
27
+ Returns:
28
+ List of ProviderSummary objects.
29
+ """
30
+ url = f"{NPI_API_BASE_URL.rstrip('/')}/search_providers"
31
+
32
+ payload = {
33
+ "query": query,
34
+ "state": state,
35
+ "taxonomy": taxonomy
36
+ }
37
+
38
+ # Remove None values
39
+ payload = {k: v for k, v in payload.items() if v is not None}
40
+
41
+ async with httpx.AsyncClient() as client:
42
+ try:
43
+ response = await client.post(url, json=payload, timeout=30.0)
44
+ response.raise_for_status()
45
+
46
+ data = response.json()
47
+ # Expecting: { "results": [ ... ] }
48
+ return SearchProvidersResponse(**data).results
49
+
50
+ except httpx.HTTPStatusError as e:
51
+ logger.error(f"HTTP error calling NPI API: {e.response.status_code} - {e.response.text}")
52
+ raise RuntimeError(f"NPI API returned error: {e.response.status_code}") from e
53
+ except Exception as e:
54
+ logger.error(f"Error calling NPI API: {e}")
55
+ raise RuntimeError(f"Failed to search providers: {str(e)}") from e
56
+
57
+ async def get_provider_by_npi(npi: str) -> Optional[ProviderDetail]:
58
+ """
59
+ Retrieve details for a specific provider by NPI.
60
+
61
+ Args:
62
+ npi: 10-digit NPI string.
63
+
64
+ Returns:
65
+ ProviderDetail object or None if not found.
66
+ """
67
+ url = f"{NPI_API_BASE_URL.rstrip('/')}/provider/{npi}"
68
+
69
+ async with httpx.AsyncClient() as client:
70
+ try:
71
+ response = await client.get(url, timeout=30.0)
72
+
73
+ if response.status_code == 404:
74
+ return None
75
+
76
+ response.raise_for_status()
77
+
78
+ data = response.json()
79
+ return ProviderDetail(**data)
80
+
81
+ except httpx.HTTPStatusError as e:
82
+ logger.error(f"HTTP error calling NPI API: {e.response.status_code} - {e.response.text}")
83
+ raise RuntimeError(f"NPI API returned error: {e.response.status_code}") from e
84
+ except Exception as e:
85
+ logger.error(f"Error calling NPI API: {e}")
86
+ raise RuntimeError(f"Failed to get provider: {str(e)}") from e
tests/__pycache__/test_npi_mcp_server.cpython-312-pytest-8.4.2.pyc ADDED
Binary file (9.22 kB). View file
 
tests/__pycache__/test_npi_mcp_server.cpython-312-pytest-9.0.1.pyc ADDED
Binary file (9.22 kB). View file
 
tests/test_npi_mcp.py DELETED
@@ -1,112 +0,0 @@
1
- import pytest
2
- from httpx import Response
3
- from npi_mcp.npi_client import NPIClient
4
- from npi_mcp.models import ProviderSummary, ProviderDetail
5
-
6
- # Mock data
7
- MOCK_SEARCH_RESPONSE_IND = {
8
- "result_count": 1,
9
- "results": [
10
- {
11
- "number": "1234567890",
12
- "basic": {
13
- "first_name": "John",
14
- "last_name": "Doe",
15
- "credential": "MD"
16
- },
17
- "enumeration_type": "NPI-1",
18
- "taxonomies": [
19
- {"code": "207RC0000X", "desc": "Cardiology", "primary": True}
20
- ],
21
- "addresses": [
22
- {
23
- "address_purpose": "LOCATION",
24
- "address_1": "123 Main St",
25
- "city": "Anytown",
26
- "state": "CA",
27
- "postal_code": "90210",
28
- "country_code": "US"
29
- }
30
- ]
31
- }
32
- ]
33
- }
34
-
35
- MOCK_SEARCH_RESPONSE_ORG = {
36
- "result_count": 1,
37
- "results": [
38
- {
39
- "number": "9876543210",
40
- "basic": {
41
- "organization_name": "General Hospital"
42
- },
43
- "enumeration_type": "NPI-2",
44
- "taxonomies": [],
45
- "addresses": [
46
- {
47
- "address_purpose": "LOCATION",
48
- "address_1": "456 Health Blvd",
49
- "city": "Metropolis",
50
- "state": "NY",
51
- "postal_code": "10001",
52
- "country_code": "US"
53
- }
54
- ]
55
- }
56
- ]
57
- }
58
-
59
- import httpx
60
-
61
- @pytest.mark.asyncio
62
- async def test_search_providers_individual(mocker):
63
- # Mock httpx client
64
- # Note: raise_for_status requires a request object
65
- resp = Response(200, json=MOCK_SEARCH_RESPONSE_IND)
66
- resp._request = httpx.Request("GET", "https://mock")
67
- mock_get = mocker.patch("httpx.AsyncClient.get", return_value=resp)
68
-
69
- client = NPIClient()
70
- results = await client.search_providers(query="John Doe")
71
-
72
- assert len(results) >= 1
73
- p = results[0]
74
- assert p.full_name == "John Doe, MD"
75
- assert p.enumeration_type == "INDIVIDUAL"
76
- assert p.primary_address.city == "Anytown"
77
-
78
- await client.close()
79
-
80
- @pytest.mark.asyncio
81
- async def test_search_providers_org(mocker):
82
- # Mock httpx client
83
- resp = Response(200, json=MOCK_SEARCH_RESPONSE_ORG)
84
- resp._request = httpx.Request("GET", "https://mock")
85
- mock_get = mocker.patch("httpx.AsyncClient.get", return_value=resp)
86
-
87
- client = NPIClient()
88
- results = await client.search_providers(query="General Hospital")
89
-
90
- assert len(results) >= 1
91
- p = results[0]
92
- assert p.full_name == "General Hospital"
93
- assert p.enumeration_type == "ORGANIZATION"
94
-
95
- await client.close()
96
-
97
- @pytest.mark.asyncio
98
- async def test_get_provider_by_npi(mocker):
99
- resp = Response(200, json=MOCK_SEARCH_RESPONSE_IND)
100
- resp._request = httpx.Request("GET", "https://mock")
101
- mock_get = mocker.patch("httpx.AsyncClient.get", return_value=resp)
102
-
103
- client = NPIClient()
104
- result = await client.get_provider_by_npi("1234567890")
105
-
106
- assert result is not None
107
- assert result.npi == "1234567890"
108
- assert result.full_name == "John Doe, MD"
109
- assert len(result.taxonomies) == 1
110
- assert result.taxonomies[0].code == "207RC0000X"
111
-
112
- await client.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/test_npi_mcp_server.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from httpx import Response
3
+ import httpx
4
+ from npi_mcp_server.tools import search_providers, get_provider_by_npi
5
+ from npi_mcp_server.schemas import ProviderSummary, ProviderDetail
6
+
7
+ # Mock data matching the Modal NPI API contract
8
+ MOCK_SEARCH_RESPONSE = {
9
+ "results": [
10
+ {
11
+ "npi": "1234567890",
12
+ "full_name": "John Doe, MD",
13
+ "enumeration_type": "INDIVIDUAL",
14
+ "primary_taxonomy": "207RC0000X",
15
+ "primary_specialty": "Cardiology",
16
+ "primary_address": {
17
+ "line1": "123 Main St",
18
+ "city": "Anytown",
19
+ "state": "CA",
20
+ "postal_code": "90210",
21
+ "country": "US"
22
+ }
23
+ }
24
+ ]
25
+ }
26
+
27
+ MOCK_PROVIDER_DETAIL = {
28
+ "npi": "1234567890",
29
+ "full_name": "John Doe, MD",
30
+ "enumeration_type": "INDIVIDUAL",
31
+ "addresses": [
32
+ {
33
+ "line1": "123 Main St",
34
+ "city": "Anytown",
35
+ "state": "CA",
36
+ "postal_code": "90210",
37
+ "country": "US"
38
+ }
39
+ ],
40
+ "taxonomies": [
41
+ {
42
+ "code": "207RC0000X",
43
+ "description": "Cardiology",
44
+ "primary": True,
45
+ "state": "CA",
46
+ "license": "12345"
47
+ }
48
+ ]
49
+ }
50
+
51
+ @pytest.mark.asyncio
52
+ async def test_search_providers(mocker):
53
+ # Mock httpx client post
54
+ resp = Response(200, json=MOCK_SEARCH_RESPONSE)
55
+ resp._request = httpx.Request("POST", "https://mock/search_providers")
56
+
57
+ mock_post = mocker.patch("httpx.AsyncClient.post", return_value=resp)
58
+
59
+ results = await search_providers(query="John Doe")
60
+
61
+ assert len(results) == 1
62
+ p = results[0]
63
+ assert p.npi == "1234567890"
64
+ assert p.full_name == "John Doe, MD"
65
+
66
+ # Check if correct URL and payload were sent
67
+ # We can check mock_post.call_args
68
+ call_args = mock_post.call_args
69
+ assert "search_providers" in call_args[0][0]
70
+ assert call_args[1]['json'] == {"query": "John Doe"}
71
+
72
+ @pytest.mark.asyncio
73
+ async def test_get_provider_by_npi(mocker):
74
+ # Mock httpx client get
75
+ resp = Response(200, json=MOCK_PROVIDER_DETAIL)
76
+ resp._request = httpx.Request("GET", "https://mock/provider/1234567890")
77
+
78
+ mock_get = mocker.patch("httpx.AsyncClient.get", return_value=resp)
79
+
80
+ result = await get_provider_by_npi("1234567890")
81
+
82
+ assert result is not None
83
+ assert result.npi == "1234567890"
84
+ assert result.full_name == "John Doe, MD"
85
+ assert len(result.taxonomies) == 1
86
+
87
+ @pytest.mark.asyncio
88
+ async def test_get_provider_by_npi_not_found(mocker):
89
+ resp = Response(404)
90
+ resp._request = httpx.Request("GET", "https://mock/provider/0000000000")
91
+
92
+ mock_get = mocker.patch("httpx.AsyncClient.get", return_value=resp)
93
+
94
+ result = await get_provider_by_npi("0000000000")
95
+ assert result is None
uv.lock CHANGED
@@ -336,7 +336,7 @@ wheels = [
336
  ]
337
 
338
  [[package]]
339
- name = "npi-mcp"
340
  version = "0.1.0"
341
  source = { editable = "." }
342
  dependencies = [
@@ -344,11 +344,14 @@ dependencies = [
344
  { name = "httpx" },
345
  { name = "mcp" },
346
  { name = "pydantic" },
 
 
 
 
 
347
  { name = "pytest" },
348
  { name = "pytest-asyncio" },
349
  { name = "pytest-mock" },
350
- { name = "sse-starlette" },
351
- { name = "uvicorn" },
352
  ]
353
 
354
  [package.metadata]
@@ -357,12 +360,12 @@ requires-dist = [
357
  { name = "httpx", specifier = ">=0.24.0" },
358
  { name = "mcp", specifier = ">=1.0.0" },
359
  { name = "pydantic", specifier = ">=2.0.0" },
360
- { name = "pytest", specifier = ">=7.0.0" },
361
- { name = "pytest-asyncio", specifier = ">=0.21.0" },
362
- { name = "pytest-mock", specifier = ">=3.10.0" },
363
- { name = "sse-starlette", specifier = ">=1.8.0" },
364
  { name = "uvicorn", specifier = ">=0.20.0" },
365
  ]
 
366
 
367
  [[package]]
368
  name = "packaging"
 
336
  ]
337
 
338
  [[package]]
339
+ name = "npi-mcp-server"
340
  version = "0.1.0"
341
  source = { editable = "." }
342
  dependencies = [
 
344
  { name = "httpx" },
345
  { name = "mcp" },
346
  { name = "pydantic" },
347
+ { name = "uvicorn" },
348
+ ]
349
+
350
+ [package.optional-dependencies]
351
+ dev = [
352
  { name = "pytest" },
353
  { name = "pytest-asyncio" },
354
  { name = "pytest-mock" },
 
 
355
  ]
356
 
357
  [package.metadata]
 
360
  { name = "httpx", specifier = ">=0.24.0" },
361
  { name = "mcp", specifier = ">=1.0.0" },
362
  { name = "pydantic", specifier = ">=2.0.0" },
363
+ { name = "pytest", marker = "extra == 'dev'", specifier = ">=7.0.0" },
364
+ { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.21.0" },
365
+ { name = "pytest-mock", marker = "extra == 'dev'", specifier = ">=3.10.0" },
 
366
  { name = "uvicorn", specifier = ">=0.20.0" },
367
  ]
368
+ provides-extras = ["dev"]
369
 
370
  [[package]]
371
  name = "packaging"