Spaces:
Runtime error
Runtime error
| from fastapi import APIRouter, Query, HTTPException | |
| from typing import List, Optional, Dict, Any, Set | |
| from pydantic import BaseModel | |
| from fastapi.concurrency import run_in_threadpool | |
| from app.services.hf_datasets import ( | |
| get_dataset_commits, | |
| get_dataset_files, | |
| get_file_url, | |
| get_datasets_page_from_zset, | |
| get_dataset_commits_async, | |
| get_dataset_files_async, | |
| get_file_url_async, | |
| get_datasets_page_from_cache, | |
| fetch_and_cache_all_datasets, | |
| ) | |
| from app.services.redis_client import cache_get | |
| import logging | |
| import time | |
| from fastapi.responses import JSONResponse | |
| import os | |
| router = APIRouter(prefix="/datasets", tags=["datasets"]) | |
| log = logging.getLogger(__name__) | |
| SIZE_LOW = 100 * 1024 * 1024 | |
| SIZE_MEDIUM = 1024 * 1024 * 1024 | |
| class DatasetInfo(BaseModel): | |
| id: str | |
| name: Optional[str] | |
| description: Optional[str] | |
| size_bytes: Optional[int] | |
| impact_level: Optional[str] | |
| downloads: Optional[int] | |
| likes: Optional[int] | |
| tags: Optional[List[str]] | |
| class Config: | |
| extra = "ignore" | |
| class PaginatedDatasets(BaseModel): | |
| total: int | |
| items: List[DatasetInfo] | |
| class CommitInfo(BaseModel): | |
| id: str | |
| title: Optional[str] | |
| message: Optional[str] | |
| author: Optional[Dict[str, Any]] | |
| date: Optional[str] | |
| class CacheStatus(BaseModel): | |
| last_update: Optional[str] | |
| total_items: int | |
| warming_up: bool | |
| def deduplicate_by_id(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | |
| seen: Set[str] = set() | |
| unique_items = [] | |
| for item in items: | |
| item_id = item.get("id") | |
| if item_id and item_id not in seen: | |
| seen.add(item_id) | |
| unique_items.append(item) | |
| return unique_items | |
| async def cache_status(): | |
| meta = await cache_get("hf:datasets:meta") | |
| last_update = meta["last_update"] if meta and "last_update" in meta else None | |
| total_items = meta["total_items"] if meta and "total_items" in meta else 0 | |
| warming_up = not bool(total_items) | |
| return CacheStatus(last_update=last_update, total_items=total_items, warming_up=warming_up) | |
| async def list_datasets( | |
| limit: int = Query(10, ge=1, le=1000), | |
| offset: int = Query(0, ge=0), | |
| search: str = Query(None, description="Search term for dataset id or description"), | |
| sort_by: str = Query(None, description="Field to sort by (e.g., 'downloads', 'likes', 'created_at')"), | |
| sort_order: str = Query("desc", regex="^(asc|desc)$", description="Sort order: 'asc' or 'desc'"), | |
| ): | |
| # Fetch the full list from cache | |
| result, status = get_datasets_page_from_cache(1000000, 0) # get all for in-memory filtering | |
| if status != 200: | |
| return JSONResponse(result, status_code=status) | |
| items = result["items"] | |
| # Filtering | |
| if search: | |
| items = [d for d in items if search.lower() in (d.get("id", "") + " " + str(d.get("description", "")).lower())] | |
| # Sorting | |
| if sort_by: | |
| items = sorted(items, key=lambda d: d.get(sort_by) or 0, reverse=(sort_order == "desc")) | |
| # Pagination | |
| total = len(items) | |
| page = items[offset:offset+limit] | |
| total_pages = (total + limit - 1) // limit | |
| current_page = (offset // limit) + 1 | |
| next_page = current_page + 1 if offset + limit < total else None | |
| prev_page = current_page - 1 if current_page > 1 else None | |
| return { | |
| "total": total, | |
| "current_page": current_page, | |
| "total_pages": total_pages, | |
| "next_page": next_page, | |
| "prev_page": prev_page, | |
| "items": page | |
| } | |
| async def get_commits(dataset_id: str): | |
| """ | |
| Get commit history for a dataset. | |
| """ | |
| try: | |
| return await get_dataset_commits_async(dataset_id) | |
| except Exception as e: | |
| log.error(f"Error fetching commits for {dataset_id}: {e}") | |
| raise HTTPException(status_code=404, detail=f"Could not fetch commits: {e}") | |
| async def list_files(dataset_id: str): | |
| """ | |
| List files in a dataset. | |
| """ | |
| try: | |
| return await get_dataset_files_async(dataset_id) | |
| except Exception as e: | |
| log.error(f"Error listing files for {dataset_id}: {e}") | |
| raise HTTPException(status_code=404, detail=f"Could not list files: {e}") | |
| async def get_file_url_endpoint(dataset_id: str, filename: str = Query(...), revision: Optional[str] = None): | |
| """ | |
| Get download URL for a file in a dataset. | |
| """ | |
| url = await get_file_url_async(dataset_id, filename, revision) | |
| return {"download_url": url} | |
| async def get_datasets_meta(): | |
| meta = await cache_get("hf:datasets:meta") | |
| return meta if meta else {} | |
| # Endpoint to trigger cache refresh manually (for admin/testing) | |
| def refresh_cache(): | |
| token = os.environ.get("HUGGINGFACEHUB_API_TOKEN") | |
| if not token: | |
| return JSONResponse({"error": "HUGGINGFACEHUB_API_TOKEN not set"}, status_code=500) | |
| count = fetch_and_cache_all_datasets(token) | |
| return {"status": "ok", "cached": count} | |