Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| import logging | |
| import os | |
| from typing import List | |
| import sys | |
| import chromadb | |
| from chromadb.utils import embedding_functions | |
| from cashews import cache | |
| from fastapi import FastAPI, HTTPException, Query | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from contextlib import asynccontextmanager | |
| import polars as pl | |
| from huggingface_hub import hf_hub_url, DatasetCard, ModelCard, HfApi | |
| from datetime import datetime, timedelta | |
| from typing import Generator | |
| from huggingface_hub import ModelInfo, DatasetInfo | |
| import stamina | |
| import logging | |
| import polars as pl | |
| from huggingface_hub import dataset_info | |
| from huggingface_hub import InferenceClient | |
| from transformers import AutoTokenizer | |
| import stamina | |
| from tqdm.contrib.concurrent import thread_map | |
| from datasets import Dataset, Value, Sequence | |
| import datasets | |
| import os | |
| from dotenv import load_dotenv | |
| from huggingface_hub import get_inference_endpoint | |
| from huggingface_hub import AsyncInferenceClient | |
| import asyncio | |
| from typing import List | |
| hf_api = HfApi() | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| "davanstrien/SmolLM2-360M-tldr-sft-2025-02-12_15-13" | |
| ) | |
| os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" # turn on HF_TRANSFER | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| LOCAL = False | |
| if sys.platform == "darwin": | |
| LOCAL = True | |
| DATA_DIR = "data" if LOCAL else "/data" | |
| # Configure cache | |
| cache.setup("mem://", size_limit="4gb") | |
| # Initialize ChromaDB client | |
| client = chromadb.PersistentClient(path=f"{DATA_DIR}/chroma") | |
| # Initialize FastAPI app | |
| async def lifespan(app: FastAPI): | |
| # Setup | |
| setup_database() | |
| yield | |
| # Cleanup | |
| await cache.close() | |
| app = FastAPI(lifespan=lifespan) | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=[ | |
| "https://*.hf.space", # Allow all Hugging Face Spaces | |
| "https://*.huggingface.co", # Allow all Hugging Face domains | |
| # "http://localhost:5500", # Allow localhost:5500 # TODO remove before prod | |
| ], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Define the embedding function at module level | |
| def get_embedding_function(): | |
| return embedding_functions.SentenceTransformerEmbeddingFunction( | |
| model_name="nomic-ai/modernbert-embed-base" | |
| ) | |
| def setup_database(): | |
| try: | |
| embedding_function = get_embedding_function() | |
| # Create collection with embedding function | |
| dataset_collection = client.get_or_create_collection( | |
| embedding_function=embedding_function, | |
| name="dataset_cards", | |
| metadata={"hnsw:space": "cosine"}, | |
| ) | |
| # TODO incremental updates | |
| df = pl.scan_parquet( | |
| "hf://datasets/davanstrien/datasets_with_metadata_and_summaries/data/train-*.parquet" | |
| ) | |
| df = df.filter( | |
| pl.col("datasetId").str.contains_any(["open-llm-leaderboard-old/"]).not_() | |
| ) | |
| row_count = df.select(pl.len()).collect().item() | |
| logger.info(f"Row count of new data: {row_count}") | |
| if dataset_collection.count() < row_count: | |
| # Load parquet files and upsert into ChromaDB | |
| df = df.select( | |
| ["datasetId", "summary", "likes", "downloads", "last_modified"] | |
| ) | |
| df = df.collect() | |
| BATCH_SIZE = 1000 | |
| total_rows = len(df) | |
| for i in range(0, total_rows, BATCH_SIZE): | |
| batch_df = df.slice(i, min(BATCH_SIZE, total_rows - i)) | |
| dataset_collection.upsert( | |
| ids=batch_df.select(["datasetId"]).to_series().to_list(), | |
| documents=batch_df.select(["summary"]).to_series().to_list(), | |
| metadatas=[ | |
| { | |
| "likes": int(likes), | |
| "downloads": int(downloads), | |
| "last_modified": str(last_modified), | |
| } | |
| for likes, downloads, last_modified in zip( | |
| batch_df.select(["likes"]).to_series().to_list(), | |
| batch_df.select(["downloads"]).to_series().to_list(), | |
| batch_df.select(["last_modified"]).to_series().to_list(), | |
| ) | |
| ], | |
| ) | |
| logger.info(f"Processed {i + len(batch_df):,} / {total_rows:,} rows") | |
| logger.info(f"Database initialized with {dataset_collection.count():,} rows") | |
| # model_collection = client.get_or_create_collection( | |
| # embedding_function=embedding_function, | |
| # name="model_cards", | |
| # metadata={"hnsw:space": "cosine"}, | |
| # ) | |
| # # If collection is empty, load data from parquet files | |
| # if model_collection.count() == 0: | |
| # # Load parquet files and insert into ChromaDB | |
| # df = pl.scan_parquet( | |
| # "hf://datasets/librarian-bots/model_cards_with_metadata/data/train-*.parquet" | |
| # ) | |
| # df = df.select(["modelId", "likes", "downloads"]) | |
| # df = df.collect() | |
| # df = df.sample(n=1000) # TODO remove for prod | |
| # # Process in batches of 1000 | |
| # BATCH_SIZE = 1000 | |
| # total_rows = len(df) | |
| # for i in range(0, total_rows, BATCH_SIZE): | |
| # batch_df = df.slice(i, min(BATCH_SIZE, total_rows - i)) | |
| # model_collection.add( | |
| # ids=batch_df.select(["modelId"]).to_series().to_list(), | |
| # documents=batch_df.select(["summary"]).to_series().to_list(), | |
| # metadatas=[ | |
| # {"likes": int(likes), "downloads": int(downloads)} | |
| # for likes, downloads in zip( | |
| # batch_df.select(["likes"]).to_series().to_list(), | |
| # batch_df.select(["downloads"]).to_series().to_list(), | |
| # ) | |
| # ], | |
| # ) | |
| # logger.info(f"Processed {i + len(batch_df):,} / {total_rows:,} rows") | |
| # logger.info(f"Database initialized with {model_collection.count():,} rows") | |
| except Exception as e: | |
| logger.error(f"Setup error: {e}") | |
| # Run setup on startup | |
| setup_database() | |
| class QueryResult(BaseModel): | |
| dataset_id: str | |
| similarity: float | |
| summary: str | |
| likes: int | |
| downloads: int | |
| class QueryResponse(BaseModel): | |
| results: List[QueryResult] | |
| async def redirect_to_docs(): | |
| from fastapi.responses import RedirectResponse | |
| return RedirectResponse(url="/docs") | |
| async def search_datasets( | |
| query: str, | |
| k: int = Query(default=5, ge=1, le=100), | |
| sort_by: str = Query( | |
| default="similarity", enum=["similarity", "likes", "downloads"] | |
| ), | |
| min_likes: int = Query(default=0, ge=0), | |
| min_downloads: int = Query(default=0, ge=0), | |
| ): | |
| try: | |
| # Get collection with proper embedding function | |
| collection = client.get_collection( | |
| name="dataset_cards", embedding_function=get_embedding_function() | |
| ) | |
| # Query ChromaDB | |
| results = collection.query( | |
| query_texts=[f"search_query: {query}"], | |
| n_results=k * 4 if sort_by != "similarity" else k, | |
| where={ | |
| "$and": [ | |
| {"likes": {"$gte": min_likes}}, | |
| {"downloads": {"$gte": min_downloads}}, | |
| ] | |
| } | |
| if min_likes > 0 or min_downloads > 0 | |
| else None, | |
| ) | |
| # Process results | |
| query_results = [] | |
| for i in range(len(results["ids"][0])): | |
| query_results.append( | |
| QueryResult( | |
| dataset_id=results["ids"][0][i], | |
| similarity=float(results["distances"][0][i]), | |
| summary=results["documents"][0][i], | |
| likes=results["metadatas"][0][i]["likes"], | |
| downloads=results["metadatas"][0][i]["downloads"], | |
| ) | |
| ) | |
| # Sort results if needed | |
| if sort_by != "similarity": | |
| query_results.sort(key=lambda x: getattr(x, sort_by), reverse=True) | |
| query_results = query_results[:k] | |
| return QueryResponse(results=query_results) | |
| except Exception as e: | |
| logger.error(f"Search error: {str(e)}") | |
| raise HTTPException(status_code=500, detail="Search failed") | |
| async def find_similar_datasets( | |
| dataset_id: str, | |
| k: int = Query(default=5, ge=1, le=100), | |
| sort_by: str = Query( | |
| default="similarity", enum=["similarity", "likes", "downloads"] | |
| ), | |
| min_likes: int = Query(default=0, ge=0), | |
| min_downloads: int = Query(default=0, ge=0), | |
| ): | |
| try: | |
| collection = client.get_collection("dataset_cards") | |
| # Get the reference document | |
| results = collection.get(ids=[dataset_id], include=["embeddings"]) | |
| if not results["ids"]: | |
| raise HTTPException( | |
| status_code=404, detail=f"Dataset ID '{dataset_id}' not found" | |
| ) | |
| # Query using the embedding | |
| results = collection.query( | |
| query_embeddings=[results["embeddings"][0]], | |
| n_results=k * 4 | |
| if sort_by != "similarity" | |
| else k + 1, # +1 to account for self-match | |
| where={ | |
| "$and": [ | |
| {"likes": {"$gte": min_likes}}, | |
| {"downloads": {"$gte": min_downloads}}, | |
| ] | |
| } | |
| if min_likes > 0 or min_downloads > 0 | |
| else None, | |
| ) | |
| # Process results (excluding the query dataset itself) | |
| query_results = [] | |
| for i in range(len(results["ids"][0])): | |
| if results["ids"][0][i] != dataset_id: | |
| query_results.append( | |
| QueryResult( | |
| dataset_id=results["ids"][0][i], | |
| similarity=float(results["distances"][0][i]), | |
| summary=results["documents"][0][i], | |
| likes=results["metadatas"][0][i]["likes"], | |
| downloads=results["metadatas"][0][i]["downloads"], | |
| ) | |
| ) | |
| # Sort results if needed | |
| if sort_by != "similarity": | |
| query_results.sort(key=lambda x: getattr(x, sort_by), reverse=True) | |
| query_results = query_results[:k] | |
| else: | |
| query_results = query_results[:k] | |
| return QueryResponse(results=query_results) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Similarity search error: {str(e)}") | |
| raise HTTPException(status_code=500, detail="Similarity search failed") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |