dataset-tool / app /api /datasets.py
iaroy's picture
Deploy full application code
fdc5d7a
raw
history blame
5.23 kB
from fastapi import APIRouter, Query, HTTPException
from typing import List, Optional, Dict, Any, Set
from pydantic import BaseModel
from fastapi.concurrency import run_in_threadpool
from app.services.hf_datasets import (
get_dataset_commits,
get_dataset_files,
get_file_url,
get_datasets_page_from_zset,
get_dataset_commits_async,
get_dataset_files_async,
get_file_url_async,
get_datasets_page_from_cache,
fetch_and_cache_all_datasets,
)
from app.services.redis_client import cache_get
import logging
import time
from fastapi.responses import JSONResponse
import os
router = APIRouter(prefix="/datasets", tags=["datasets"])
log = logging.getLogger(__name__)
SIZE_LOW = 100 * 1024 * 1024
SIZE_MEDIUM = 1024 * 1024 * 1024
class DatasetInfo(BaseModel):
id: str
name: Optional[str]
description: Optional[str]
size_bytes: Optional[int]
impact_level: Optional[str]
downloads: Optional[int]
likes: Optional[int]
tags: Optional[List[str]]
class Config:
extra = "ignore"
class PaginatedDatasets(BaseModel):
total: int
items: List[DatasetInfo]
class CommitInfo(BaseModel):
id: str
title: Optional[str]
message: Optional[str]
author: Optional[Dict[str, Any]]
date: Optional[str]
class CacheStatus(BaseModel):
last_update: Optional[str]
total_items: int
warming_up: bool
def deduplicate_by_id(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
seen: Set[str] = set()
unique_items = []
for item in items:
item_id = item.get("id")
if item_id and item_id not in seen:
seen.add(item_id)
unique_items.append(item)
return unique_items
@router.get("/cache-status", response_model=CacheStatus)
async def cache_status():
meta = await cache_get("hf:datasets:meta")
last_update = meta["last_update"] if meta and "last_update" in meta else None
total_items = meta["total_items"] if meta and "total_items" in meta else 0
warming_up = not bool(total_items)
return CacheStatus(last_update=last_update, total_items=total_items, warming_up=warming_up)
@router.get("/", response_model=None)
async def list_datasets(
limit: int = Query(10, ge=1, le=1000),
offset: int = Query(0, ge=0),
search: str = Query(None, description="Search term for dataset id or description"),
sort_by: str = Query(None, description="Field to sort by (e.g., 'downloads', 'likes', 'created_at')"),
sort_order: str = Query("desc", regex="^(asc|desc)$", description="Sort order: 'asc' or 'desc'"),
):
# Fetch the full list from cache
result, status = get_datasets_page_from_cache(1000000, 0) # get all for in-memory filtering
if status != 200:
return JSONResponse(result, status_code=status)
items = result["items"]
# Filtering
if search:
items = [d for d in items if search.lower() in (d.get("id", "") + " " + str(d.get("description", "")).lower())]
# Sorting
if sort_by:
items = sorted(items, key=lambda d: d.get(sort_by) or 0, reverse=(sort_order == "desc"))
# Pagination
total = len(items)
page = items[offset:offset+limit]
total_pages = (total + limit - 1) // limit
current_page = (offset // limit) + 1
next_page = current_page + 1 if offset + limit < total else None
prev_page = current_page - 1 if current_page > 1 else None
return {
"total": total,
"current_page": current_page,
"total_pages": total_pages,
"next_page": next_page,
"prev_page": prev_page,
"items": page
}
@router.get("/{dataset_id:path}/commits", response_model=List[CommitInfo])
async def get_commits(dataset_id: str):
"""
Get commit history for a dataset.
"""
try:
return await get_dataset_commits_async(dataset_id)
except Exception as e:
log.error(f"Error fetching commits for {dataset_id}: {e}")
raise HTTPException(status_code=404, detail=f"Could not fetch commits: {e}")
@router.get("/{dataset_id:path}/files", response_model=List[str])
async def list_files(dataset_id: str):
"""
List files in a dataset.
"""
try:
return await get_dataset_files_async(dataset_id)
except Exception as e:
log.error(f"Error listing files for {dataset_id}: {e}")
raise HTTPException(status_code=404, detail=f"Could not list files: {e}")
@router.get("/{dataset_id:path}/file-url")
async def get_file_url_endpoint(dataset_id: str, filename: str = Query(...), revision: Optional[str] = None):
"""
Get download URL for a file in a dataset.
"""
url = await get_file_url_async(dataset_id, filename, revision)
return {"download_url": url}
@router.get("/meta")
async def get_datasets_meta():
meta = await cache_get("hf:datasets:meta")
return meta if meta else {}
# Endpoint to trigger cache refresh manually (for admin/testing)
@router.post("/datasets/refresh-cache")
def refresh_cache():
token = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
if not token:
return JSONResponse({"error": "HUGGINGFACEHUB_API_TOKEN not set"}, status_code=500)
count = fetch_and_cache_all_datasets(token)
return {"status": "ok", "cached": count}