jonathanagustin commited on
Commit
15bf590
·
verified ·
1 Parent(s): d8953ff

fix: query ALL parquet shards, not just shard 0

Browse files
Files changed (1) hide show
  1. app.py +128 -212
app.py CHANGED
@@ -1,23 +1,23 @@
1
  """LawForge Data API - HuggingFace Space
2
 
3
  FastAPI service to query CourtListener parquet data directly.
4
- Bypasses datasets-server limitations for private datasets.
5
  """
6
 
7
  import os
8
- from functools import lru_cache
9
- from typing import Optional
10
 
11
  import duckdb
 
12
  from fastapi import FastAPI, HTTPException, Query
13
  from fastapi.middleware.cors import CORSMiddleware
14
  from huggingface_hub import hf_hub_download
15
- import pandas as pd
16
 
17
  app = FastAPI(
18
  title="LawForge Data API",
19
  description="Query CourtListener legal data",
20
- version="1.0.0"
21
  )
22
 
23
  app.add_middleware(
@@ -31,50 +31,87 @@ app.add_middleware(
31
  # Configuration
32
  DATASET_ID = "jonathanagustin/courtlistener-1"
33
  HF_TOKEN = os.environ.get("HF_TOKEN")
 
 
34
 
35
- # Cache for DuckDB connections
36
- _db_cache = {}
 
37
 
38
 
39
- def get_parquet_path(config: str, shard: int = 0) -> str:
40
- """Download and cache parquet file, return local path."""
41
- cache_key = f"{config}_{shard}"
42
- if cache_key not in _db_cache:
43
- filename = f"data/{config}/{config}-{shard:05d}.parquet"
44
- print(f"Downloading: {filename}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  try:
46
  local_path = hf_hub_download(
47
  repo_id=DATASET_ID,
48
  filename=filename,
49
  repo_type="dataset",
50
- token=HF_TOKEN
 
51
  )
52
- print(f"Downloaded to: {local_path}")
53
- _db_cache[cache_key] = local_path
54
  except Exception as e:
55
  print(f"Error downloading {filename}: {e}")
56
- raise HTTPException(status_code=404, detail=f"Parquet file not found: {filename}. Error: {str(e)}")
57
- return _db_cache[cache_key]
58
-
59
-
60
- def query_parquet(config: str, sql: str, params: dict = None) -> list:
61
- """Execute SQL query on parquet file."""
62
- path = get_parquet_path(config)
 
 
 
 
 
63
  try:
64
  conn = duckdb.connect(":memory:")
65
- conn.execute(f"CREATE VIEW data AS SELECT * FROM read_parquet('{path}')")
66
-
67
- if params:
68
- result = conn.execute(sql, params).fetchdf()
69
  else:
70
- result = conn.execute(sql).fetchdf()
71
-
 
 
72
  conn.close()
73
-
74
- # Convert to JSON-safe format
75
- import json
76
- import numpy as np
77
-
78
  def clean_value(v):
79
  if v is None:
80
  return None
@@ -85,224 +122,103 @@ def query_parquet(config: str, sql: str, params: dict = None) -> list:
85
  if isinstance(v, (np.floating, np.float64)):
86
  return float(v)
87
  return v
88
-
89
- records = []
90
- for _, row in result.iterrows():
91
- record = {k: clean_value(v) for k, v in row.items()}
92
- records.append(record)
93
-
94
- return records
95
  except Exception as e:
96
- import traceback
97
- raise HTTPException(status_code=500, detail=f"Query error: {str(e)}. Traceback: {traceback.format_exc()}")
98
 
99
 
100
  @app.get("/")
101
  def root():
 
 
102
  return {
103
  "name": "LawForge Data API",
104
- "version": "1.0.0",
 
105
  "endpoints": {
106
  "/health": "Health check",
107
- "/rows/{config}": "Get rows from a config",
108
  "/search/{config}": "Full-text search",
109
  "/filter/{config}": "SQL WHERE filter",
 
110
  }
111
  }
112
 
113
 
114
  @app.get("/health")
115
  def health():
116
- token_status = "set" if HF_TOKEN else "not set"
117
- token_len = len(HF_TOKEN) if HF_TOKEN else 0
118
- return {"status": "ok", "hf_token": token_status, "token_len": token_len}
119
 
120
 
121
- @app.get("/test-download")
122
- def test_download():
123
- """Test downloading a parquet file."""
124
- from huggingface_hub import hf_hub_download
125
- try:
126
- local_path = hf_hub_download(
127
- repo_id=DATASET_ID,
128
- filename="data/courts/courts-00000.parquet",
129
- repo_type="dataset",
130
- token=HF_TOKEN
131
- )
132
- import os
133
- size = os.path.getsize(local_path)
134
- return {"status": "ok", "path": local_path, "size_bytes": size}
135
- except Exception as e:
136
- return {"status": "error", "error": str(e), "type": type(e).__name__}
137
-
138
-
139
- @app.get("/test-query")
140
- def test_query():
141
- """Test querying a parquet file."""
142
- try:
143
- path = get_parquet_path("courts")
144
- conn = duckdb.connect(":memory:")
145
- conn.execute(f"CREATE VIEW data AS SELECT * FROM read_parquet('{path}')")
146
- result = conn.execute("SELECT COUNT(*) as cnt FROM data").fetchdf()
147
- count = int(result['cnt'].iloc[0])
148
-
149
- # Get one row
150
- row = conn.execute("SELECT * FROM data LIMIT 1").fetchdf()
151
- conn.close()
152
-
153
- # Convert to dict
154
- row_dict = row.to_dict(orient="records")[0] if len(row) > 0 else {}
155
-
156
- return {"status": "ok", "count": count, "sample_row_keys": list(row_dict.keys())}
157
- except Exception as e:
158
- import traceback
159
- return {"status": "error", "error": str(e), "type": type(e).__name__, "traceback": traceback.format_exc()}
160
 
161
 
162
  @app.get("/rows/{config}")
163
- def get_rows(
164
- config: str,
165
- offset: int = Query(0, ge=0),
166
- limit: int = Query(20, ge=1, le=100)
167
- ):
168
- """Get paginated rows from a config."""
169
- import traceback
170
- try:
171
- sql = f"SELECT * FROM data LIMIT {limit} OFFSET {offset}"
172
- rows = query_parquet(config, sql)
173
-
174
- # Get total count
175
- count_sql = "SELECT COUNT(*) as cnt FROM data"
176
- count_result = query_parquet(config, count_sql)
177
- total = count_result[0]["cnt"] if count_result else 0
178
-
179
- return {
180
- "rows": rows,
181
- "total": total,
182
- "offset": offset,
183
- "limit": limit
184
- }
185
- except HTTPException:
186
- raise
187
- except Exception as e:
188
- return {"error": str(e), "traceback": traceback.format_exc()}
189
 
190
 
191
  @app.get("/search/{config}")
192
- def search(
193
- config: str,
194
- q: str = Query(..., min_length=1),
195
- offset: int = Query(0, ge=0),
196
- limit: int = Query(20, ge=1, le=100)
197
- ):
198
- """Full-text search on a config."""
199
- try:
200
- # Build search query based on config
201
- if config == "opinions":
202
- search_cols = ["plain_text", "html"]
203
- elif config == "opinion-clusters":
204
- search_cols = ["case_name", "case_name_full", "syllabus"]
205
- elif config == "dockets":
206
- search_cols = ["case_name", "case_name_full", "docket_number"]
207
- else:
208
- search_cols = ["*"]
209
-
210
- # Create WHERE clause for text search
211
- where_clauses = []
212
- for col in search_cols:
213
- if col == "*":
214
- where_clauses.append(f"CAST(data AS VARCHAR) ILIKE '%{q}%'")
215
- else:
216
- where_clauses.append(f"COALESCE({col}, '') ILIKE '%{q}%'")
217
-
218
- where = " OR ".join(where_clauses)
219
- sql = f"SELECT * FROM data WHERE {where} LIMIT {limit} OFFSET {offset}"
220
- rows = query_parquet(config, sql)
221
-
222
- return {
223
- "rows": rows,
224
- "query": q,
225
- "offset": offset,
226
- "limit": limit
227
- }
228
- except Exception as e:
229
- raise HTTPException(status_code=500, detail=str(e))
230
 
231
 
232
  @app.get("/filter/{config}")
233
- def filter_rows(
234
- config: str,
235
- where: str = Query(..., min_length=1),
236
- offset: int = Query(0, ge=0),
237
- limit: int = Query(20, ge=1, le=100)
238
- ):
239
- """Filter rows using SQL WHERE clause."""
240
- try:
241
- # Sanitize WHERE clause (basic protection)
242
- forbidden = ["DROP", "DELETE", "INSERT", "UPDATE", "ALTER", "CREATE", ";"]
243
- where_upper = where.upper()
244
- for word in forbidden:
245
- if word in where_upper:
246
- raise HTTPException(status_code=400, detail=f"Forbidden SQL keyword: {word}")
247
-
248
- sql = f"SELECT * FROM data WHERE {where} LIMIT {limit} OFFSET {offset}"
249
- rows = query_parquet(config, sql)
250
-
251
- return {
252
- "rows": rows,
253
- "where": where,
254
- "offset": offset,
255
- "limit": limit
256
- }
257
- except HTTPException:
258
- raise
259
- except Exception as e:
260
- raise HTTPException(status_code=500, detail=str(e))
261
 
262
 
263
  @app.get("/opinion/{opinion_id}")
264
  def get_opinion(opinion_id: int):
265
- """Get a specific opinion by ID."""
266
- try:
267
- sql = f"SELECT * FROM data WHERE id = {opinion_id}"
268
- rows = query_parquet("opinions", sql)
269
- if not rows:
270
- raise HTTPException(status_code=404, detail="Opinion not found")
271
- return rows[0]
272
- except HTTPException:
273
- raise
274
- except Exception as e:
275
- raise HTTPException(status_code=500, detail=str(e))
276
 
277
 
278
  @app.get("/cluster/{cluster_id}")
279
  def get_cluster(cluster_id: int):
280
- """Get a specific opinion cluster by ID."""
281
- try:
282
- sql = f"SELECT * FROM data WHERE id = {cluster_id}"
283
- rows = query_parquet("opinion-clusters", sql)
284
- if not rows:
285
- raise HTTPException(status_code=404, detail="Cluster not found")
286
- return rows[0]
287
- except HTTPException:
288
- raise
289
- except Exception as e:
290
- raise HTTPException(status_code=500, detail=str(e))
291
 
292
 
293
  @app.get("/docket/{docket_id}")
294
  def get_docket(docket_id: int):
295
- """Get a specific docket by ID."""
296
- try:
297
- sql = f"SELECT * FROM data WHERE id = {docket_id}"
298
- rows = query_parquet("dockets", sql)
299
- if not rows:
300
- raise HTTPException(status_code=404, detail="Docket not found")
301
- return rows[0]
302
- except HTTPException:
303
- raise
304
- except Exception as e:
305
- raise HTTPException(status_code=500, detail=str(e))
306
 
307
 
308
  if __name__ == "__main__":
 
1
  """LawForge Data API - HuggingFace Space
2
 
3
  FastAPI service to query CourtListener parquet data directly.
4
+ Uses DuckDB to query ALL parquet shards.
5
  """
6
 
7
  import os
8
+ import json
9
+ from pathlib import Path
10
 
11
  import duckdb
12
+ import numpy as np
13
  from fastapi import FastAPI, HTTPException, Query
14
  from fastapi.middleware.cors import CORSMiddleware
15
  from huggingface_hub import hf_hub_download
 
16
 
17
  app = FastAPI(
18
  title="LawForge Data API",
19
  description="Query CourtListener legal data",
20
+ version="2.0.0"
21
  )
22
 
23
  app.add_middleware(
 
31
  # Configuration
32
  DATASET_ID = "jonathanagustin/courtlistener-1"
33
  HF_TOKEN = os.environ.get("HF_TOKEN")
34
+ CACHE_DIR = Path("/tmp/hf_cache")
35
+ CACHE_DIR.mkdir(parents=True, exist_ok=True)
36
 
37
+ # Cache
38
+ _shard_cache: dict[str, list[str]] = {}
39
+ _manifest_cache: dict = {}
40
 
41
 
42
+ def get_manifest() -> dict:
43
+ """Download and cache the manifest."""
44
+ global _manifest_cache
45
+ if not _manifest_cache:
46
+ try:
47
+ path = hf_hub_download(
48
+ repo_id=DATASET_ID,
49
+ filename="manifest.json",
50
+ repo_type="dataset",
51
+ token=HF_TOKEN,
52
+ cache_dir=str(CACHE_DIR)
53
+ )
54
+ with open(path) as f:
55
+ _manifest_cache = json.load(f)
56
+ except Exception as e:
57
+ print(f"Error loading manifest: {e}")
58
+ _manifest_cache = {"tables": {}}
59
+ return _manifest_cache
60
+
61
+
62
+ def get_shard_count(config: str) -> int:
63
+ """Get number of shards for a config from manifest."""
64
+ manifest = get_manifest()
65
+ table_info = manifest.get("tables", {}).get(config, {})
66
+ return table_info.get("shard_count", 1)
67
+
68
+
69
+ def download_all_shards(config: str) -> list[str]:
70
+ """Download all parquet shards for a config."""
71
+ if config in _shard_cache:
72
+ return _shard_cache[config]
73
+
74
+ shard_count = get_shard_count(config)
75
+ print(f"Downloading {shard_count} shards for {config}...")
76
+
77
+ paths = []
78
+ for i in range(shard_count):
79
+ filename = f"data/{config}/{config}-{i:05d}.parquet"
80
  try:
81
  local_path = hf_hub_download(
82
  repo_id=DATASET_ID,
83
  filename=filename,
84
  repo_type="dataset",
85
+ token=HF_TOKEN,
86
+ cache_dir=str(CACHE_DIR)
87
  )
88
+ paths.append(local_path)
 
89
  except Exception as e:
90
  print(f"Error downloading {filename}: {e}")
91
+
92
+ print(f"Downloaded {len(paths)}/{shard_count} shards for {config}")
93
+ _shard_cache[config] = paths
94
+ return paths
95
+
96
+
97
+ def query_config(config: str, sql_template: str) -> list[dict]:
98
+ """Execute SQL query across all shards of a config."""
99
+ paths = download_all_shards(config)
100
+ if not paths:
101
+ raise HTTPException(status_code=404, detail=f"No data found for config: {config}")
102
+
103
  try:
104
  conn = duckdb.connect(":memory:")
105
+
106
+ if len(paths) == 1:
107
+ conn.execute(f"CREATE VIEW data AS SELECT * FROM read_parquet('{paths[0]}')")
 
108
  else:
109
+ paths_str = ", ".join(f"'{p}'" for p in paths)
110
+ conn.execute(f"CREATE VIEW data AS SELECT * FROM read_parquet([{paths_str}])")
111
+
112
+ result = conn.execute(sql_template).fetchdf()
113
  conn.close()
114
+
 
 
 
 
115
  def clean_value(v):
116
  if v is None:
117
  return None
 
122
  if isinstance(v, (np.floating, np.float64)):
123
  return float(v)
124
  return v
125
+
126
+ return [{k: clean_value(v) for k, v in row.items()} for _, row in result.iterrows()]
127
+
128
+ except HTTPException:
129
+ raise
 
 
130
  except Exception as e:
131
+ raise HTTPException(status_code=500, detail=f"Query error: {str(e)}")
 
132
 
133
 
134
  @app.get("/")
135
  def root():
136
+ manifest = get_manifest()
137
+ tables = list(manifest.get("tables", {}).keys())
138
  return {
139
  "name": "LawForge Data API",
140
+ "version": "2.0.0",
141
+ "tables": tables,
142
  "endpoints": {
143
  "/health": "Health check",
144
+ "/rows/{config}": "Get rows (all shards)",
145
  "/search/{config}": "Full-text search",
146
  "/filter/{config}": "SQL WHERE filter",
147
+ "/stats": "Dataset statistics",
148
  }
149
  }
150
 
151
 
152
  @app.get("/health")
153
  def health():
154
+ return {"status": "ok", "hf_token": "set" if HF_TOKEN else "not set", "token_len": len(HF_TOKEN) if HF_TOKEN else 0}
 
 
155
 
156
 
157
+ @app.get("/stats")
158
+ def stats():
159
+ manifest = get_manifest()
160
+ tables = {name: {"total_rows": info.get("total_rows", 0), "shard_count": info.get("shard_count", 0)}
161
+ for name, info in manifest.get("tables", {}).items()}
162
+ return {"updated_at": manifest.get("updated_at"), "tables": tables}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
 
165
  @app.get("/rows/{config}")
166
+ def get_rows(config: str, offset: int = Query(0, ge=0), limit: int = Query(20, ge=1, le=1000)):
167
+ manifest = get_manifest()
168
+ total = manifest.get("tables", {}).get(config, {}).get("total_rows", 0)
169
+ rows = query_config(config, f"SELECT * FROM data LIMIT {limit} OFFSET {offset}")
170
+ return {"rows": rows, "total": total, "offset": offset, "limit": limit}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
 
173
  @app.get("/search/{config}")
174
+ def search(config: str, q: str = Query(..., min_length=1), offset: int = Query(0, ge=0), limit: int = Query(20, ge=1, le=100)):
175
+ if config == "opinions":
176
+ cols = ["plain_text", "html", "author_str"]
177
+ elif config == "opinion-clusters":
178
+ cols = ["case_name", "case_name_full", "syllabus", "judges"]
179
+ elif config == "dockets":
180
+ cols = ["case_name", "case_name_full", "docket_number"]
181
+ else:
182
+ cols = ["id"]
183
+
184
+ where = " OR ".join(f"COALESCE(CAST({c} AS VARCHAR), '') ILIKE '%{q}%'" for c in cols)
185
+ rows = query_config(config, f"SELECT * FROM data WHERE {where} LIMIT {limit} OFFSET {offset}")
186
+ return {"rows": rows, "query": q, "offset": offset, "limit": limit}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
 
189
  @app.get("/filter/{config}")
190
+ def filter_rows(config: str, where: str = Query(..., min_length=1), offset: int = Query(0, ge=0), limit: int = Query(20, ge=1, le=1000)):
191
+ forbidden = ["DROP", "DELETE", "INSERT", "UPDATE", "ALTER", "CREATE", ";", "--"]
192
+ for word in forbidden:
193
+ if word in where.upper():
194
+ raise HTTPException(status_code=400, detail=f"Forbidden: {word}")
195
+
196
+ rows = query_config(config, f"SELECT * FROM data WHERE {where} LIMIT {limit} OFFSET {offset}")
197
+ return {"rows": rows, "where": where, "offset": offset, "limit": limit}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
 
200
  @app.get("/opinion/{opinion_id}")
201
  def get_opinion(opinion_id: int):
202
+ rows = query_config("opinions", f"SELECT * FROM data WHERE id = '{opinion_id}'")
203
+ if not rows:
204
+ raise HTTPException(status_code=404, detail="Opinion not found")
205
+ return rows[0]
 
 
 
 
 
 
 
206
 
207
 
208
  @app.get("/cluster/{cluster_id}")
209
  def get_cluster(cluster_id: int):
210
+ rows = query_config("opinion-clusters", f"SELECT * FROM data WHERE id = '{cluster_id}'")
211
+ if not rows:
212
+ raise HTTPException(status_code=404, detail="Cluster not found")
213
+ return rows[0]
 
 
 
 
 
 
 
214
 
215
 
216
  @app.get("/docket/{docket_id}")
217
  def get_docket(docket_id: int):
218
+ rows = query_config("dockets", f"SELECT * FROM data WHERE id = '{docket_id}'")
219
+ if not rows:
220
+ raise HTTPException(status_code=404, detail="Docket not found")
221
+ return rows[0]
 
 
 
 
 
 
 
222
 
223
 
224
  if __name__ == "__main__":