datapass-server / server.py
waroca's picture
Upload folder using huggingface_hub
cfe9850 verified
from mcp.server.fastmcp import FastMCP, Context
from fastapi import Request
from fastapi.responses import JSONResponse, HTMLResponse, FileResponse
from dotenv import load_dotenv
import os
import json
import datasets_registry
import subscriptions_ledger
import stripe_webhook
import auth
load_dotenv()
# Initialize MCP Server
mcp = FastMCP("datapass")
# Admin API Secret for secure admin endpoints
ADMIN_API_SECRET = os.getenv("ADMIN_API_SECRET")
def verify_admin_secret(provided_secret: str) -> bool:
"""Verify the admin secret matches."""
if not ADMIN_API_SECRET:
return False # Fail closed if not configured
return provided_secret == ADMIN_API_SECRET
def get_admin_secret_from_request(request: Request, body_data: dict = None) -> str:
"""Extract admin secret from header (preferred) or body (fallback)."""
header_secret = request.headers.get("X-Admin-Secret", "")
if header_secret:
return header_secret
if body_data:
return body_data.get("admin_secret", "")
return ""
def _get_access_token_from_context(ctx: Context) -> str | None:
"""Extract access token from the Authorization header in the request context."""
try:
if ctx.request_context and ctx.request_context.request:
auth_header = ctx.request_context.request.headers.get("Authorization", "")
if auth_header.startswith("Bearer "):
return auth_header[7:] # Remove "Bearer " prefix
except Exception as e:
print(f"Error extracting access token: {e}")
return None
def _validate_token_for_dataset(access_token: str, dataset_id: str = None):
"""Validate access token and optionally check dataset access."""
if not access_token:
return None, "No access token provided. Please include Authorization header with Bearer token."
token_info = subscriptions_ledger.validate_access_token(access_token)
if not token_info:
return None, "Invalid or expired access token."
if dataset_id and token_info["dataset_id"] != dataset_id:
return None, f"Access token not valid for dataset '{dataset_id}'. Your token is for '{token_info['dataset_id']}'."
return token_info, None
# =============================================================================
# Public MCP Tools (no auth required)
# =============================================================================
@mcp.tool()
def get_dataset_catalog():
"""
Returns the catalog of all available datasets in the marketplace.
Use this to discover datasets you can subscribe to.
"""
registry = datasets_registry.load_registry()
active_datasets = [d for d in registry if d.get("is_active")]
return active_datasets
# =============================================================================
# Authenticated MCP Tools (require access token via Authorization header)
# =============================================================================
@mcp.tool()
def get_my_datasets(ctx: Context):
"""
Returns the list of datasets you have access to based on your subscription.
The access token is automatically read from your Authorization header.
"""
access_token = _get_access_token_from_context(ctx)
if not access_token:
return {"error": "No access token provided. Please include Authorization header with Bearer token."}
token_info = subscriptions_ledger.validate_access_token(access_token)
if not token_info:
return {"error": "Invalid or expired access token."}
# Return info about the dataset this token grants access to
dataset_id = token_info["dataset_id"]
subscription = token_info.get("subscription", {})
# Get dataset details from registry
registry = datasets_registry.load_registry()
dataset = next((d for d in registry if d["dataset_id"] == dataset_id), None)
return {
"datasets": [{
"dataset_id": dataset_id,
"display_name": dataset.get("display_name", dataset_id) if dataset else dataset_id,
"description": dataset.get("description", "") if dataset else "",
"subscription_end": subscription.get("subscription_end"),
"plan_id": subscription.get("plan_id")
}],
"user": token_info["hf_user"]
}
def _get_duckdb_connection(server_hf_token: str):
"""Create a DuckDB connection with HF token configured."""
import duckdb
con = duckdb.connect(database=':memory:')
con.execute("INSTALL httpfs;")
con.execute("LOAD httpfs;")
con.execute(f"CREATE SECRET hf_token (TYPE huggingface, TOKEN '{server_hf_token}');")
return con
@mcp.tool()
def get_dataset_schema(dataset_id: str, ctx: Context):
"""
Returns the schema (column names and types) of a dataset you have access to.
Use this to understand the data structure before querying.
Args:
dataset_id: The ID of the dataset (e.g., 'waroca/prompts')
"""
access_token = _get_access_token_from_context(ctx)
token_info, error = _validate_token_for_dataset(access_token, dataset_id)
if error:
return {"error": error}
# Use server's HF token to access the dataset
server_hf_token = os.getenv("HF_TOKEN")
if not server_hf_token:
return {"error": "Server HF token not configured"}
# Get the correct parquet URL pattern for this dataset
dataset_url = datasets_registry.get_parquet_url(dataset_id)
try:
con = _get_duckdb_connection(server_hf_token)
# Get schema
result = con.execute(f"DESCRIBE SELECT * FROM read_parquet('{dataset_url}')").fetchall()
schema = [{"column": row[0], "type": row[1]} for row in result]
# Get row count
count_result = con.execute(f"SELECT COUNT(*) FROM read_parquet('{dataset_url}')").fetchone()
row_count = count_result[0] if count_result else 0
return {
"dataset_id": dataset_id,
"schema": schema,
"row_count": row_count
}
except Exception as e:
return {"error": f"Failed to get schema: {str(e)}"}
@mcp.tool()
def get_dataset_sample(dataset_id: str, ctx: Context, num_rows: int = 5):
"""
Returns a sample of rows from a dataset you have access to.
Use this to preview the data before running queries.
Args:
dataset_id: The ID of the dataset (e.g., 'waroca/prompts')
num_rows: Number of sample rows to return (default: 5, max: 20)
"""
access_token = _get_access_token_from_context(ctx)
token_info, error = _validate_token_for_dataset(access_token, dataset_id)
if error:
return {"error": error}
# Limit rows
num_rows = min(max(1, num_rows), 20)
server_hf_token = os.getenv("HF_TOKEN")
if not server_hf_token:
return {"error": "Server HF token not configured"}
# Get the correct parquet URL pattern for this dataset
dataset_url = datasets_registry.get_parquet_url(dataset_id)
try:
con = _get_duckdb_connection(server_hf_token)
result = con.execute(f"SELECT * FROM read_parquet('{dataset_url}') LIMIT {num_rows}").fetchdf()
return {
"dataset_id": dataset_id,
"sample": result.to_dict(orient='records'),
"num_rows": len(result)
}
except Exception as e:
return {"error": f"Failed to get sample: {str(e)}"}
@mcp.tool()
def query_dataset(dataset_id: str, query: str, ctx: Context):
"""
Execute a SQL query against a dataset you have access to.
The dataset is available as a table named 'data'.
Args:
dataset_id: The ID of the dataset (e.g., 'waroca/prompts')
query: SQL query to execute. Use 'data' as the table name.
Example: "SELECT * FROM data WHERE column > 10 LIMIT 100"
"""
access_token = _get_access_token_from_context(ctx)
token_info, error = _validate_token_for_dataset(access_token, dataset_id)
if error:
return {"error": error}
server_hf_token = os.getenv("HF_TOKEN")
if not server_hf_token:
return {"error": "Server HF token not configured"}
# Get the correct parquet URL pattern for this dataset
dataset_url = datasets_registry.get_parquet_url(dataset_id)
try:
con = _get_duckdb_connection(server_hf_token)
con.execute(f"CREATE OR REPLACE VIEW data AS SELECT * FROM read_parquet('{dataset_url}');")
result = con.execute(query).fetchdf()
# Wrap in dict to ensure proper JSON serialization
return {"results": result.to_dict(orient='records'), "row_count": len(result)}
except Exception as e:
return {"error": f"Query execution failed: {str(e)}"}
@mcp.tool()
def query_dataset_natural_language(dataset_id: str, question: str, ctx: Context):
"""
Query a dataset using natural language. The question will be converted to SQL automatically.
Args:
dataset_id: The ID of the dataset (e.g., 'waroca/prompts')
question: Your question in plain English.
Example: "What are the top 10 most common values in the category column?"
"""
access_token = _get_access_token_from_context(ctx)
token_info, error = _validate_token_for_dataset(access_token, dataset_id)
if error:
return {"error": error}
server_hf_token = os.getenv("HF_TOKEN")
if not server_hf_token:
return {"error": "Server HF token not configured"}
# Get the correct parquet URL pattern for this dataset
dataset_url = datasets_registry.get_parquet_url(dataset_id)
try:
con = _get_duckdb_connection(server_hf_token)
# Get schema for context
schema_result = con.execute(f"DESCRIBE SELECT * FROM read_parquet('{dataset_url}')").fetchall()
schema_str = "\n".join([f" - {row[0]}: {row[1]}" for row in schema_result])
# Generate SQL using LLM
from huggingface_hub import InferenceClient
client = InferenceClient(token=server_hf_token)
system_prompt = "You are a SQL expert. Convert natural language questions into DuckDB SQL queries. Return ONLY the SQL query, nothing else. Do not explain."
user_prompt = f"""Table name: data
Schema:
{schema_str}
Question: {question}
Return ONLY the SQL query, nothing else."""
response = client.chat_completion(
model="Qwen/Qwen2.5-Coder-32B-Instruct",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
max_tokens=300
)
generated_sql = response.choices[0].message.content.strip()
# Clean up if model adds markdown
if "```sql" in generated_sql:
generated_sql = generated_sql.split("```sql")[1].split("```")[0].strip()
elif "```" in generated_sql:
generated_sql = generated_sql.split("```")[1].split("```")[0].strip()
# Remove trailing semicolon if present (we'll add it if needed)
generated_sql = generated_sql.rstrip(';').strip()
# Execute the query
con.execute(f"CREATE OR REPLACE VIEW data AS SELECT * FROM read_parquet('{dataset_url}');")
result = con.execute(generated_sql).fetchdf()
return {
"question": question,
"generated_sql": generated_sql,
"result": result.to_dict(orient='records')
}
except Exception as e:
return {"error": f"Natural language query failed: {str(e)}"}
# =============================================================================
# Admin Helper Functions (used by HTTP API endpoints, not exposed as MCP tools)
# =============================================================================
def _admin_get_subscriber_stats():
"""Returns subscriber statistics."""
all_subs = subscriptions_ledger.get_all_subscriptions()
stats = {}
active_count = 0
from datetime import datetime
for sub in all_subs.values():
d_id = sub["dataset_id"]
if d_id not in stats:
stats[d_id] = {"total": 0, "active": 0}
stats[d_id]["total"] += 1
end_str = sub.get("subscription_end", "")
if end_str:
try:
if end_str.endswith("Z"):
end_str = end_str[:-1]
end_date = datetime.fromisoformat(end_str)
if end_date > datetime.utcnow():
stats[d_id]["active"] += 1
active_count += 1
except:
pass
return {
"by_dataset": stats,
"total_active": active_count,
"total_subscriptions": len(all_subs)
}
def _admin_get_subscribers_for_dataset(dataset_id: str):
"""Returns list of subscribers for a dataset."""
all_subs = subscriptions_ledger.get_all_subscriptions()
subscribers = []
from datetime import datetime
for sub in all_subs.values():
if sub["dataset_id"] == dataset_id:
end_str = sub.get("subscription_end", "")
is_active = False
if end_str:
try:
if end_str.endswith("Z"):
end_str = end_str[:-1]
end_date = datetime.fromisoformat(end_str)
is_active = end_date > datetime.utcnow()
except:
pass
subscribers.append({
**sub,
"is_active": is_active
})
return subscribers
def _admin_update_dataset_registry(dataset_id: str, action: str, data: str):
"""Updates the dataset registry."""
registry = datasets_registry.load_registry()
if action == "add":
new_dataset = json.loads(data)
registry.append(new_dataset)
elif action == "toggle_active":
for d in registry:
if d["dataset_id"] == dataset_id:
d["is_active"] = not d.get("is_active", False)
elif action == "delete":
registry = [d for d in registry if d["dataset_id"] != dataset_id]
elif action == "update":
updated_dataset = json.loads(data)
for i, d in enumerate(registry):
if d["dataset_id"] == dataset_id:
registry[i] = updated_dataset
break
success = datasets_registry.save_registry(registry)
if success:
return {"status": "success"}
else:
return {"error": "Failed to save registry"}
# =============================================================================
# HTTP API Routes (for frontend, not MCP)
# =============================================================================
@mcp.custom_route("/webhook", methods=["POST"])
async def webhook_handler(request: Request):
return await stripe_webhook.handle_stripe_webhook(request)
async def _subscribe_free_logic(dataset_id: str, hf_token: str):
user_info = auth.validate_hf_token(hf_token)
if not user_info:
return {"error": "Invalid HF token"}
hf_user = user_info["name"]
registry = datasets_registry.load_registry()
dataset = next((d for d in registry if d["dataset_id"] == dataset_id), None)
if not dataset or not dataset.get("is_active"):
return {"error": "Dataset not found or inactive"}
free_plan = datasets_registry.get_free_plan(dataset_id)
if not free_plan:
return {"error": "No free plan available for this dataset."}
from datetime import datetime, timedelta
duration_days = free_plan.get("access_duration_days", 1) # Default to 1-day trial
end_date = datetime.utcnow() + timedelta(days=duration_days)
access_token = subscriptions_ledger.generate_access_token()
ledger_entry = {
"event_id": f"free_{hf_user}_{datetime.utcnow().timestamp()}",
"hf_user": hf_user,
"dataset_id": dataset_id,
"plan_id": free_plan["plan_id"],
"subscription_start": datetime.utcnow().isoformat() + "Z",
"subscription_end": end_date.isoformat() + "Z",
"source": "free_tier",
"access_token": access_token,
"created_at": datetime.utcnow().isoformat() + "Z"
}
subscriptions_ledger.append_subscription_event(ledger_entry)
return {"status": "success", "message": "Subscribed successfully", "access_token": access_token}
async def _create_checkout_session_logic(dataset_id: str, hf_token: str):
user_info = auth.validate_hf_token(hf_token)
if not user_info:
return {"error": "Invalid HF token"}
hf_user = user_info["name"]
registry = datasets_registry.load_registry()
dataset = next((d for d in registry if d["dataset_id"] == dataset_id), None)
if not dataset or not dataset.get("is_active"):
return {"error": "Dataset not found or inactive"}
if not dataset.get("plans"):
return {"error": "No plans available for this dataset"}
plan = dataset["plans"][0]
price_id = plan.get("stripe_price_id")
if price_id in ["free", "0", 0]:
return {"error": "This is a free dataset, use subscribe_free instead"}
if not price_id:
return {"error": "Price ID not configured"}
try:
import stripe
stripe.api_key = os.getenv("STRIPE_SECRET_KEY")
checkout_session = stripe.checkout.Session.create(
payment_method_types=['card'],
line_items=[{'price': price_id, 'quantity': 1}],
mode='subscription',
metadata={'hf_user': hf_user, 'dataset_id': dataset_id},
success_url='https://huggingface.co/spaces/waroca/monetization-frontend?success=true',
cancel_url='https://huggingface.co/spaces/waroca/monetization-frontend?canceled=true',
)
return {"checkout_url": checkout_session.url}
except Exception as e:
return {"error": str(e)}
@mcp.custom_route("/api/subscribe_free", methods=["POST"])
async def api_subscribe_free(request: Request):
data = await request.json()
result = await _subscribe_free_logic(data.get("dataset_id"), data.get("hf_token"))
return JSONResponse(result)
@mcp.custom_route("/api/create_checkout_session", methods=["POST"])
async def api_create_checkout_session(request: Request):
data = await request.json()
result = await _create_checkout_session_logic(data.get("dataset_id"), data.get("hf_token"))
return JSONResponse(result)
@mcp.custom_route("/api/user_subscriptions", methods=["POST"])
async def api_user_subscriptions(request: Request):
"""Get subscriptions for the current user."""
data = await request.json()
hf_user = data.get("hf_user")
hf_token = data.get("hf_token")
if not hf_user:
return JSONResponse({"error": "hf_user required"}, status_code=400)
if not hf_token:
return JSONResponse({"error": "hf_token required"}, status_code=401)
user_info = auth.validate_hf_token(hf_token)
if not user_info:
return JSONResponse({"error": "Invalid HF token"}, status_code=401)
if user_info.get("name") != hf_user:
return JSONResponse({"error": "Token does not match user"}, status_code=403)
subs = subscriptions_ledger.get_user_subscriptions(hf_user)
return JSONResponse(subs)
@mcp.custom_route("/api/catalog", methods=["GET"])
async def api_catalog(request: Request):
"""Public endpoint for getting the dataset catalog."""
catalog = get_dataset_catalog()
return JSONResponse(catalog)
# =============================================================================
# Admin API Routes
# =============================================================================
@mcp.custom_route("/api/admin/subscriber_stats", methods=["POST"])
async def api_admin_subscriber_stats(request: Request):
data = await request.json()
admin_secret = get_admin_secret_from_request(request, data)
if not verify_admin_secret(admin_secret):
return JSONResponse({"error": "Unauthorized"}, status_code=401)
result = _admin_get_subscriber_stats()
return JSONResponse(result)
@mcp.custom_route("/api/admin/subscribers_for_dataset", methods=["POST"])
async def api_admin_subscribers_for_dataset(request: Request):
data = await request.json()
admin_secret = get_admin_secret_from_request(request, data)
dataset_id = data.get("dataset_id")
if not verify_admin_secret(admin_secret):
return JSONResponse({"error": "Unauthorized"}, status_code=401)
if not dataset_id:
return JSONResponse({"error": "dataset_id required"}, status_code=400)
result = _admin_get_subscribers_for_dataset(dataset_id)
return JSONResponse(result)
@mcp.custom_route("/api/admin/update_registry", methods=["POST"])
async def api_admin_update_registry(request: Request):
data = await request.json()
admin_secret = get_admin_secret_from_request(request, data)
dataset_id = data.get("dataset_id", "")
action = data.get("action")
payload = data.get("data", "")
if not verify_admin_secret(admin_secret):
return JSONResponse({"error": "Unauthorized"}, status_code=401)
if not action:
return JSONResponse({"error": "action required"}, status_code=400)
result = _admin_update_dataset_registry(dataset_id, action, payload)
return JSONResponse(result)
@mcp.custom_route("/api/admin/catalog", methods=["POST"])
async def api_admin_catalog(request: Request):
data = await request.json()
admin_secret = get_admin_secret_from_request(request, data)
if not verify_admin_secret(admin_secret):
return JSONResponse({"error": "Unauthorized"}, status_code=401)
registry = datasets_registry.load_registry()
return JSONResponse(registry)
@mcp.custom_route("/api/admin/detect_format", methods=["POST"])
async def api_admin_detect_format(request: Request):
"""Detect the format and parquet path for a HuggingFace dataset."""
data = await request.json()
admin_secret = get_admin_secret_from_request(request, data)
dataset_id = data.get("dataset_id")
if not verify_admin_secret(admin_secret):
return JSONResponse({"error": "Unauthorized"}, status_code=401)
if not dataset_id:
return JSONResponse({"error": "dataset_id required"}, status_code=400)
result = datasets_registry.detect_dataset_format(dataset_id)
return JSONResponse(result)
# =============================================================================
# Landing Page
# =============================================================================
FRONTEND_URL = os.getenv("FRONTEND_URL", "https://huggingface.co/spaces/waroca/datapass")
@mcp.custom_route("/logo.png", methods=["GET"])
async def serve_logo(request: Request):
"""Serve the DataPass logo."""
logo_path = os.path.join(os.path.dirname(__file__), "datapass_logo.png")
return FileResponse(logo_path, media_type="image/png")
@mcp.custom_route("/", methods=["GET"])
async def landing_page(request: Request):
"""Landing page with DataPass info and redirect to frontend."""
html = f"""
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>DataPass - MCP Server</title>
<style>
* {{ margin: 0; padding: 0; box-sizing: border-box; }}
body {{
font-family: -apple-system, BlinkMacSystemFont, 'SF Pro Display', 'Segoe UI', Roboto, sans-serif;
background: linear-gradient(135deg, #1a1a2e 0%, #16213e 50%, #0f3460 100%);
min-height: 100vh;
display: flex;
align-items: center;
justify-content: center;
color: #fff;
}}
.container {{
text-align: center;
padding: 2rem;
max-width: 600px;
}}
.logo {{
margin-bottom: 1rem;
}}
.logo img {{
max-width: 200px;
height: auto;
}}
h1 {{
font-size: 2.5rem;
font-weight: 700;
margin-bottom: 0.5rem;
letter-spacing: -0.02em;
}}
.tagline {{
font-size: 1.25rem;
color: rgba(255,255,255,0.7);
margin-bottom: 2rem;
}}
.card {{
background: rgba(255,255,255,0.1);
backdrop-filter: blur(10px);
border: 1px solid rgba(255,255,255,0.2);
border-radius: 16px;
padding: 2rem;
margin-bottom: 2rem;
}}
.card h2 {{
font-size: 1.125rem;
margin-bottom: 1rem;
color: rgba(255,255,255,0.9);
}}
.features {{
display: grid;
grid-template-columns: 1fr 1fr;
gap: 1rem;
text-align: left;
}}
.feature {{
display: flex;
align-items: center;
gap: 0.5rem;
font-size: 0.9rem;
color: rgba(255,255,255,0.8);
}}
.btn {{
display: inline-block;
background: #fff;
color: #1a1a2e;
padding: 0.875rem 2rem;
border-radius: 50px;
text-decoration: none;
font-weight: 600;
font-size: 1rem;
transition: transform 0.2s, box-shadow 0.2s;
}}
.btn:hover {{
transform: translateY(-2px);
box-shadow: 0 10px 30px rgba(0,0,0,0.3);
}}
.mcp-info {{
margin-top: 2rem;
padding: 1rem;
background: rgba(0,0,0,0.2);
border-radius: 8px;
font-size: 0.875rem;
color: rgba(255,255,255,0.6);
}}
code {{
background: rgba(255,255,255,0.1);
padding: 0.125rem 0.375rem;
border-radius: 4px;
font-family: 'SF Mono', monospace;
}}
@media (max-width: 480px) {{
.features {{ grid-template-columns: 1fr; }}
h1 {{ font-size: 2rem; }}
}}
</style>
</head>
<body>
<div class="container">
<div class="logo"><img src="/logo.png" alt="DataPass Logo"></div>
<p class="tagline">Your pass to private datasets.</p>
<div class="card">
<h2>Query private datasets without downloading them</h2>
<div class="features">
<div class="feature">🔒 Data stays on HF</div>
<div class="feature">💬 Natural language queries</div>
<div class="feature">⚡ SQL via DuckDB</div>
<div class="feature">🎫 Time-limited access</div>
</div>
</div>
<a href="{FRONTEND_URL}" class="btn" target="_blank">Get a DataPass →</a>
<div class="mcp-info">
<strong>MCP Server Endpoint:</strong><br>
<code>sse: {request.url.scheme}://{request.url.netloc}/sse</code>
</div>
</div>
</body>
</html>
"""
return HTMLResponse(content=html)
# =============================================================================
# App Initialization
# =============================================================================
app = mcp.sse_app()
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)