waroca commited on
Commit
f1b8a40
·
verified ·
1 Parent(s): 00f4bcb

Upload folder using huggingface_hub

Browse files
Files changed (9) hide show
  1. Dockerfile +22 -0
  2. __init__.py +0 -0
  3. auth.py +11 -0
  4. datasets.json +16 -0
  5. datasets_registry.py +243 -0
  6. requirements.txt +9 -0
  7. server.py +629 -0
  8. stripe_webhook.py +142 -0
  9. subscriptions_ledger.py +252 -0
Dockerfile ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Install uv
6
+ RUN pip install uv
7
+
8
+ # Copy requirements
9
+ COPY requirements.txt .
10
+
11
+ # Install dependencies
12
+ RUN uv pip install --system -r requirements.txt
13
+
14
+ # Copy code
15
+ COPY . .
16
+
17
+ # Expose port
18
+ EXPOSE 7860
19
+
20
+ # Run the server
21
+ # HF Spaces expect the app to run on port 7860
22
+ CMD ["uvicorn", "server:app", "--host", "0.0.0.0", "--port", "7860"]
__init__.py ADDED
File without changes
auth.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from huggingface_hub import whoami
3
+
4
+ def validate_hf_token(token):
5
+ """Validates the HF token and returns user info."""
6
+ try:
7
+ user_info = whoami(token=token)
8
+ return user_info
9
+ except Exception as e:
10
+ print(f"Token validation failed: {e}")
11
+ return None
datasets.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "dataset_id": "org/dataset_A",
4
+ "slug": "daily-signals",
5
+ "display_name": "Daily Signals Dataset",
6
+ "description": "Daily updated financial signals.",
7
+ "plans": [
8
+ {
9
+ "plan_id": "pro_monthly",
10
+ "stripe_price_id": "price_123",
11
+ "access_duration_days": 30
12
+ }
13
+ ],
14
+ "is_active": true
15
+ }
16
+ ]
datasets_registry.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import tempfile
4
+ from typing import List, Dict, Any, Optional
5
+ from huggingface_hub import HfApi, hf_hub_download
6
+ from huggingface_hub.utils import EntryNotFoundError
7
+
8
+ # Configuration - uses same ledger repo as subscriptions
9
+ LEDGER_REPO = os.getenv("LEDGER_DATASET_ID", "")
10
+ REGISTRY_FILE = "datasets.json"
11
+ HF_TOKEN = os.getenv("HF_TOKEN")
12
+
13
+ # Fallback to local file if LEDGER_DATASET_ID not set (for local dev)
14
+ LOCAL_REGISTRY_FILE = "datasets.json"
15
+
16
+ # Initialize HF API
17
+ api = HfApi(token=HF_TOKEN) if HF_TOKEN else None
18
+
19
+
20
+ def _use_hf_storage() -> bool:
21
+ """Check if we should use HF Dataset storage."""
22
+ return bool(LEDGER_REPO and HF_TOKEN and api)
23
+
24
+
25
+ def _download_registry() -> Optional[str]:
26
+ """Download current registry from HF Dataset."""
27
+ if not _use_hf_storage():
28
+ return None
29
+
30
+ try:
31
+ path = hf_hub_download(
32
+ repo_id=LEDGER_REPO,
33
+ filename=REGISTRY_FILE,
34
+ repo_type="dataset",
35
+ token=HF_TOKEN
36
+ )
37
+ return path
38
+ except EntryNotFoundError:
39
+ # File doesn't exist yet in the dataset
40
+ return None
41
+ except Exception as e:
42
+ print(f"Error downloading registry: {e}")
43
+ return None
44
+
45
+
46
+ def _upload_registry(local_path: str) -> bool:
47
+ """Upload registry to HF Dataset."""
48
+ if not _use_hf_storage():
49
+ return False
50
+
51
+ try:
52
+ api.upload_file(
53
+ path_or_fileobj=local_path,
54
+ path_in_repo=REGISTRY_FILE,
55
+ repo_id=LEDGER_REPO,
56
+ repo_type="dataset",
57
+ token=HF_TOKEN,
58
+ commit_message=f"Update dataset registry"
59
+ )
60
+ return True
61
+ except Exception as e:
62
+ print(f"Error uploading registry: {e}")
63
+ return False
64
+
65
+
66
+ def load_registry() -> List[Dict[str, Any]]:
67
+ """Loads the dataset registry from HF Dataset or local file."""
68
+ if _use_hf_storage():
69
+ hf_path = _download_registry()
70
+ if hf_path:
71
+ try:
72
+ with open(hf_path, "r") as f:
73
+ return json.load(f)
74
+ except json.JSONDecodeError:
75
+ print(f"Error decoding {hf_path}")
76
+ return []
77
+
78
+ # Fallback to local file
79
+ if not os.path.exists(LOCAL_REGISTRY_FILE):
80
+ return []
81
+ try:
82
+ with open(LOCAL_REGISTRY_FILE, "r") as f:
83
+ registry = json.load(f)
84
+ return registry
85
+ except json.JSONDecodeError:
86
+ print(f"Error decoding {LOCAL_REGISTRY_FILE}")
87
+ return []
88
+
89
+
90
+ def save_registry(registry: List[Dict[str, Any]]) -> bool:
91
+ """Saves the dataset registry to HF Dataset or local file."""
92
+ if _use_hf_storage():
93
+ # Create temp file with registry content
94
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as tmp:
95
+ tmp_path = tmp.name
96
+ json.dump(registry, tmp, indent=2)
97
+
98
+ # Upload to HF
99
+ success = _upload_registry(tmp_path)
100
+
101
+ # Clean up temp file
102
+ try:
103
+ os.unlink(tmp_path)
104
+ except:
105
+ pass
106
+
107
+ return success
108
+ else:
109
+ # Local file storage
110
+ with open(LOCAL_REGISTRY_FILE, "w") as f:
111
+ json.dump(registry, f, indent=2)
112
+ return True
113
+
114
+ def get_dataset_by_id(dataset_id: str) -> Optional[Dict[str, Any]]:
115
+ """Finds a dataset by its ID."""
116
+ registry = load_registry()
117
+ for dataset in registry:
118
+ if dataset.get("dataset_id") == dataset_id:
119
+ return dataset
120
+ return None
121
+
122
+ def get_dataset_by_slug(slug: str) -> Optional[Dict[str, Any]]:
123
+ """Finds a dataset by its slug."""
124
+ registry = load_registry()
125
+ for dataset in registry:
126
+ if dataset.get("slug") == slug:
127
+ return dataset
128
+ return None
129
+
130
+ def get_plan_by_price_id(price_id: str) -> Optional[Dict[str, Any]]:
131
+ """Finds a plan and its dataset by Stripe price ID."""
132
+ registry = load_registry()
133
+ for dataset in registry:
134
+ for plan in dataset.get("plans", []):
135
+ if plan.get("stripe_price_id") == price_id:
136
+ return {"dataset": dataset, "plan": plan}
137
+ return None
138
+
139
+ def get_free_plan(dataset_id: str) -> Optional[Dict[str, Any]]:
140
+ """
141
+ Securely finds a free plan for a dataset.
142
+ Returns the plan dict if found, None otherwise.
143
+ """
144
+ dataset = get_dataset_by_id(dataset_id)
145
+ if not dataset:
146
+ return None
147
+
148
+ # Explicitly check for free markers
149
+ for plan in dataset.get("plans", []):
150
+ if plan.get("stripe_price_id") in ["free", "0", 0]:
151
+ return plan
152
+
153
+ return None
154
+
155
+
156
+ def detect_dataset_format(dataset_id: str) -> Dict[str, Any]:
157
+ """
158
+ Detects the format and parquet path for a HuggingFace dataset.
159
+ Returns info about the dataset including the correct parquet URL pattern.
160
+ """
161
+ if not api:
162
+ return {
163
+ "dataset_id": dataset_id,
164
+ "error": "HF API not initialized (HF_TOKEN not set)",
165
+ "parquet_url_pattern": None
166
+ }
167
+
168
+ try:
169
+ # Get dataset info from main branch
170
+ info = api.dataset_info(dataset_id, token=HF_TOKEN)
171
+
172
+ # Check for native parquet files in main branch
173
+ parquet_paths = []
174
+ has_native_parquet = False
175
+
176
+ for sibling in info.siblings or []:
177
+ filename = sibling.rfilename
178
+ if filename.endswith('.parquet'):
179
+ parquet_paths.append(filename)
180
+ has_native_parquet = True
181
+
182
+ # Check for auto-converted parquet in refs/convert/parquet
183
+ has_converted_parquet = False
184
+ converted_parquet_paths = []
185
+ try:
186
+ convert_info = api.dataset_info(dataset_id, token=HF_TOKEN, revision='refs/convert/parquet')
187
+ for sibling in convert_info.siblings or []:
188
+ filename = sibling.rfilename
189
+ if filename.endswith('.parquet'):
190
+ converted_parquet_paths.append(filename)
191
+ has_converted_parquet = True
192
+ except Exception:
193
+ # refs/convert/parquet doesn't exist for this dataset
194
+ pass
195
+
196
+ # Determine the best parquet URL pattern
197
+ if has_native_parquet:
198
+ # Dataset has native parquet files in main branch
199
+ parquet_url_pattern = f"hf://datasets/{dataset_id}/**/*.parquet"
200
+ parquet_count = len(parquet_paths)
201
+ elif has_converted_parquet:
202
+ # Dataset was auto-converted, use refs/convert/parquet
203
+ # Note: The revision path must be URL-encoded for DuckDB
204
+ parquet_url_pattern = f"hf://datasets/{dataset_id}@refs%2Fconvert%2Fparquet/**/*.parquet"
205
+ parquet_count = len(converted_parquet_paths)
206
+ else:
207
+ # No parquet files found
208
+ parquet_url_pattern = None
209
+ parquet_count = 0
210
+
211
+ return {
212
+ "dataset_id": dataset_id,
213
+ "has_native_parquet": has_native_parquet,
214
+ "has_converted_parquet": has_converted_parquet,
215
+ "parquet_url_pattern": parquet_url_pattern,
216
+ "parquet_files_count": parquet_count,
217
+ "card_data": info.card_data.__dict__ if info.card_data else None,
218
+ }
219
+ except Exception as e:
220
+ return {
221
+ "dataset_id": dataset_id,
222
+ "error": str(e),
223
+ "parquet_url_pattern": None
224
+ }
225
+
226
+
227
+ def get_parquet_url(dataset_id: str) -> str:
228
+ """
229
+ Gets the best parquet URL pattern for a dataset.
230
+ Checks registry first, then tries to detect automatically.
231
+ """
232
+ # Check if dataset has a stored parquet_url_pattern in registry
233
+ dataset = get_dataset_by_id(dataset_id)
234
+ if dataset and dataset.get("parquet_url_pattern"):
235
+ return dataset["parquet_url_pattern"]
236
+
237
+ # Try to detect the format
238
+ format_info = detect_dataset_format(dataset_id)
239
+ if format_info.get("parquet_url_pattern"):
240
+ return format_info["parquet_url_pattern"]
241
+
242
+ # Fallback to standard pattern
243
+ return f"hf://datasets/{dataset_id}/**/*.parquet"
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ mcp
2
+ huggingface_hub
3
+ stripe
4
+ fastapi
5
+ uvicorn
6
+ python-dotenv
7
+ duckdb
8
+ pandas
9
+ numpy
server.py ADDED
@@ -0,0 +1,629 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mcp.server.fastmcp import FastMCP, Context
2
+ from fastapi import Request
3
+ from fastapi.responses import JSONResponse
4
+ from dotenv import load_dotenv
5
+ import os
6
+ import json
7
+ import datasets_registry
8
+ import subscriptions_ledger
9
+ import stripe_webhook
10
+ import auth
11
+
12
+ load_dotenv()
13
+
14
+ # Initialize MCP Server
15
+ mcp = FastMCP("datapass")
16
+
17
+ # Admin API Secret for secure admin endpoints
18
+ ADMIN_API_SECRET = os.getenv("ADMIN_API_SECRET")
19
+
20
+
21
+ def verify_admin_secret(provided_secret: str) -> bool:
22
+ """Verify the admin secret matches."""
23
+ if not ADMIN_API_SECRET:
24
+ return False # Fail closed if not configured
25
+ return provided_secret == ADMIN_API_SECRET
26
+
27
+
28
+ def get_admin_secret_from_request(request: Request, body_data: dict = None) -> str:
29
+ """Extract admin secret from header (preferred) or body (fallback)."""
30
+ header_secret = request.headers.get("X-Admin-Secret", "")
31
+ if header_secret:
32
+ return header_secret
33
+ if body_data:
34
+ return body_data.get("admin_secret", "")
35
+ return ""
36
+
37
+
38
+ def _get_access_token_from_context(ctx: Context) -> str | None:
39
+ """Extract access token from the Authorization header in the request context."""
40
+ try:
41
+ if ctx.request_context and ctx.request_context.request:
42
+ auth_header = ctx.request_context.request.headers.get("Authorization", "")
43
+ if auth_header.startswith("Bearer "):
44
+ return auth_header[7:] # Remove "Bearer " prefix
45
+ except Exception as e:
46
+ print(f"Error extracting access token: {e}")
47
+ return None
48
+
49
+
50
+ def _validate_token_for_dataset(access_token: str, dataset_id: str = None):
51
+ """Validate access token and optionally check dataset access."""
52
+ if not access_token:
53
+ return None, "No access token provided. Please include Authorization header with Bearer token."
54
+
55
+ token_info = subscriptions_ledger.validate_access_token(access_token)
56
+ if not token_info:
57
+ return None, "Invalid or expired access token."
58
+
59
+ if dataset_id and token_info["dataset_id"] != dataset_id:
60
+ return None, f"Access token not valid for dataset '{dataset_id}'. Your token is for '{token_info['dataset_id']}'."
61
+
62
+ return token_info, None
63
+
64
+
65
+ # =============================================================================
66
+ # Public MCP Tools (no auth required)
67
+ # =============================================================================
68
+
69
+ @mcp.tool()
70
+ def get_dataset_catalog():
71
+ """
72
+ Returns the catalog of all available datasets in the marketplace.
73
+ Use this to discover datasets you can subscribe to.
74
+ """
75
+ registry = datasets_registry.load_registry()
76
+ active_datasets = [d for d in registry if d.get("is_active")]
77
+ return active_datasets
78
+
79
+
80
+ # =============================================================================
81
+ # Authenticated MCP Tools (require access token via Authorization header)
82
+ # =============================================================================
83
+
84
+ @mcp.tool()
85
+ def get_my_datasets(ctx: Context):
86
+ """
87
+ Returns the list of datasets you have access to based on your subscription.
88
+ The access token is automatically read from your Authorization header.
89
+ """
90
+ access_token = _get_access_token_from_context(ctx)
91
+ if not access_token:
92
+ return {"error": "No access token provided. Please include Authorization header with Bearer token."}
93
+
94
+ token_info = subscriptions_ledger.validate_access_token(access_token)
95
+ if not token_info:
96
+ return {"error": "Invalid or expired access token."}
97
+
98
+ # Return info about the dataset this token grants access to
99
+ dataset_id = token_info["dataset_id"]
100
+ subscription = token_info.get("subscription", {})
101
+
102
+ # Get dataset details from registry
103
+ registry = datasets_registry.load_registry()
104
+ dataset = next((d for d in registry if d["dataset_id"] == dataset_id), None)
105
+
106
+ return {
107
+ "datasets": [{
108
+ "dataset_id": dataset_id,
109
+ "display_name": dataset.get("display_name", dataset_id) if dataset else dataset_id,
110
+ "description": dataset.get("description", "") if dataset else "",
111
+ "subscription_end": subscription.get("subscription_end"),
112
+ "plan_id": subscription.get("plan_id")
113
+ }],
114
+ "user": token_info["hf_user"]
115
+ }
116
+
117
+
118
+ def _get_duckdb_connection(server_hf_token: str):
119
+ """Create a DuckDB connection with HF token configured."""
120
+ import duckdb
121
+ con = duckdb.connect(database=':memory:')
122
+ con.execute("INSTALL httpfs;")
123
+ con.execute("LOAD httpfs;")
124
+ con.execute(f"CREATE SECRET hf_token (TYPE huggingface, TOKEN '{server_hf_token}');")
125
+ return con
126
+
127
+
128
+ @mcp.tool()
129
+ def get_dataset_schema(dataset_id: str, ctx: Context):
130
+ """
131
+ Returns the schema (column names and types) of a dataset you have access to.
132
+ Use this to understand the data structure before querying.
133
+
134
+ Args:
135
+ dataset_id: The ID of the dataset (e.g., 'waroca/prompts')
136
+ """
137
+ access_token = _get_access_token_from_context(ctx)
138
+ token_info, error = _validate_token_for_dataset(access_token, dataset_id)
139
+ if error:
140
+ return {"error": error}
141
+
142
+ # Use server's HF token to access the dataset
143
+ server_hf_token = os.getenv("HF_TOKEN")
144
+ if not server_hf_token:
145
+ return {"error": "Server HF token not configured"}
146
+
147
+ # Get the correct parquet URL pattern for this dataset
148
+ dataset_url = datasets_registry.get_parquet_url(dataset_id)
149
+
150
+ try:
151
+ con = _get_duckdb_connection(server_hf_token)
152
+
153
+ # Get schema
154
+ result = con.execute(f"DESCRIBE SELECT * FROM read_parquet('{dataset_url}')").fetchall()
155
+ schema = [{"column": row[0], "type": row[1]} for row in result]
156
+
157
+ # Get row count
158
+ count_result = con.execute(f"SELECT COUNT(*) FROM read_parquet('{dataset_url}')").fetchone()
159
+ row_count = count_result[0] if count_result else 0
160
+
161
+ return {
162
+ "dataset_id": dataset_id,
163
+ "schema": schema,
164
+ "row_count": row_count
165
+ }
166
+ except Exception as e:
167
+ return {"error": f"Failed to get schema: {str(e)}"}
168
+
169
+
170
+ @mcp.tool()
171
+ def get_dataset_sample(dataset_id: str, ctx: Context, num_rows: int = 5):
172
+ """
173
+ Returns a sample of rows from a dataset you have access to.
174
+ Use this to preview the data before running queries.
175
+
176
+ Args:
177
+ dataset_id: The ID of the dataset (e.g., 'waroca/prompts')
178
+ num_rows: Number of sample rows to return (default: 5, max: 20)
179
+ """
180
+ access_token = _get_access_token_from_context(ctx)
181
+ token_info, error = _validate_token_for_dataset(access_token, dataset_id)
182
+ if error:
183
+ return {"error": error}
184
+
185
+ # Limit rows
186
+ num_rows = min(max(1, num_rows), 20)
187
+
188
+ server_hf_token = os.getenv("HF_TOKEN")
189
+ if not server_hf_token:
190
+ return {"error": "Server HF token not configured"}
191
+
192
+ # Get the correct parquet URL pattern for this dataset
193
+ dataset_url = datasets_registry.get_parquet_url(dataset_id)
194
+
195
+ try:
196
+ con = _get_duckdb_connection(server_hf_token)
197
+ result = con.execute(f"SELECT * FROM read_parquet('{dataset_url}') LIMIT {num_rows}").fetchdf()
198
+ return {
199
+ "dataset_id": dataset_id,
200
+ "sample": result.to_dict(orient='records'),
201
+ "num_rows": len(result)
202
+ }
203
+ except Exception as e:
204
+ return {"error": f"Failed to get sample: {str(e)}"}
205
+
206
+
207
+ @mcp.tool()
208
+ def query_dataset(dataset_id: str, query: str, ctx: Context):
209
+ """
210
+ Execute a SQL query against a dataset you have access to.
211
+ The dataset is available as a table named 'data'.
212
+
213
+ Args:
214
+ dataset_id: The ID of the dataset (e.g., 'waroca/prompts')
215
+ query: SQL query to execute. Use 'data' as the table name.
216
+ Example: "SELECT * FROM data WHERE column > 10 LIMIT 100"
217
+ """
218
+ access_token = _get_access_token_from_context(ctx)
219
+ token_info, error = _validate_token_for_dataset(access_token, dataset_id)
220
+ if error:
221
+ return {"error": error}
222
+
223
+ server_hf_token = os.getenv("HF_TOKEN")
224
+ if not server_hf_token:
225
+ return {"error": "Server HF token not configured"}
226
+
227
+ # Get the correct parquet URL pattern for this dataset
228
+ dataset_url = datasets_registry.get_parquet_url(dataset_id)
229
+
230
+ try:
231
+ con = _get_duckdb_connection(server_hf_token)
232
+ con.execute(f"CREATE OR REPLACE VIEW data AS SELECT * FROM read_parquet('{dataset_url}');")
233
+ result = con.execute(query).fetchdf()
234
+ # Wrap in dict to ensure proper JSON serialization
235
+ return {"results": result.to_dict(orient='records'), "row_count": len(result)}
236
+ except Exception as e:
237
+ return {"error": f"Query execution failed: {str(e)}"}
238
+
239
+
240
+ @mcp.tool()
241
+ def query_dataset_natural_language(dataset_id: str, question: str, ctx: Context):
242
+ """
243
+ Query a dataset using natural language. The question will be converted to SQL automatically.
244
+
245
+ Args:
246
+ dataset_id: The ID of the dataset (e.g., 'waroca/prompts')
247
+ question: Your question in plain English.
248
+ Example: "What are the top 10 most common values in the category column?"
249
+ """
250
+ access_token = _get_access_token_from_context(ctx)
251
+ token_info, error = _validate_token_for_dataset(access_token, dataset_id)
252
+ if error:
253
+ return {"error": error}
254
+
255
+ server_hf_token = os.getenv("HF_TOKEN")
256
+ if not server_hf_token:
257
+ return {"error": "Server HF token not configured"}
258
+
259
+ # Get the correct parquet URL pattern for this dataset
260
+ dataset_url = datasets_registry.get_parquet_url(dataset_id)
261
+
262
+ try:
263
+ con = _get_duckdb_connection(server_hf_token)
264
+
265
+ # Get schema for context
266
+ schema_result = con.execute(f"DESCRIBE SELECT * FROM read_parquet('{dataset_url}')").fetchall()
267
+ schema_str = "\n".join([f" - {row[0]}: {row[1]}" for row in schema_result])
268
+
269
+ # Generate SQL using LLM
270
+ from huggingface_hub import InferenceClient
271
+ client = InferenceClient(token=server_hf_token)
272
+
273
+ system_prompt = "You are a SQL expert. Convert natural language questions into DuckDB SQL queries. Return ONLY the SQL query, nothing else. Do not explain."
274
+
275
+ user_prompt = f"""Table name: data
276
+ Schema:
277
+ {schema_str}
278
+
279
+ Question: {question}
280
+
281
+ Return ONLY the SQL query, nothing else."""
282
+
283
+ response = client.chat_completion(
284
+ model="Qwen/Qwen2.5-Coder-32B-Instruct",
285
+ messages=[
286
+ {"role": "system", "content": system_prompt},
287
+ {"role": "user", "content": user_prompt}
288
+ ],
289
+ max_tokens=300
290
+ )
291
+ generated_sql = response.choices[0].message.content.strip()
292
+
293
+ # Clean up if model adds markdown
294
+ if "```sql" in generated_sql:
295
+ generated_sql = generated_sql.split("```sql")[1].split("```")[0].strip()
296
+ elif "```" in generated_sql:
297
+ generated_sql = generated_sql.split("```")[1].split("```")[0].strip()
298
+
299
+ # Remove trailing semicolon if present (we'll add it if needed)
300
+ generated_sql = generated_sql.rstrip(';').strip()
301
+
302
+ # Execute the query
303
+ con.execute(f"CREATE OR REPLACE VIEW data AS SELECT * FROM read_parquet('{dataset_url}');")
304
+ result = con.execute(generated_sql).fetchdf()
305
+
306
+ return {
307
+ "question": question,
308
+ "generated_sql": generated_sql,
309
+ "result": result.to_dict(orient='records')
310
+ }
311
+ except Exception as e:
312
+ return {"error": f"Natural language query failed: {str(e)}"}
313
+
314
+
315
+ # =============================================================================
316
+ # Admin Helper Functions (used by HTTP API endpoints, not exposed as MCP tools)
317
+ # =============================================================================
318
+
319
+ def _admin_get_subscriber_stats():
320
+ """Returns subscriber statistics."""
321
+ all_subs = subscriptions_ledger.get_all_subscriptions()
322
+ stats = {}
323
+ active_count = 0
324
+
325
+ from datetime import datetime
326
+
327
+ for sub in all_subs.values():
328
+ d_id = sub["dataset_id"]
329
+ if d_id not in stats:
330
+ stats[d_id] = {"total": 0, "active": 0}
331
+
332
+ stats[d_id]["total"] += 1
333
+
334
+ end_str = sub.get("subscription_end", "")
335
+ if end_str:
336
+ try:
337
+ if end_str.endswith("Z"):
338
+ end_str = end_str[:-1]
339
+ end_date = datetime.fromisoformat(end_str)
340
+ if end_date > datetime.utcnow():
341
+ stats[d_id]["active"] += 1
342
+ active_count += 1
343
+ except:
344
+ pass
345
+
346
+ return {
347
+ "by_dataset": stats,
348
+ "total_active": active_count,
349
+ "total_subscriptions": len(all_subs)
350
+ }
351
+
352
+
353
+ def _admin_get_subscribers_for_dataset(dataset_id: str):
354
+ """Returns list of subscribers for a dataset."""
355
+ all_subs = subscriptions_ledger.get_all_subscriptions()
356
+ subscribers = []
357
+
358
+ from datetime import datetime
359
+
360
+ for sub in all_subs.values():
361
+ if sub["dataset_id"] == dataset_id:
362
+ end_str = sub.get("subscription_end", "")
363
+ is_active = False
364
+ if end_str:
365
+ try:
366
+ if end_str.endswith("Z"):
367
+ end_str = end_str[:-1]
368
+ end_date = datetime.fromisoformat(end_str)
369
+ is_active = end_date > datetime.utcnow()
370
+ except:
371
+ pass
372
+
373
+ subscribers.append({
374
+ **sub,
375
+ "is_active": is_active
376
+ })
377
+
378
+ return subscribers
379
+
380
+
381
+ def _admin_update_dataset_registry(dataset_id: str, action: str, data: str):
382
+ """Updates the dataset registry."""
383
+ registry = datasets_registry.load_registry()
384
+
385
+ if action == "add":
386
+ new_dataset = json.loads(data)
387
+ registry.append(new_dataset)
388
+ elif action == "toggle_active":
389
+ for d in registry:
390
+ if d["dataset_id"] == dataset_id:
391
+ d["is_active"] = not d.get("is_active", False)
392
+ elif action == "delete":
393
+ registry = [d for d in registry if d["dataset_id"] != dataset_id]
394
+ elif action == "update":
395
+ updated_dataset = json.loads(data)
396
+ for i, d in enumerate(registry):
397
+ if d["dataset_id"] == dataset_id:
398
+ registry[i] = updated_dataset
399
+ break
400
+
401
+ success = datasets_registry.save_registry(registry)
402
+ if success:
403
+ return {"status": "success"}
404
+ else:
405
+ return {"error": "Failed to save registry"}
406
+
407
+
408
+ # =============================================================================
409
+ # HTTP API Routes (for frontend, not MCP)
410
+ # =============================================================================
411
+
412
+ @mcp.custom_route("/webhook", methods=["POST"])
413
+ async def webhook_handler(request: Request):
414
+ return await stripe_webhook.handle_stripe_webhook(request)
415
+
416
+
417
+ async def _subscribe_free_logic(dataset_id: str, hf_token: str):
418
+ user_info = auth.validate_hf_token(hf_token)
419
+ if not user_info:
420
+ return {"error": "Invalid HF token"}
421
+
422
+ hf_user = user_info["name"]
423
+ registry = datasets_registry.load_registry()
424
+ dataset = next((d for d in registry if d["dataset_id"] == dataset_id), None)
425
+
426
+ if not dataset or not dataset.get("is_active"):
427
+ return {"error": "Dataset not found or inactive"}
428
+
429
+ free_plan = datasets_registry.get_free_plan(dataset_id)
430
+
431
+ if not free_plan:
432
+ return {"error": "No free plan available for this dataset."}
433
+
434
+ from datetime import datetime, timedelta
435
+
436
+ duration_days = free_plan.get("access_duration_days", 1) # Default to 1-day trial
437
+ end_date = datetime.utcnow() + timedelta(days=duration_days)
438
+ access_token = subscriptions_ledger.generate_access_token()
439
+
440
+ ledger_entry = {
441
+ "event_id": f"free_{hf_user}_{datetime.utcnow().timestamp()}",
442
+ "hf_user": hf_user,
443
+ "dataset_id": dataset_id,
444
+ "plan_id": free_plan["plan_id"],
445
+ "subscription_start": datetime.utcnow().isoformat() + "Z",
446
+ "subscription_end": end_date.isoformat() + "Z",
447
+ "source": "free_tier",
448
+ "access_token": access_token,
449
+ "created_at": datetime.utcnow().isoformat() + "Z"
450
+ }
451
+
452
+ subscriptions_ledger.append_subscription_event(ledger_entry)
453
+ return {"status": "success", "message": "Subscribed successfully", "access_token": access_token}
454
+
455
+
456
+ async def _create_checkout_session_logic(dataset_id: str, hf_token: str):
457
+ user_info = auth.validate_hf_token(hf_token)
458
+ if not user_info:
459
+ return {"error": "Invalid HF token"}
460
+
461
+ hf_user = user_info["name"]
462
+ registry = datasets_registry.load_registry()
463
+ dataset = next((d for d in registry if d["dataset_id"] == dataset_id), None)
464
+
465
+ if not dataset or not dataset.get("is_active"):
466
+ return {"error": "Dataset not found or inactive"}
467
+
468
+ if not dataset.get("plans"):
469
+ return {"error": "No plans available for this dataset"}
470
+
471
+ plan = dataset["plans"][0]
472
+ price_id = plan.get("stripe_price_id")
473
+
474
+ if price_id in ["free", "0", 0]:
475
+ return {"error": "This is a free dataset, use subscribe_free instead"}
476
+
477
+ if not price_id:
478
+ return {"error": "Price ID not configured"}
479
+
480
+ try:
481
+ import stripe
482
+ stripe.api_key = os.getenv("STRIPE_SECRET_KEY")
483
+
484
+ checkout_session = stripe.checkout.Session.create(
485
+ payment_method_types=['card'],
486
+ line_items=[{'price': price_id, 'quantity': 1}],
487
+ mode='subscription',
488
+ metadata={'hf_user': hf_user, 'dataset_id': dataset_id},
489
+ success_url='https://huggingface.co/spaces/waroca/monetization-frontend?success=true',
490
+ cancel_url='https://huggingface.co/spaces/waroca/monetization-frontend?canceled=true',
491
+ )
492
+ return {"checkout_url": checkout_session.url}
493
+ except Exception as e:
494
+ return {"error": str(e)}
495
+
496
+
497
+ @mcp.custom_route("/api/subscribe_free", methods=["POST"])
498
+ async def api_subscribe_free(request: Request):
499
+ data = await request.json()
500
+ result = await _subscribe_free_logic(data.get("dataset_id"), data.get("hf_token"))
501
+ return JSONResponse(result)
502
+
503
+
504
+ @mcp.custom_route("/api/create_checkout_session", methods=["POST"])
505
+ async def api_create_checkout_session(request: Request):
506
+ data = await request.json()
507
+ result = await _create_checkout_session_logic(data.get("dataset_id"), data.get("hf_token"))
508
+ return JSONResponse(result)
509
+
510
+
511
+ @mcp.custom_route("/api/user_subscriptions", methods=["POST"])
512
+ async def api_user_subscriptions(request: Request):
513
+ """Get subscriptions for the current user."""
514
+ data = await request.json()
515
+ hf_user = data.get("hf_user")
516
+ hf_token = data.get("hf_token")
517
+
518
+ if not hf_user:
519
+ return JSONResponse({"error": "hf_user required"}, status_code=400)
520
+
521
+ if not hf_token:
522
+ return JSONResponse({"error": "hf_token required"}, status_code=401)
523
+
524
+ user_info = auth.validate_hf_token(hf_token)
525
+ if not user_info:
526
+ return JSONResponse({"error": "Invalid HF token"}, status_code=401)
527
+
528
+ if user_info.get("name") != hf_user:
529
+ return JSONResponse({"error": "Token does not match user"}, status_code=403)
530
+
531
+ subs = subscriptions_ledger.get_user_subscriptions(hf_user)
532
+ return JSONResponse(subs)
533
+
534
+
535
+ @mcp.custom_route("/api/catalog", methods=["GET"])
536
+ async def api_catalog(request: Request):
537
+ """Public endpoint for getting the dataset catalog."""
538
+ catalog = get_dataset_catalog()
539
+ return JSONResponse(catalog)
540
+
541
+
542
+ # =============================================================================
543
+ # Admin API Routes
544
+ # =============================================================================
545
+
546
+ @mcp.custom_route("/api/admin/subscriber_stats", methods=["POST"])
547
+ async def api_admin_subscriber_stats(request: Request):
548
+ data = await request.json()
549
+ admin_secret = get_admin_secret_from_request(request, data)
550
+
551
+ if not verify_admin_secret(admin_secret):
552
+ return JSONResponse({"error": "Unauthorized"}, status_code=401)
553
+
554
+ result = _admin_get_subscriber_stats()
555
+ return JSONResponse(result)
556
+
557
+
558
+ @mcp.custom_route("/api/admin/subscribers_for_dataset", methods=["POST"])
559
+ async def api_admin_subscribers_for_dataset(request: Request):
560
+ data = await request.json()
561
+ admin_secret = get_admin_secret_from_request(request, data)
562
+ dataset_id = data.get("dataset_id")
563
+
564
+ if not verify_admin_secret(admin_secret):
565
+ return JSONResponse({"error": "Unauthorized"}, status_code=401)
566
+
567
+ if not dataset_id:
568
+ return JSONResponse({"error": "dataset_id required"}, status_code=400)
569
+
570
+ result = _admin_get_subscribers_for_dataset(dataset_id)
571
+ return JSONResponse(result)
572
+
573
+
574
+ @mcp.custom_route("/api/admin/update_registry", methods=["POST"])
575
+ async def api_admin_update_registry(request: Request):
576
+ data = await request.json()
577
+ admin_secret = get_admin_secret_from_request(request, data)
578
+ dataset_id = data.get("dataset_id", "")
579
+ action = data.get("action")
580
+ payload = data.get("data", "")
581
+
582
+ if not verify_admin_secret(admin_secret):
583
+ return JSONResponse({"error": "Unauthorized"}, status_code=401)
584
+
585
+ if not action:
586
+ return JSONResponse({"error": "action required"}, status_code=400)
587
+
588
+ result = _admin_update_dataset_registry(dataset_id, action, payload)
589
+ return JSONResponse(result)
590
+
591
+
592
+ @mcp.custom_route("/api/admin/catalog", methods=["POST"])
593
+ async def api_admin_catalog(request: Request):
594
+ data = await request.json()
595
+ admin_secret = get_admin_secret_from_request(request, data)
596
+
597
+ if not verify_admin_secret(admin_secret):
598
+ return JSONResponse({"error": "Unauthorized"}, status_code=401)
599
+
600
+ registry = datasets_registry.load_registry()
601
+ return JSONResponse(registry)
602
+
603
+
604
+ @mcp.custom_route("/api/admin/detect_format", methods=["POST"])
605
+ async def api_admin_detect_format(request: Request):
606
+ """Detect the format and parquet path for a HuggingFace dataset."""
607
+ data = await request.json()
608
+ admin_secret = get_admin_secret_from_request(request, data)
609
+ dataset_id = data.get("dataset_id")
610
+
611
+ if not verify_admin_secret(admin_secret):
612
+ return JSONResponse({"error": "Unauthorized"}, status_code=401)
613
+
614
+ if not dataset_id:
615
+ return JSONResponse({"error": "dataset_id required"}, status_code=400)
616
+
617
+ result = datasets_registry.detect_dataset_format(dataset_id)
618
+ return JSONResponse(result)
619
+
620
+
621
+ # =============================================================================
622
+ # App Initialization
623
+ # =============================================================================
624
+
625
+ app = mcp.sse_app()
626
+
627
+ if __name__ == "__main__":
628
+ import uvicorn
629
+ uvicorn.run(app, host="0.0.0.0", port=8000)
stripe_webhook.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import stripe
2
+ import os
3
+ import json
4
+ from fastapi import Request, HTTPException
5
+ from datetime import datetime, timedelta
6
+ import datasets_registry
7
+ import subscriptions_ledger
8
+
9
+ from fastapi.responses import JSONResponse
10
+
11
+ async def handle_stripe_webhook(request: Request):
12
+ payload = await request.body()
13
+ sig_header = request.headers.get("stripe-signature")
14
+ webhook_secret = os.getenv("STRIPE_WEBHOOK_SECRET")
15
+
16
+ try:
17
+ event = stripe.Webhook.construct_event(
18
+ payload, sig_header, webhook_secret
19
+ )
20
+ except ValueError as e:
21
+ raise HTTPException(status_code=400, detail="Invalid payload")
22
+ except stripe.error.SignatureVerificationError as e:
23
+ raise HTTPException(status_code=400, detail="Invalid signature")
24
+
25
+ event_type = event["type"]
26
+ data_object = event["data"]["object"]
27
+
28
+ if event_type == "checkout.session.completed":
29
+ await handle_checkout_session(data_object, event["id"])
30
+ elif event_type == "invoice.paid":
31
+ await handle_invoice_paid(data_object, event["id"])
32
+
33
+ return JSONResponse({"status": "success"})
34
+
35
+ async def handle_checkout_session(session, event_id):
36
+ # Metadata should contain hf_user and dataset_id
37
+ metadata = session.get("metadata", {})
38
+ hf_user = metadata.get("hf_user")
39
+ dataset_id = metadata.get("dataset_id")
40
+
41
+ if not hf_user or not dataset_id:
42
+ print(f"Missing metadata in session {session['id']}")
43
+ return
44
+
45
+ # Determine plan and duration
46
+ # In a real scenario, we might look up the line items to find the price ID
47
+ # For simplicity, we assume the first line item's price ID matches our registry
48
+ # Or we can pass the plan_id in metadata too.
49
+
50
+ # Let's try to get price_id from the session if possible, or rely on metadata
51
+ # If subscription mode, we might need to fetch the subscription details
52
+ subscription_id = session.get("subscription")
53
+ if subscription_id:
54
+ # It's a subscription
55
+ sub = stripe.Subscription.retrieve(subscription_id)
56
+ price_id = sub["items"]["data"][0]["price"]["id"]
57
+ current_period_end = sub["current_period_end"]
58
+ end_date = datetime.fromtimestamp(current_period_end)
59
+ else:
60
+ # One-time payment?
61
+ # For this hackathon, let's assume subscriptions.
62
+ print("Non-subscription checkout not fully supported yet.")
63
+ return
64
+
65
+ plan_info = datasets_registry.get_plan_by_price_id(price_id)
66
+ if not plan_info:
67
+ print(f"Unknown price_id {price_id}")
68
+ return
69
+
70
+ plan_id = plan_info["plan"]["plan_id"]
71
+
72
+ # Generate unique access token for this subscription
73
+ access_token = subscriptions_ledger.generate_access_token()
74
+
75
+ ledger_entry = {
76
+ "event_id": event_id,
77
+ "hf_user": hf_user,
78
+ "dataset_id": dataset_id,
79
+ "plan_id": plan_id,
80
+ "subscription_start": datetime.utcnow().isoformat() + "Z",
81
+ "subscription_end": end_date.isoformat() + "Z",
82
+ "source": "stripe",
83
+ "access_token": access_token,
84
+ "created_at": datetime.utcnow().isoformat() + "Z",
85
+ "stripe_customer_id": session.get("customer"),
86
+ "stripe_subscription_id": subscription_id
87
+ }
88
+
89
+ subscriptions_ledger.append_subscription_event(ledger_entry)
90
+
91
+ async def handle_invoice_paid(invoice, event_id):
92
+ subscription_id = invoice.get("subscription")
93
+ if not subscription_id:
94
+ return
95
+
96
+ # We need to find the user and dataset associated with this subscription
97
+ # We can query Stripe or look up in our ledger if we stored subscription_id
98
+ # For now, let's assume we can get it from the subscription metadata in Stripe
99
+
100
+ sub = stripe.Subscription.retrieve(subscription_id)
101
+ metadata = sub.get("metadata", {})
102
+ hf_user = metadata.get("hf_user")
103
+ dataset_id = metadata.get("dataset_id")
104
+
105
+ if not hf_user or not dataset_id:
106
+ # Try to find from previous ledger entries?
107
+ # For simplicity, we assume metadata is preserved on the subscription object in Stripe
108
+ print(f"Missing metadata in subscription {subscription_id}")
109
+ return
110
+
111
+ price_id = sub["items"]["data"][0]["price"]["id"]
112
+ current_period_end = sub["current_period_end"]
113
+ end_date = datetime.fromtimestamp(current_period_end)
114
+
115
+ plan_info = datasets_registry.get_plan_by_price_id(price_id)
116
+ if not plan_info:
117
+ return
118
+
119
+ plan_id = plan_info["plan"]["plan_id"]
120
+
121
+ # For renewals, try to preserve the existing access token
122
+ existing_sub = subscriptions_ledger.get_active_subscription(hf_user, dataset_id)
123
+ if existing_sub and existing_sub.get("access_token"):
124
+ access_token = existing_sub["access_token"]
125
+ else:
126
+ access_token = subscriptions_ledger.generate_access_token()
127
+
128
+ ledger_entry = {
129
+ "event_id": event_id,
130
+ "hf_user": hf_user,
131
+ "dataset_id": dataset_id,
132
+ "plan_id": plan_id,
133
+ "subscription_start": datetime.utcnow().isoformat() + "Z", # Or period start
134
+ "subscription_end": end_date.isoformat() + "Z",
135
+ "source": "stripe",
136
+ "access_token": access_token,
137
+ "created_at": datetime.utcnow().isoformat() + "Z",
138
+ "stripe_customer_id": invoice.get("customer"),
139
+ "stripe_subscription_id": subscription_id
140
+ }
141
+
142
+ subscriptions_ledger.append_subscription_event(ledger_entry)
subscriptions_ledger.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import secrets
4
+ import tempfile
5
+ from datetime import datetime
6
+ from typing import Dict, Any, Optional
7
+ from huggingface_hub import HfApi, hf_hub_download
8
+ from huggingface_hub.utils import EntryNotFoundError
9
+
10
+ # Configuration for HF Dataset-based ledger
11
+ LEDGER_REPO = os.getenv("LEDGER_DATASET_ID", "")
12
+ LEDGER_FILE = "subscriptions.jsonl"
13
+ HF_TOKEN = os.getenv("HF_TOKEN")
14
+
15
+ # Fallback to local file if LEDGER_DATASET_ID not set (for local dev)
16
+ LOCAL_LEDGER_FILE = "../subscriptions_ledger/subscriptions.jsonl"
17
+
18
+ # Initialize HF API
19
+ api = HfApi(token=HF_TOKEN) if HF_TOKEN else None
20
+
21
+
22
+ def _use_hf_storage() -> bool:
23
+ """Check if we should use HF Dataset storage."""
24
+ return bool(LEDGER_REPO and HF_TOKEN and api)
25
+
26
+
27
+ def _download_ledger() -> Optional[str]:
28
+ """Download current ledger from HF Dataset."""
29
+ if not _use_hf_storage():
30
+ return None
31
+
32
+ try:
33
+ path = hf_hub_download(
34
+ repo_id=LEDGER_REPO,
35
+ filename=LEDGER_FILE,
36
+ repo_type="dataset",
37
+ token=HF_TOKEN
38
+ )
39
+ return path
40
+ except EntryNotFoundError:
41
+ # File doesn't exist yet in the dataset
42
+ return None
43
+ except Exception as e:
44
+ print(f"Error downloading ledger: {e}")
45
+ return None
46
+
47
+
48
+ def _upload_ledger(local_path: str) -> bool:
49
+ """Upload ledger to HF Dataset."""
50
+ if not _use_hf_storage():
51
+ return False
52
+
53
+ try:
54
+ api.upload_file(
55
+ path_or_fileobj=local_path,
56
+ path_in_repo=LEDGER_FILE,
57
+ repo_id=LEDGER_REPO,
58
+ repo_type="dataset",
59
+ token=HF_TOKEN,
60
+ commit_message=f"Update subscriptions ledger - {datetime.utcnow().isoformat()}"
61
+ )
62
+ return True
63
+ except Exception as e:
64
+ print(f"Error uploading ledger: {e}")
65
+ return False
66
+
67
+
68
+ def _get_ledger_path() -> str:
69
+ """Get the path to read the ledger from."""
70
+ if _use_hf_storage():
71
+ hf_path = _download_ledger()
72
+ if hf_path:
73
+ return hf_path
74
+
75
+ # Fallback to local file
76
+ return LOCAL_LEDGER_FILE
77
+
78
+
79
+ def append_subscription_event(event: Dict[str, Any]) -> bool:
80
+ """
81
+ Appends a subscription event to the ledger.
82
+ If using HF Dataset, downloads, appends, and re-uploads.
83
+ """
84
+ # Ensure timestamp is present
85
+ if "created_at" not in event:
86
+ event["created_at"] = datetime.utcnow().isoformat() + "Z"
87
+
88
+ if _use_hf_storage():
89
+ # Download current ledger
90
+ current_path = _download_ledger()
91
+
92
+ # Create temp file to work with
93
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as tmp:
94
+ tmp_path = tmp.name
95
+
96
+ # Copy existing content if any
97
+ if current_path and os.path.exists(current_path):
98
+ with open(current_path, 'r') as f:
99
+ tmp.write(f.read())
100
+
101
+ # Append new event
102
+ tmp.write(json.dumps(event) + "\n")
103
+
104
+ # Upload back to HF
105
+ success = _upload_ledger(tmp_path)
106
+
107
+ # Clean up temp file
108
+ try:
109
+ os.unlink(tmp_path)
110
+ except:
111
+ pass
112
+
113
+ return success
114
+ else:
115
+ # Local file storage
116
+ parent_dir = os.path.dirname(LOCAL_LEDGER_FILE)
117
+ if parent_dir:
118
+ os.makedirs(parent_dir, exist_ok=True)
119
+ with open(LOCAL_LEDGER_FILE, "a") as f:
120
+ f.write(json.dumps(event) + "\n")
121
+ return True
122
+
123
+
124
+ def get_all_subscriptions() -> Dict[tuple, Dict[str, Any]]:
125
+ """
126
+ Reads the ledger and folds events to find the latest state for each (user, dataset).
127
+ Returns a dict: key=(hf_user, dataset_id), value=subscription_record
128
+ """
129
+ subscriptions = {}
130
+
131
+ ledger_path = _get_ledger_path()
132
+
133
+ if not ledger_path or not os.path.exists(ledger_path):
134
+ return subscriptions
135
+
136
+ try:
137
+ with open(ledger_path, "r") as f:
138
+ for line in f:
139
+ line = line.strip()
140
+ if not line:
141
+ continue
142
+ try:
143
+ event = json.loads(line)
144
+ hf_user = event.get("hf_user")
145
+ dataset_id = event.get("dataset_id")
146
+
147
+ if hf_user and dataset_id:
148
+ key = (hf_user, dataset_id)
149
+ # Append-only ledger: latest record is current state
150
+ subscriptions[key] = event
151
+ except json.JSONDecodeError:
152
+ continue
153
+ except Exception as e:
154
+ print(f"Error reading ledger: {e}")
155
+
156
+ return subscriptions
157
+
158
+
159
+ def get_active_subscription(hf_user: str, dataset_id: str) -> Optional[Dict[str, Any]]:
160
+ """Checks if a user has an active subscription to a dataset."""
161
+ all_subs = get_all_subscriptions()
162
+ sub = all_subs.get((hf_user, dataset_id))
163
+
164
+ if not sub:
165
+ return None
166
+
167
+ # Check expiry
168
+ subscription_end = sub.get("subscription_end")
169
+ if not subscription_end:
170
+ return None
171
+
172
+ try:
173
+ # Handle Z suffix if present
174
+ if subscription_end.endswith("Z"):
175
+ subscription_end = subscription_end[:-1]
176
+ end_date = datetime.fromisoformat(subscription_end)
177
+ if end_date > datetime.utcnow():
178
+ return sub
179
+ except ValueError:
180
+ print(f"Error parsing date: {subscription_end}")
181
+ return None
182
+
183
+ return None
184
+
185
+
186
+ def get_user_subscriptions(hf_user: str) -> list:
187
+ """Get all subscriptions for a specific user."""
188
+ all_subs = get_all_subscriptions()
189
+ user_subs = []
190
+
191
+ for (user, dataset_id), sub in all_subs.items():
192
+ if user == hf_user:
193
+ # Add active status
194
+ end_str = sub.get("subscription_end", "")
195
+ is_active = False
196
+ if end_str:
197
+ try:
198
+ if end_str.endswith("Z"):
199
+ end_str = end_str[:-1]
200
+ end_date = datetime.fromisoformat(end_str)
201
+ is_active = end_date > datetime.utcnow()
202
+ except:
203
+ pass
204
+
205
+ user_subs.append({
206
+ **sub,
207
+ "is_active": is_active
208
+ })
209
+
210
+ return user_subs
211
+
212
+
213
+ def generate_access_token() -> str:
214
+ """Generate a secure random access token."""
215
+ return f"hfdata_{secrets.token_urlsafe(32)}"
216
+
217
+
218
+ def validate_access_token(access_token: str) -> Optional[Dict[str, Any]]:
219
+ """
220
+ Validate an access token and return the subscription info if valid.
221
+ Returns None if token is invalid or subscription expired.
222
+ """
223
+ all_subs = get_all_subscriptions()
224
+
225
+ for (hf_user, dataset_id), sub in all_subs.items():
226
+ if sub.get("access_token") == access_token:
227
+ # Check if subscription is still active
228
+ end_str = sub.get("subscription_end", "")
229
+ if end_str:
230
+ try:
231
+ if end_str.endswith("Z"):
232
+ end_str = end_str[:-1]
233
+ end_date = datetime.fromisoformat(end_str)
234
+ if end_date > datetime.utcnow():
235
+ return {
236
+ "hf_user": hf_user,
237
+ "dataset_id": dataset_id,
238
+ "subscription": sub
239
+ }
240
+ except:
241
+ pass
242
+ return None # Token found but subscription expired
243
+
244
+ return None # Token not found
245
+
246
+
247
+ def get_subscription_by_token(access_token: str) -> Optional[Dict[str, Any]]:
248
+ """
249
+ Get subscription details by access token.
250
+ Alias for validate_access_token for clarity.
251
+ """
252
+ return validate_access_token(access_token)