applemuncy commited on
Commit
4b7309f
·
verified ·
1 Parent(s): 0cf3b3a

Upload 3 files

Browse files
Files changed (3) hide show
  1. auth_server.py +187 -0
  2. requirements.txt +1 -0
  3. simple_auth_provider.py +271 -0
auth_server.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Authorization Server for MCP Split Demo.
3
+
4
+ This server handles OAuth flows, client registration, and token issuance.
5
+ Can be replaced with enterprise authorization servers like Auth0, Entra ID, etc.
6
+
7
+ NOTE: this is a simplified example for demonstration purposes.
8
+ This is not a production-ready implementation.
9
+
10
+ """
11
+
12
+ import asyncio
13
+ import logging
14
+ import time
15
+
16
+ import click
17
+ from pydantic import AnyHttpUrl, BaseModel
18
+ from starlette.applications import Starlette
19
+ from starlette.exceptions import HTTPException
20
+ from starlette.requests import Request
21
+ from starlette.responses import JSONResponse, Response
22
+ from starlette.routing import Route
23
+ from uvicorn import Config, Server
24
+
25
+ from mcp.server.auth.routes import cors_middleware, create_auth_routes
26
+ from mcp.server.auth.settings import AuthSettings, ClientRegistrationOptions
27
+
28
+ from simple_auth_provider import SimpleAuthSettings, SimpleOAuthProvider
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ class AuthServerSettings(BaseModel):
34
+ """Settings for the Authorization Server."""
35
+
36
+ # Server settings
37
+ host: str = "localhost"
38
+ port: int = 9000
39
+ server_url: AnyHttpUrl = AnyHttpUrl("http://localhost:9000")
40
+ auth_callback_path: str = "http://localhost:9000/login/callback"
41
+
42
+
43
+ class SimpleAuthProvider(SimpleOAuthProvider):
44
+ """
45
+ Authorization Server provider with simple demo authentication.
46
+
47
+ This provider:
48
+ 1. Issues MCP tokens after simple credential authentication
49
+ 2. Stores token state for introspection by Resource Servers
50
+ """
51
+
52
+ def __init__(self, auth_settings: SimpleAuthSettings, auth_callback_path: str, server_url: str):
53
+ super().__init__(auth_settings, auth_callback_path, server_url)
54
+
55
+
56
+ def create_authorization_server(server_settings: AuthServerSettings, auth_settings: SimpleAuthSettings) -> Starlette:
57
+ """Create the Authorization Server application."""
58
+ oauth_provider = SimpleAuthProvider(
59
+ auth_settings, server_settings.auth_callback_path, str(server_settings.server_url)
60
+ )
61
+
62
+ mcp_auth_settings = AuthSettings(
63
+ issuer_url=server_settings.server_url,
64
+ client_registration_options=ClientRegistrationOptions(
65
+ enabled=True,
66
+ valid_scopes=[auth_settings.mcp_scope],
67
+ default_scopes=[auth_settings.mcp_scope],
68
+ ),
69
+ required_scopes=[auth_settings.mcp_scope],
70
+ resource_server_url=None,
71
+ )
72
+
73
+ # Create OAuth routes
74
+ routes = create_auth_routes(
75
+ provider=oauth_provider,
76
+ issuer_url=mcp_auth_settings.issuer_url,
77
+ service_documentation_url=mcp_auth_settings.service_documentation_url,
78
+ client_registration_options=mcp_auth_settings.client_registration_options,
79
+ revocation_options=mcp_auth_settings.revocation_options,
80
+ )
81
+
82
+ # Add login page route (GET)
83
+ async def login_page_handler(request: Request) -> Response:
84
+ """Show login form."""
85
+ state = request.query_params.get("state")
86
+ if not state:
87
+ raise HTTPException(400, "Missing state parameter")
88
+ return await oauth_provider.get_login_page(state)
89
+
90
+ routes.append(Route("/login", endpoint=login_page_handler, methods=["GET"]))
91
+
92
+ # Add login callback route (POST)
93
+ async def login_callback_handler(request: Request) -> Response:
94
+ """Handle simple authentication callback."""
95
+ return await oauth_provider.handle_login_callback(request)
96
+
97
+ routes.append(Route("/login/callback", endpoint=login_callback_handler, methods=["POST"]))
98
+
99
+ # Add token introspection endpoint (RFC 7662) for Resource Servers
100
+ async def introspect_handler(request: Request) -> Response:
101
+ """
102
+ Token introspection endpoint for Resource Servers.
103
+
104
+ Resource Servers call this endpoint to validate tokens without
105
+ needing direct access to token storage.
106
+ """
107
+ form = await request.form()
108
+ token = form.get("token")
109
+ if not token or not isinstance(token, str):
110
+ return JSONResponse({"active": False}, status_code=400)
111
+
112
+ # Look up token in provider
113
+ access_token = await oauth_provider.load_access_token(token)
114
+ if not access_token:
115
+ return JSONResponse({"active": False})
116
+
117
+ return JSONResponse(
118
+ {
119
+ "active": True,
120
+ "client_id": access_token.client_id,
121
+ "scope": " ".join(access_token.scopes),
122
+ "exp": access_token.expires_at,
123
+ "iat": int(time.time()),
124
+ "token_type": "Bearer",
125
+ "aud": access_token.resource, # RFC 8707 audience claim
126
+ }
127
+ )
128
+
129
+ routes.append(
130
+ Route(
131
+ "/introspect",
132
+ endpoint=cors_middleware(introspect_handler, ["POST", "OPTIONS"]),
133
+ methods=["POST", "OPTIONS"],
134
+ )
135
+ )
136
+
137
+ return Starlette(routes=routes)
138
+
139
+
140
+ async def run_server(server_settings: AuthServerSettings, auth_settings: SimpleAuthSettings):
141
+ """Run the Authorization Server."""
142
+ auth_server = create_authorization_server(server_settings, auth_settings)
143
+
144
+ config = Config(
145
+ auth_server,
146
+ host=server_settings.host,
147
+ port=server_settings.port,
148
+ log_level="info",
149
+ )
150
+ server = Server(config)
151
+
152
+ logger.info(f"🚀 MCP Authorization Server running on {server_settings.server_url}")
153
+
154
+ await server.serve()
155
+
156
+
157
+ @click.command()
158
+ @click.option("--port", default=9000, help="Port to listen on")
159
+ def main(port: int) -> int:
160
+ """
161
+ Run the MCP Authorization Server.
162
+
163
+ This server handles OAuth flows and can be used by multiple Resource Servers.
164
+
165
+ Uses simple hardcoded credentials for demo purposes.
166
+ """
167
+ logging.basicConfig(level=logging.INFO)
168
+
169
+ # Load simple auth settings
170
+ auth_settings = SimpleAuthSettings()
171
+
172
+ # Create server settings
173
+ host = "localhost"
174
+ server_url = f"http://{host}:{port}"
175
+ server_settings = AuthServerSettings(
176
+ host=host,
177
+ port=port,
178
+ server_url=AnyHttpUrl(server_url),
179
+ auth_callback_path=f"{server_url}/login",
180
+ )
181
+
182
+ asyncio.run(run_server(server_settings, auth_settings))
183
+ return 0
184
+
185
+
186
+ if __name__ == "__main__":
187
+ main() # type: ignore[call-arg]
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ mcp==1.11.0
simple_auth_provider.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Simple OAuth provider for MCP servers.
3
+
4
+ This module contains a basic OAuth implementation using hardcoded user credentials
5
+ for demonstration purposes. No external authentication provider is required.
6
+
7
+ NOTE: this is a simplified example for demonstration purposes.
8
+ This is not a production-ready implementation.
9
+
10
+ """
11
+
12
+ import logging
13
+ import secrets
14
+ import time
15
+ from typing import Any
16
+ import os
17
+ from pydantic import AnyHttpUrl
18
+ from pydantic_settings import BaseSettings, SettingsConfigDict
19
+ from starlette.exceptions import HTTPException
20
+ from starlette.requests import Request
21
+ from starlette.responses import HTMLResponse, RedirectResponse, Response
22
+
23
+ from mcp.server.auth.provider import (
24
+ AccessToken,
25
+ AuthorizationCode,
26
+ AuthorizationParams,
27
+ OAuthAuthorizationServerProvider,
28
+ RefreshToken,
29
+ construct_redirect_uri,
30
+ )
31
+ from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ class SimpleAuthSettings(BaseSettings):
37
+ """Simple OAuth settings for demo purposes."""
38
+
39
+ model_config = SettingsConfigDict(env_prefix="MCP_")
40
+
41
+ # Demo user credentials
42
+ demo_username: str = os.getenv('DEMO_USER',"demo_user")
43
+ demo_password: str = os.getenv('DEMO_PASSWORD',"demo_password")
44
+
45
+
46
+ # MCP OAuth scope
47
+ mcp_scope: str = "user"
48
+
49
+
50
+ class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
51
+ """
52
+ Simple OAuth provider for demo purposes.
53
+
54
+ This provider handles the OAuth flow by:
55
+ 1. Providing a simple login form for demo credentials
56
+ 2. Issuing MCP tokens after successful authentication
57
+ 3. Maintaining token state for introspection
58
+ """
59
+
60
+ def __init__(self, settings: SimpleAuthSettings, auth_callback_url: str, server_url: str):
61
+ self.settings = settings
62
+ self.auth_callback_url = auth_callback_url
63
+ self.server_url = server_url
64
+ self.clients: dict[str, OAuthClientInformationFull] = {}
65
+ self.auth_codes: dict[str, AuthorizationCode] = {}
66
+ self.tokens: dict[str, AccessToken] = {}
67
+ self.state_mapping: dict[str, dict[str, str | None]] = {}
68
+ # Store authenticated user information
69
+ self.user_data: dict[str, dict[str, Any]] = {}
70
+
71
+ async def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
72
+ """Get OAuth client information."""
73
+ return self.clients.get(client_id)
74
+
75
+ async def register_client(self, client_info: OAuthClientInformationFull):
76
+ """Register a new OAuth client."""
77
+ self.clients[client_info.client_id] = client_info
78
+
79
+ async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str:
80
+ """Generate an authorization URL for simple login flow."""
81
+ state = params.state or secrets.token_hex(16)
82
+
83
+ # Store state mapping for callback
84
+ self.state_mapping[state] = {
85
+ "redirect_uri": str(params.redirect_uri),
86
+ "code_challenge": params.code_challenge,
87
+ "redirect_uri_provided_explicitly": str(params.redirect_uri_provided_explicitly),
88
+ "client_id": client.client_id,
89
+ "resource": params.resource, # RFC 8707
90
+ }
91
+
92
+ # Build simple login URL that points to login page
93
+ auth_url = f"{self.auth_callback_url}?state={state}&client_id={client.client_id}"
94
+
95
+ return auth_url
96
+
97
+ async def get_login_page(self, state: str) -> HTMLResponse:
98
+ """Generate login page HTML for the given state."""
99
+ if not state:
100
+ raise HTTPException(400, "Missing state parameter")
101
+
102
+ # Create simple login form HTML
103
+ html_content = f"""
104
+ <!DOCTYPE html>
105
+ <html>
106
+ <head>
107
+ <title>MCP Demo Authentication</title>
108
+ <style>
109
+ body {{ font-family: Arial, sans-serif; max-width: 500px; margin: 0 auto; padding: 20px; }}
110
+ .form-group {{ margin-bottom: 15px; }}
111
+ input {{ width: 100%; padding: 8px; margin-top: 5px; }}
112
+ button {{ background-color: #4CAF50; color: white; padding: 10px 15px; border: none; cursor: pointer; }}
113
+ </style>
114
+ </head>
115
+ <body>
116
+ <h2>MCP Demo Authentication</h2>
117
+ <p>This is a simplified authentication demo. Use the demo credentials below:</p>
118
+ <p><strong>Username:</strong> demo_user<br>
119
+ <strong>Password:</strong> demo_password</p>
120
+
121
+ <form action="{self.server_url.rstrip("/")}/login/callback" method="post">
122
+ <input type="hidden" name="state" value="{state}">
123
+ <div class="form-group">
124
+ <label>Username:</label>
125
+ <input type="text" name="username" value="demo_user" required>
126
+ </div>
127
+ <div class="form-group">
128
+ <label>Password:</label>
129
+ <input type="password" name="password" value="demo_password" required>
130
+ </div>
131
+ <button type="submit">Sign In</button>
132
+ </form>
133
+ </body>
134
+ </html>
135
+ """
136
+
137
+ return HTMLResponse(content=html_content)
138
+
139
+ async def handle_login_callback(self, request: Request) -> Response:
140
+ """Handle login form submission callback."""
141
+ form = await request.form()
142
+ username = form.get("username")
143
+ password = form.get("password")
144
+ state = form.get("state")
145
+
146
+ if not username or not password or not state:
147
+ raise HTTPException(400, "Missing username, password, or state parameter")
148
+
149
+ # Ensure we have strings, not UploadFile objects
150
+ if not isinstance(username, str) or not isinstance(password, str) or not isinstance(state, str):
151
+ raise HTTPException(400, "Invalid parameter types")
152
+
153
+ redirect_uri = await self.handle_simple_callback(username, password, state)
154
+ return RedirectResponse(url=redirect_uri, status_code=302)
155
+
156
+ async def handle_simple_callback(self, username: str, password: str, state: str) -> str:
157
+ """Handle simple authentication callback and return redirect URI."""
158
+ state_data = self.state_mapping.get(state)
159
+ if not state_data:
160
+ raise HTTPException(400, "Invalid state parameter")
161
+
162
+ redirect_uri = state_data["redirect_uri"]
163
+ code_challenge = state_data["code_challenge"]
164
+ redirect_uri_provided_explicitly = state_data["redirect_uri_provided_explicitly"] == "True"
165
+ client_id = state_data["client_id"]
166
+ resource = state_data.get("resource") # RFC 8707
167
+
168
+ # These are required values from our own state mapping
169
+ assert redirect_uri is not None
170
+ assert code_challenge is not None
171
+ assert client_id is not None
172
+
173
+ # Validate demo credentials
174
+ if username != self.settings.demo_username or password != self.settings.demo_password:
175
+ raise HTTPException(401, "Invalid credentials")
176
+
177
+ # Create MCP authorization code
178
+ new_code = f"mcp_{secrets.token_hex(16)}"
179
+ auth_code = AuthorizationCode(
180
+ code=new_code,
181
+ client_id=client_id,
182
+ redirect_uri=AnyHttpUrl(redirect_uri),
183
+ redirect_uri_provided_explicitly=redirect_uri_provided_explicitly,
184
+ expires_at=time.time() + 300,
185
+ scopes=[self.settings.mcp_scope],
186
+ code_challenge=code_challenge,
187
+ resource=resource, # RFC 8707
188
+ )
189
+ self.auth_codes[new_code] = auth_code
190
+
191
+ # Store user data
192
+ self.user_data[username] = {
193
+ "username": username,
194
+ "user_id": f"user_{secrets.token_hex(8)}",
195
+ "authenticated_at": time.time(),
196
+ }
197
+
198
+ del self.state_mapping[state]
199
+ return construct_redirect_uri(redirect_uri, code=new_code, state=state)
200
+
201
+ async def load_authorization_code(
202
+ self, client: OAuthClientInformationFull, authorization_code: str
203
+ ) -> AuthorizationCode | None:
204
+ """Load an authorization code."""
205
+ return self.auth_codes.get(authorization_code)
206
+
207
+ async def exchange_authorization_code(
208
+ self, client: OAuthClientInformationFull, authorization_code: AuthorizationCode
209
+ ) -> OAuthToken:
210
+ """Exchange authorization code for tokens."""
211
+ if authorization_code.code not in self.auth_codes:
212
+ raise ValueError("Invalid authorization code")
213
+
214
+ # Generate MCP access token
215
+ mcp_token = f"mcp_{secrets.token_hex(32)}"
216
+
217
+ # Store MCP token
218
+ self.tokens[mcp_token] = AccessToken(
219
+ token=mcp_token,
220
+ client_id=client.client_id,
221
+ scopes=authorization_code.scopes,
222
+ expires_at=int(time.time()) + 3600,
223
+ resource=authorization_code.resource, # RFC 8707
224
+ )
225
+
226
+ # Store user data mapping for this token
227
+ self.user_data[mcp_token] = {
228
+ "username": self.settings.demo_username,
229
+ "user_id": f"user_{secrets.token_hex(8)}",
230
+ "authenticated_at": time.time(),
231
+ }
232
+
233
+ del self.auth_codes[authorization_code.code]
234
+
235
+ return OAuthToken(
236
+ access_token=mcp_token,
237
+ token_type="Bearer",
238
+ expires_in=3600,
239
+ scope=" ".join(authorization_code.scopes),
240
+ )
241
+
242
+ async def load_access_token(self, token: str) -> AccessToken | None:
243
+ """Load and validate an access token."""
244
+ access_token = self.tokens.get(token)
245
+ if not access_token:
246
+ return None
247
+
248
+ # Check if expired
249
+ if access_token.expires_at and access_token.expires_at < time.time():
250
+ del self.tokens[token]
251
+ return None
252
+
253
+ return access_token
254
+
255
+ async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_token: str) -> RefreshToken | None:
256
+ """Load a refresh token - not supported in this example."""
257
+ return None
258
+
259
+ async def exchange_refresh_token(
260
+ self,
261
+ client: OAuthClientInformationFull,
262
+ refresh_token: RefreshToken,
263
+ scopes: list[str],
264
+ ) -> OAuthToken:
265
+ """Exchange refresh token - not supported in this example."""
266
+ raise NotImplementedError("Refresh tokens not supported")
267
+
268
+ async def revoke_token(self, token: str, token_type_hint: str | None = None) -> None:
269
+ """Revoke a token."""
270
+ if token in self.tokens:
271
+ del self.tokens[token]