Spaces:
Runtime error
Runtime error
| import asyncio | |
| import concurrent.futures | |
| import json | |
| import logging | |
| import os | |
| import sqlite3 | |
| from contextlib import asynccontextmanager | |
| from typing import List | |
| import numpy as np | |
| from apscheduler.schedulers.asyncio import AsyncIOScheduler | |
| from apscheduler.triggers.cron import CronTrigger | |
| from cashews import NOT_NONE, cache | |
| from fastapi import FastAPI, HTTPException, Query | |
| from huggingface_hub import login, upload_file | |
| from pandas import Timestamp | |
| from pydantic import BaseModel | |
| from starlette.responses import RedirectResponse | |
| from create_collections import collections, update_collection_for_dataset | |
| from data_loader import refresh_data | |
| login(token=os.getenv("HF_TOKEN")) | |
| UPDATE_SCHEDULE = {"hour": os.getenv("UPDATE_INTERVAL_HOURS", "*/6")} | |
| COLLECTION_UPDATE_SCHEDULE = {"hour": "0"} # Run at midnight every day | |
| cache.setup("mem://?check_interval=10&size=10000") | |
| logger = logging.getLogger(__name__) | |
| def get_db_connection(): | |
| conn = sqlite3.connect("datasets.db") | |
| conn.row_factory = sqlite3.Row | |
| conn.execute("PRAGMA journal_mode = WAL") | |
| conn.execute("PRAGMA synchronous = NORMAL") | |
| return conn | |
| def setup_database(): | |
| conn = get_db_connection() | |
| c = conn.cursor() | |
| c.execute( | |
| """CREATE TABLE IF NOT EXISTS datasets | |
| (hub_id TEXT PRIMARY KEY, | |
| likes INTEGER, | |
| downloads INTEGER, | |
| tags JSON, | |
| created_at INTEGER, | |
| last_modified INTEGER, | |
| license JSON, | |
| language JSON, | |
| config_name TEXT, | |
| column_names JSON, | |
| features JSON)""" | |
| ) | |
| c.execute( | |
| """ | |
| CREATE INDEX IF NOT EXISTS idx_column_names | |
| ON datasets(column_names) | |
| """ | |
| ) | |
| c.execute( | |
| """ | |
| CREATE INDEX IF NOT EXISTS idx_downloads_likes | |
| ON datasets(downloads DESC, likes DESC) | |
| """ | |
| ) | |
| conn.commit() | |
| c.execute("ANALYZE") | |
| conn.close() | |
| def serialize_numpy(obj): | |
| if isinstance(obj, np.ndarray): | |
| return obj.tolist() | |
| if isinstance(obj, np.integer): | |
| return int(obj) | |
| if isinstance(obj, np.floating): | |
| return float(obj) | |
| if isinstance(obj, Timestamp): | |
| return int(obj.timestamp()) | |
| logger.error(f"Object of type {type(obj)} is not JSON serializable") | |
| raise TypeError(f"Object of type {type(obj)} is not JSON serializable") | |
| def background_refresh_data(): | |
| logger.info("Starting background data refresh") | |
| try: | |
| return refresh_data() | |
| except Exception as e: | |
| logger.error(f"Error in background data refresh: {str(e)}") | |
| return None | |
| async def update_database(): | |
| logger.info("Starting scheduled data refresh") | |
| with concurrent.futures.ThreadPoolExecutor() as executor: | |
| future = executor.submit(background_refresh_data) | |
| try: | |
| datasets = await asyncio.get_event_loop().run_in_executor( | |
| None, future.result | |
| ) | |
| except asyncio.CancelledError: | |
| future.cancel() | |
| logger.info("Data refresh cancelled") | |
| return | |
| if datasets is None: | |
| logger.error("Data refresh failed, skipping database update") | |
| return | |
| conn = get_db_connection() | |
| try: | |
| c = conn.cursor() | |
| c.executemany( | |
| """ | |
| INSERT OR REPLACE INTO datasets | |
| (hub_id, likes, downloads, tags, created_at, last_modified, license, language, config_name, column_names, features) | |
| VALUES (?, ?, ?, json(?), ?, ?, json(?), json(?), ?, json(?), json(?)) | |
| """, | |
| [ | |
| ( | |
| data["hub_id"], | |
| data.get("likes", 0), | |
| data.get("downloads", 0), | |
| json.dumps(data.get("tags", []), default=serialize_numpy), | |
| int(data["created_at"].timestamp()) | |
| if isinstance(data["created_at"], Timestamp) | |
| else data.get("created_at", 0), | |
| int(data["last_modified"].timestamp()) | |
| if isinstance(data["last_modified"], Timestamp) | |
| else data.get("last_modified", 0), | |
| json.dumps(data.get("license", []), default=serialize_numpy), | |
| json.dumps(data.get("language", []), default=serialize_numpy), | |
| data.get("config_name", ""), | |
| json.dumps(data.get("column_names", []), default=serialize_numpy), | |
| json.dumps(data.get("features", []), default=serialize_numpy), | |
| ) | |
| for data in datasets | |
| ], | |
| ) | |
| conn.commit() | |
| logger.info("Scheduled data refresh completed") | |
| except Exception as e: | |
| logger.error(f"Error during database update: {str(e)}") | |
| conn.rollback() | |
| finally: | |
| conn.close() | |
| try: | |
| upload_file( | |
| path_or_fileobj="datasets.db", | |
| path_in_repo="datasets.db", | |
| repo_id="librarian-bots/column-db", | |
| repo_type="dataset", | |
| ) | |
| logger.info("Database file uploaded to Hugging Face Hub successfully") | |
| except Exception as e: | |
| logger.error(f"Error uploading database file to Hugging Face Hub: {str(e)}") | |
| async def update_collections(): | |
| logger.info("Starting scheduled collection update") | |
| try: | |
| for collection in collections: | |
| result = await asyncio.get_event_loop().run_in_executor( | |
| None, | |
| update_collection_for_dataset, | |
| collection["collection_name"], | |
| collection["dataset_columns"], | |
| collection["collection_description"], | |
| "librarian-bots", | |
| ) | |
| logger.info(f"Updated collection: {result}") | |
| except Exception as e: | |
| logger.error(f"Error during collection update: {str(e)}") | |
| async def lifespan(app: FastAPI): | |
| setup_database() | |
| logger.info("Performing initial data refresh") | |
| await update_database() | |
| scheduler = AsyncIOScheduler() | |
| scheduler.add_job(update_database, CronTrigger(**UPDATE_SCHEDULE)) | |
| scheduler.add_job(update_collections, CronTrigger(**COLLECTION_UPDATE_SCHEDULE)) | |
| scheduler.start() | |
| await update_collections() | |
| yield | |
| scheduler.shutdown() | |
| app = FastAPI(lifespan=lifespan) | |
| def root(): | |
| return RedirectResponse(url="/docs") | |
| class SearchResponse(BaseModel): | |
| total: int | |
| page: int | |
| page_size: int | |
| results: List[dict] | |
| async def search_datasets( | |
| columns: List[str] = Query(...), | |
| match_all: bool = Query(False), | |
| page: int = Query(1, ge=1), | |
| page_size: int = Query(10, ge=1, le=1000), | |
| ): | |
| offset = (page - 1) * page_size | |
| conn = get_db_connection() | |
| c = conn.cursor() | |
| try: | |
| if match_all: | |
| query = """ | |
| SELECT *, ( | |
| SELECT COUNT(*) | |
| FROM json_each(column_names) | |
| WHERE json_each.value IN ({}) | |
| ) as match_count | |
| FROM datasets | |
| WHERE match_count = ? | |
| ORDER BY downloads DESC, likes DESC | |
| LIMIT ? OFFSET ? | |
| """.format(",".join("?" * len(columns))) | |
| c.execute(query, (*columns, len(columns), page_size, offset)) | |
| else: | |
| query = """ | |
| SELECT * FROM datasets | |
| WHERE EXISTS ( | |
| SELECT 1 | |
| FROM json_each(column_names) | |
| WHERE json_each.value IN ({}) | |
| ) | |
| ORDER BY downloads DESC, likes DESC | |
| LIMIT ? OFFSET ? | |
| """.format(",".join("?" * len(columns))) | |
| c.execute(query, (*columns, page_size, offset)) | |
| results = [dict(row) for row in c.fetchall()] | |
| if match_all: | |
| count_query = """ | |
| SELECT COUNT(*) as total FROM datasets | |
| WHERE ( | |
| SELECT COUNT(*) | |
| FROM json_each(column_names) | |
| WHERE json_each.value IN ({}) | |
| ) = ? | |
| """.format(",".join("?" * len(columns))) | |
| c.execute(count_query, (*columns, len(columns))) | |
| else: | |
| count_query = """ | |
| SELECT COUNT(*) as total FROM datasets | |
| WHERE EXISTS ( | |
| SELECT 1 | |
| FROM json_each(column_names) | |
| WHERE json_each.value IN ({}) | |
| ) | |
| """.format(",".join("?" * len(columns))) | |
| c.execute(count_query, columns) | |
| total = c.fetchone()["total"] | |
| for result in results: | |
| result["tags"] = json.loads(result["tags"]) | |
| result["license"] = json.loads(result["license"]) | |
| result["language"] = json.loads(result["language"]) | |
| result["column_names"] = json.loads(result["column_names"]) | |
| result["features"] = json.loads(result["features"]) | |
| return SearchResponse( | |
| total=total, page=page, page_size=page_size, results=results | |
| ) | |
| except sqlite3.Error as e: | |
| logger.error(f"Database error: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Database error: {str(e)}") from e | |
| finally: | |
| conn.close() | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |