from mcp.server.fastmcp import FastMCP, Context from fastapi import Request from fastapi.responses import JSONResponse, HTMLResponse, FileResponse from dotenv import load_dotenv import os import json import datasets_registry import subscriptions_ledger import stripe_webhook import auth load_dotenv() # Initialize MCP Server mcp = FastMCP("datapass") # Admin API Secret for secure admin endpoints ADMIN_API_SECRET = os.getenv("ADMIN_API_SECRET") def verify_admin_secret(provided_secret: str) -> bool: """Verify the admin secret matches.""" if not ADMIN_API_SECRET: return False # Fail closed if not configured return provided_secret == ADMIN_API_SECRET def get_admin_secret_from_request(request: Request, body_data: dict = None) -> str: """Extract admin secret from header (preferred) or body (fallback).""" header_secret = request.headers.get("X-Admin-Secret", "") if header_secret: return header_secret if body_data: return body_data.get("admin_secret", "") return "" def _get_access_token_from_context(ctx: Context) -> str | None: """Extract access token from the Authorization header in the request context.""" try: if ctx.request_context and ctx.request_context.request: auth_header = ctx.request_context.request.headers.get("Authorization", "") if auth_header.startswith("Bearer "): return auth_header[7:] # Remove "Bearer " prefix except Exception as e: print(f"Error extracting access token: {e}") return None def _validate_token_for_dataset(access_token: str, dataset_id: str = None): """Validate access token and optionally check dataset access.""" if not access_token: return None, "No access token provided. Please include Authorization header with Bearer token." token_info = subscriptions_ledger.validate_access_token(access_token) if not token_info: return None, "Invalid or expired access token." if dataset_id and token_info["dataset_id"] != dataset_id: return None, f"Access token not valid for dataset '{dataset_id}'. Your token is for '{token_info['dataset_id']}'." return token_info, None # ============================================================================= # Public MCP Tools (no auth required) # ============================================================================= @mcp.tool() def get_dataset_catalog(): """ Returns the catalog of all available datasets in the marketplace. Use this to discover datasets you can subscribe to. """ registry = datasets_registry.load_registry() active_datasets = [d for d in registry if d.get("is_active")] return active_datasets # ============================================================================= # Authenticated MCP Tools (require access token via Authorization header) # ============================================================================= @mcp.tool() def get_my_datasets(ctx: Context): """ Returns the list of datasets you have access to based on your subscription. The access token is automatically read from your Authorization header. """ access_token = _get_access_token_from_context(ctx) if not access_token: return {"error": "No access token provided. Please include Authorization header with Bearer token."} token_info = subscriptions_ledger.validate_access_token(access_token) if not token_info: return {"error": "Invalid or expired access token."} # Return info about the dataset this token grants access to dataset_id = token_info["dataset_id"] subscription = token_info.get("subscription", {}) # Get dataset details from registry registry = datasets_registry.load_registry() dataset = next((d for d in registry if d["dataset_id"] == dataset_id), None) return { "datasets": [{ "dataset_id": dataset_id, "display_name": dataset.get("display_name", dataset_id) if dataset else dataset_id, "description": dataset.get("description", "") if dataset else "", "subscription_end": subscription.get("subscription_end"), "plan_id": subscription.get("plan_id") }], "user": token_info["hf_user"] } def _get_duckdb_connection(server_hf_token: str): """Create a DuckDB connection with HF token configured.""" import duckdb con = duckdb.connect(database=':memory:') con.execute("INSTALL httpfs;") con.execute("LOAD httpfs;") con.execute(f"CREATE SECRET hf_token (TYPE huggingface, TOKEN '{server_hf_token}');") return con @mcp.tool() def get_dataset_schema(dataset_id: str, ctx: Context): """ Returns the schema (column names and types) of a dataset you have access to. Use this to understand the data structure before querying. Args: dataset_id: The ID of the dataset (e.g., 'waroca/prompts') """ access_token = _get_access_token_from_context(ctx) token_info, error = _validate_token_for_dataset(access_token, dataset_id) if error: return {"error": error} # Use server's HF token to access the dataset server_hf_token = os.getenv("HF_TOKEN") if not server_hf_token: return {"error": "Server HF token not configured"} # Get the correct parquet URL pattern for this dataset dataset_url = datasets_registry.get_parquet_url(dataset_id) try: con = _get_duckdb_connection(server_hf_token) # Get schema result = con.execute(f"DESCRIBE SELECT * FROM read_parquet('{dataset_url}')").fetchall() schema = [{"column": row[0], "type": row[1]} for row in result] # Get row count count_result = con.execute(f"SELECT COUNT(*) FROM read_parquet('{dataset_url}')").fetchone() row_count = count_result[0] if count_result else 0 return { "dataset_id": dataset_id, "schema": schema, "row_count": row_count } except Exception as e: return {"error": f"Failed to get schema: {str(e)}"} @mcp.tool() def get_dataset_sample(dataset_id: str, ctx: Context, num_rows: int = 5): """ Returns a sample of rows from a dataset you have access to. Use this to preview the data before running queries. Args: dataset_id: The ID of the dataset (e.g., 'waroca/prompts') num_rows: Number of sample rows to return (default: 5, max: 20) """ access_token = _get_access_token_from_context(ctx) token_info, error = _validate_token_for_dataset(access_token, dataset_id) if error: return {"error": error} # Limit rows num_rows = min(max(1, num_rows), 20) server_hf_token = os.getenv("HF_TOKEN") if not server_hf_token: return {"error": "Server HF token not configured"} # Get the correct parquet URL pattern for this dataset dataset_url = datasets_registry.get_parquet_url(dataset_id) try: con = _get_duckdb_connection(server_hf_token) result = con.execute(f"SELECT * FROM read_parquet('{dataset_url}') LIMIT {num_rows}").fetchdf() return { "dataset_id": dataset_id, "sample": result.to_dict(orient='records'), "num_rows": len(result) } except Exception as e: return {"error": f"Failed to get sample: {str(e)}"} @mcp.tool() def query_dataset(dataset_id: str, query: str, ctx: Context): """ Execute a SQL query against a dataset you have access to. The dataset is available as a table named 'data'. Args: dataset_id: The ID of the dataset (e.g., 'waroca/prompts') query: SQL query to execute. Use 'data' as the table name. Example: "SELECT * FROM data WHERE column > 10 LIMIT 100" """ access_token = _get_access_token_from_context(ctx) token_info, error = _validate_token_for_dataset(access_token, dataset_id) if error: return {"error": error} server_hf_token = os.getenv("HF_TOKEN") if not server_hf_token: return {"error": "Server HF token not configured"} # Get the correct parquet URL pattern for this dataset dataset_url = datasets_registry.get_parquet_url(dataset_id) try: con = _get_duckdb_connection(server_hf_token) con.execute(f"CREATE OR REPLACE VIEW data AS SELECT * FROM read_parquet('{dataset_url}');") result = con.execute(query).fetchdf() # Wrap in dict to ensure proper JSON serialization return {"results": result.to_dict(orient='records'), "row_count": len(result)} except Exception as e: return {"error": f"Query execution failed: {str(e)}"} @mcp.tool() def query_dataset_natural_language(dataset_id: str, question: str, ctx: Context): """ Query a dataset using natural language. The question will be converted to SQL automatically. Args: dataset_id: The ID of the dataset (e.g., 'waroca/prompts') question: Your question in plain English. Example: "What are the top 10 most common values in the category column?" """ access_token = _get_access_token_from_context(ctx) token_info, error = _validate_token_for_dataset(access_token, dataset_id) if error: return {"error": error} server_hf_token = os.getenv("HF_TOKEN") if not server_hf_token: return {"error": "Server HF token not configured"} # Get the correct parquet URL pattern for this dataset dataset_url = datasets_registry.get_parquet_url(dataset_id) try: con = _get_duckdb_connection(server_hf_token) # Get schema for context schema_result = con.execute(f"DESCRIBE SELECT * FROM read_parquet('{dataset_url}')").fetchall() schema_str = "\n".join([f" - {row[0]}: {row[1]}" for row in schema_result]) # Generate SQL using LLM from huggingface_hub import InferenceClient client = InferenceClient(token=server_hf_token) system_prompt = "You are a SQL expert. Convert natural language questions into DuckDB SQL queries. Return ONLY the SQL query, nothing else. Do not explain." user_prompt = f"""Table name: data Schema: {schema_str} Question: {question} Return ONLY the SQL query, nothing else.""" response = client.chat_completion( model="Qwen/Qwen2.5-Coder-32B-Instruct", messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt} ], max_tokens=300 ) generated_sql = response.choices[0].message.content.strip() # Clean up if model adds markdown if "```sql" in generated_sql: generated_sql = generated_sql.split("```sql")[1].split("```")[0].strip() elif "```" in generated_sql: generated_sql = generated_sql.split("```")[1].split("```")[0].strip() # Remove trailing semicolon if present (we'll add it if needed) generated_sql = generated_sql.rstrip(';').strip() # Execute the query con.execute(f"CREATE OR REPLACE VIEW data AS SELECT * FROM read_parquet('{dataset_url}');") result = con.execute(generated_sql).fetchdf() return { "question": question, "generated_sql": generated_sql, "result": result.to_dict(orient='records') } except Exception as e: return {"error": f"Natural language query failed: {str(e)}"} # ============================================================================= # Admin Helper Functions (used by HTTP API endpoints, not exposed as MCP tools) # ============================================================================= def _admin_get_subscriber_stats(): """Returns subscriber statistics.""" all_subs = subscriptions_ledger.get_all_subscriptions() stats = {} active_count = 0 from datetime import datetime for sub in all_subs.values(): d_id = sub["dataset_id"] if d_id not in stats: stats[d_id] = {"total": 0, "active": 0} stats[d_id]["total"] += 1 end_str = sub.get("subscription_end", "") if end_str: try: if end_str.endswith("Z"): end_str = end_str[:-1] end_date = datetime.fromisoformat(end_str) if end_date > datetime.utcnow(): stats[d_id]["active"] += 1 active_count += 1 except: pass return { "by_dataset": stats, "total_active": active_count, "total_subscriptions": len(all_subs) } def _admin_get_subscribers_for_dataset(dataset_id: str): """Returns list of subscribers for a dataset.""" all_subs = subscriptions_ledger.get_all_subscriptions() subscribers = [] from datetime import datetime for sub in all_subs.values(): if sub["dataset_id"] == dataset_id: end_str = sub.get("subscription_end", "") is_active = False if end_str: try: if end_str.endswith("Z"): end_str = end_str[:-1] end_date = datetime.fromisoformat(end_str) is_active = end_date > datetime.utcnow() except: pass subscribers.append({ **sub, "is_active": is_active }) return subscribers def _admin_update_dataset_registry(dataset_id: str, action: str, data: str): """Updates the dataset registry.""" registry = datasets_registry.load_registry() if action == "add": new_dataset = json.loads(data) registry.append(new_dataset) elif action == "toggle_active": for d in registry: if d["dataset_id"] == dataset_id: d["is_active"] = not d.get("is_active", False) elif action == "delete": registry = [d for d in registry if d["dataset_id"] != dataset_id] elif action == "update": updated_dataset = json.loads(data) for i, d in enumerate(registry): if d["dataset_id"] == dataset_id: registry[i] = updated_dataset break success = datasets_registry.save_registry(registry) if success: return {"status": "success"} else: return {"error": "Failed to save registry"} # ============================================================================= # HTTP API Routes (for frontend, not MCP) # ============================================================================= @mcp.custom_route("/webhook", methods=["POST"]) async def webhook_handler(request: Request): return await stripe_webhook.handle_stripe_webhook(request) async def _subscribe_free_logic(dataset_id: str, hf_token: str): user_info = auth.validate_hf_token(hf_token) if not user_info: return {"error": "Invalid HF token"} hf_user = user_info["name"] registry = datasets_registry.load_registry() dataset = next((d for d in registry if d["dataset_id"] == dataset_id), None) if not dataset or not dataset.get("is_active"): return {"error": "Dataset not found or inactive"} free_plan = datasets_registry.get_free_plan(dataset_id) if not free_plan: return {"error": "No free plan available for this dataset."} from datetime import datetime, timedelta duration_days = free_plan.get("access_duration_days", 1) # Default to 1-day trial end_date = datetime.utcnow() + timedelta(days=duration_days) access_token = subscriptions_ledger.generate_access_token() ledger_entry = { "event_id": f"free_{hf_user}_{datetime.utcnow().timestamp()}", "hf_user": hf_user, "dataset_id": dataset_id, "plan_id": free_plan["plan_id"], "subscription_start": datetime.utcnow().isoformat() + "Z", "subscription_end": end_date.isoformat() + "Z", "source": "free_tier", "access_token": access_token, "created_at": datetime.utcnow().isoformat() + "Z" } subscriptions_ledger.append_subscription_event(ledger_entry) return {"status": "success", "message": "Subscribed successfully", "access_token": access_token} async def _create_checkout_session_logic(dataset_id: str, hf_token: str): user_info = auth.validate_hf_token(hf_token) if not user_info: return {"error": "Invalid HF token"} hf_user = user_info["name"] registry = datasets_registry.load_registry() dataset = next((d for d in registry if d["dataset_id"] == dataset_id), None) if not dataset or not dataset.get("is_active"): return {"error": "Dataset not found or inactive"} if not dataset.get("plans"): return {"error": "No plans available for this dataset"} plan = dataset["plans"][0] price_id = plan.get("stripe_price_id") if price_id in ["free", "0", 0]: return {"error": "This is a free dataset, use subscribe_free instead"} if not price_id: return {"error": "Price ID not configured"} try: import stripe stripe.api_key = os.getenv("STRIPE_SECRET_KEY") checkout_session = stripe.checkout.Session.create( payment_method_types=['card'], line_items=[{'price': price_id, 'quantity': 1}], mode='subscription', metadata={'hf_user': hf_user, 'dataset_id': dataset_id}, success_url='https://huggingface.co/spaces/waroca/monetization-frontend?success=true', cancel_url='https://huggingface.co/spaces/waroca/monetization-frontend?canceled=true', ) return {"checkout_url": checkout_session.url} except Exception as e: return {"error": str(e)} @mcp.custom_route("/api/subscribe_free", methods=["POST"]) async def api_subscribe_free(request: Request): data = await request.json() result = await _subscribe_free_logic(data.get("dataset_id"), data.get("hf_token")) return JSONResponse(result) @mcp.custom_route("/api/create_checkout_session", methods=["POST"]) async def api_create_checkout_session(request: Request): data = await request.json() result = await _create_checkout_session_logic(data.get("dataset_id"), data.get("hf_token")) return JSONResponse(result) @mcp.custom_route("/api/user_subscriptions", methods=["POST"]) async def api_user_subscriptions(request: Request): """Get subscriptions for the current user.""" data = await request.json() hf_user = data.get("hf_user") hf_token = data.get("hf_token") if not hf_user: return JSONResponse({"error": "hf_user required"}, status_code=400) if not hf_token: return JSONResponse({"error": "hf_token required"}, status_code=401) user_info = auth.validate_hf_token(hf_token) if not user_info: return JSONResponse({"error": "Invalid HF token"}, status_code=401) if user_info.get("name") != hf_user: return JSONResponse({"error": "Token does not match user"}, status_code=403) subs = subscriptions_ledger.get_user_subscriptions(hf_user) return JSONResponse(subs) @mcp.custom_route("/api/catalog", methods=["GET"]) async def api_catalog(request: Request): """Public endpoint for getting the dataset catalog.""" catalog = get_dataset_catalog() return JSONResponse(catalog) # ============================================================================= # Admin API Routes # ============================================================================= @mcp.custom_route("/api/admin/subscriber_stats", methods=["POST"]) async def api_admin_subscriber_stats(request: Request): data = await request.json() admin_secret = get_admin_secret_from_request(request, data) if not verify_admin_secret(admin_secret): return JSONResponse({"error": "Unauthorized"}, status_code=401) result = _admin_get_subscriber_stats() return JSONResponse(result) @mcp.custom_route("/api/admin/subscribers_for_dataset", methods=["POST"]) async def api_admin_subscribers_for_dataset(request: Request): data = await request.json() admin_secret = get_admin_secret_from_request(request, data) dataset_id = data.get("dataset_id") if not verify_admin_secret(admin_secret): return JSONResponse({"error": "Unauthorized"}, status_code=401) if not dataset_id: return JSONResponse({"error": "dataset_id required"}, status_code=400) result = _admin_get_subscribers_for_dataset(dataset_id) return JSONResponse(result) @mcp.custom_route("/api/admin/update_registry", methods=["POST"]) async def api_admin_update_registry(request: Request): data = await request.json() admin_secret = get_admin_secret_from_request(request, data) dataset_id = data.get("dataset_id", "") action = data.get("action") payload = data.get("data", "") if not verify_admin_secret(admin_secret): return JSONResponse({"error": "Unauthorized"}, status_code=401) if not action: return JSONResponse({"error": "action required"}, status_code=400) result = _admin_update_dataset_registry(dataset_id, action, payload) return JSONResponse(result) @mcp.custom_route("/api/admin/catalog", methods=["POST"]) async def api_admin_catalog(request: Request): data = await request.json() admin_secret = get_admin_secret_from_request(request, data) if not verify_admin_secret(admin_secret): return JSONResponse({"error": "Unauthorized"}, status_code=401) registry = datasets_registry.load_registry() return JSONResponse(registry) @mcp.custom_route("/api/admin/detect_format", methods=["POST"]) async def api_admin_detect_format(request: Request): """Detect the format and parquet path for a HuggingFace dataset.""" data = await request.json() admin_secret = get_admin_secret_from_request(request, data) dataset_id = data.get("dataset_id") if not verify_admin_secret(admin_secret): return JSONResponse({"error": "Unauthorized"}, status_code=401) if not dataset_id: return JSONResponse({"error": "dataset_id required"}, status_code=400) result = datasets_registry.detect_dataset_format(dataset_id) return JSONResponse(result) # ============================================================================= # Landing Page # ============================================================================= FRONTEND_URL = os.getenv("FRONTEND_URL", "https://huggingface.co/spaces/waroca/datapass") @mcp.custom_route("/logo.png", methods=["GET"]) async def serve_logo(request: Request): """Serve the DataPass logo.""" logo_path = os.path.join(os.path.dirname(__file__), "datapass_logo.png") return FileResponse(logo_path, media_type="image/png") @mcp.custom_route("/", methods=["GET"]) async def landing_page(request: Request): """Landing page with DataPass info and redirect to frontend.""" html = f"""

Your pass to private datasets.
sse: {request.url.scheme}://{request.url.netloc}/sse