Spaces:
Sleeping
Sleeping
| """LawForge Data API - HuggingFace Space | |
| FastAPI service to query CourtListener parquet data directly. | |
| Uses DuckDB to query ALL parquet shards. | |
| """ | |
| import os | |
| import json | |
| from pathlib import Path | |
| import duckdb | |
| import numpy as np | |
| from fastapi import FastAPI, HTTPException, Query | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from huggingface_hub import hf_hub_download | |
| app = FastAPI( | |
| title="LawForge Data API", | |
| description="Query CourtListener legal data", | |
| version="2.0.0" | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Configuration | |
| DATASET_ID = "jonathanagustin/courtlistener-1" | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| CACHE_DIR = Path("/tmp/hf_cache") | |
| CACHE_DIR.mkdir(parents=True, exist_ok=True) | |
| # Cache | |
| _shard_cache: dict[str, list[str]] = {} | |
| _manifest_cache: dict = {} | |
| def get_manifest() -> dict: | |
| """Download and cache the manifest.""" | |
| global _manifest_cache | |
| if not _manifest_cache: | |
| try: | |
| path = hf_hub_download( | |
| repo_id=DATASET_ID, | |
| filename="manifest.json", | |
| repo_type="dataset", | |
| token=HF_TOKEN, | |
| cache_dir=str(CACHE_DIR) | |
| ) | |
| with open(path) as f: | |
| _manifest_cache = json.load(f) | |
| except Exception as e: | |
| print(f"Error loading manifest: {e}") | |
| _manifest_cache = {"tables": {}} | |
| return _manifest_cache | |
| def get_shard_count(config: str) -> int: | |
| """Get number of shards for a config from manifest.""" | |
| manifest = get_manifest() | |
| table_info = manifest.get("tables", {}).get(config, {}) | |
| return table_info.get("shard_count", 1) | |
| def download_all_shards(config: str) -> list[str]: | |
| """Download all parquet shards for a config.""" | |
| if config in _shard_cache: | |
| return _shard_cache[config] | |
| shard_count = get_shard_count(config) | |
| print(f"Downloading {shard_count} shards for {config}...") | |
| paths = [] | |
| for i in range(shard_count): | |
| filename = f"data/{config}/{config}-{i:05d}.parquet" | |
| try: | |
| local_path = hf_hub_download( | |
| repo_id=DATASET_ID, | |
| filename=filename, | |
| repo_type="dataset", | |
| token=HF_TOKEN, | |
| cache_dir=str(CACHE_DIR) | |
| ) | |
| paths.append(local_path) | |
| except Exception as e: | |
| print(f"Error downloading {filename}: {e}") | |
| print(f"Downloaded {len(paths)}/{shard_count} shards for {config}") | |
| _shard_cache[config] = paths | |
| return paths | |
| def query_config(config: str, sql_template: str) -> list[dict]: | |
| """Execute SQL query across all shards of a config.""" | |
| paths = download_all_shards(config) | |
| if not paths: | |
| raise HTTPException(status_code=404, detail=f"No data found for config: {config}") | |
| try: | |
| conn = duckdb.connect(":memory:") | |
| if len(paths) == 1: | |
| conn.execute(f"CREATE VIEW data AS SELECT * FROM read_parquet('{paths[0]}')") | |
| else: | |
| paths_str = ", ".join(f"'{p}'" for p in paths) | |
| conn.execute(f"CREATE VIEW data AS SELECT * FROM read_parquet([{paths_str}])") | |
| result = conn.execute(sql_template).fetchdf() | |
| conn.close() | |
| def clean_value(v): | |
| if v is None: | |
| return None | |
| if isinstance(v, float) and (np.isnan(v) or np.isinf(v)): | |
| return None | |
| if isinstance(v, (np.integer, np.int64)): | |
| return int(v) | |
| if isinstance(v, (np.floating, np.float64)): | |
| return float(v) | |
| return v | |
| return [{k: clean_value(v) for k, v in row.items()} for _, row in result.iterrows()] | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Query error: {str(e)}") | |
| def root(): | |
| manifest = get_manifest() | |
| tables = list(manifest.get("tables", {}).keys()) | |
| return { | |
| "name": "LawForge Data API", | |
| "version": "2.0.0", | |
| "tables": tables, | |
| "endpoints": { | |
| "/health": "Health check", | |
| "/rows/{config}": "Get rows (all shards)", | |
| "/search/{config}": "Full-text search", | |
| "/filter/{config}": "SQL WHERE filter", | |
| "/stats": "Dataset statistics", | |
| } | |
| } | |
| def health(): | |
| return {"status": "ok", "hf_token": "set" if HF_TOKEN else "not set", "token_len": len(HF_TOKEN) if HF_TOKEN else 0} | |
| def stats(): | |
| manifest = get_manifest() | |
| tables = {name: {"total_rows": info.get("total_rows", 0), "shard_count": info.get("shard_count", 0)} | |
| for name, info in manifest.get("tables", {}).items()} | |
| return {"updated_at": manifest.get("updated_at"), "tables": tables} | |
| def get_rows(config: str, offset: int = Query(0, ge=0), limit: int = Query(20, ge=1, le=1000)): | |
| manifest = get_manifest() | |
| total = manifest.get("tables", {}).get(config, {}).get("total_rows", 0) | |
| rows = query_config(config, f"SELECT * FROM data LIMIT {limit} OFFSET {offset}") | |
| return {"rows": rows, "total": total, "offset": offset, "limit": limit} | |
| def search(config: str, q: str = Query(..., min_length=1), offset: int = Query(0, ge=0), limit: int = Query(20, ge=1, le=100)): | |
| if config == "opinions": | |
| cols = ["plain_text", "html", "author_str"] | |
| elif config == "opinion-clusters": | |
| cols = ["case_name", "case_name_full", "syllabus", "judges"] | |
| elif config == "dockets": | |
| cols = ["case_name", "case_name_full", "docket_number"] | |
| else: | |
| cols = ["id"] | |
| where = " OR ".join(f"COALESCE(CAST({c} AS VARCHAR), '') ILIKE '%{q}%'" for c in cols) | |
| rows = query_config(config, f"SELECT * FROM data WHERE {where} LIMIT {limit} OFFSET {offset}") | |
| return {"rows": rows, "query": q, "offset": offset, "limit": limit} | |
| def filter_rows(config: str, where: str = Query(..., min_length=1), offset: int = Query(0, ge=0), limit: int = Query(20, ge=1, le=1000)): | |
| forbidden = ["DROP", "DELETE", "INSERT", "UPDATE", "ALTER", "CREATE", ";", "--"] | |
| for word in forbidden: | |
| if word in where.upper(): | |
| raise HTTPException(status_code=400, detail=f"Forbidden: {word}") | |
| rows = query_config(config, f"SELECT * FROM data WHERE {where} LIMIT {limit} OFFSET {offset}") | |
| return {"rows": rows, "where": where, "offset": offset, "limit": limit} | |
| def get_opinion(opinion_id: int): | |
| rows = query_config("opinions", f"SELECT * FROM data WHERE id = '{opinion_id}'") | |
| if not rows: | |
| raise HTTPException(status_code=404, detail="Opinion not found") | |
| return rows[0] | |
| def get_cluster(cluster_id: int): | |
| rows = query_config("opinion-clusters", f"SELECT * FROM data WHERE id = '{cluster_id}'") | |
| if not rows: | |
| raise HTTPException(status_code=404, detail="Cluster not found") | |
| return rows[0] | |
| def get_docket(docket_id: int): | |
| rows = query_config("dockets", f"SELECT * FROM data WHERE id = '{docket_id}'") | |
| if not rows: | |
| raise HTTPException(status_code=404, detail="Docket not found") | |
| return rows[0] | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |