yjernite's picture
yjernite HF Staff
Upload 5 files
1b21566 verified
from fastapi import FastAPI, HTTPException, Depends, Query
from fastapi.middleware.cors import CORSMiddleware
from typing import List, Optional
import uvicorn
from contextlib import asynccontextmanager
from data_loader import DataLoader
from models import ArticleResponse, ArticleDetail, FiltersResponse
# Initialize data loader
data_loader = DataLoader()
# Dependency functions for API parameters
def get_filter_params(
document_type: Optional[List[str]] = Query(None, description="Filter by document types"),
author_type: Optional[List[str]] = Query(None, description="Filter by author types"),
min_relevance: Optional[float] = Query(None, ge=0, le=10, description="Minimum AI labor relevance score"),
max_relevance: Optional[float] = Query(None, ge=0, le=10, description="Maximum AI labor relevance score"),
start_date: Optional[str] = Query(None, description="Start date (YYYY-MM-DD)"),
end_date: Optional[str] = Query(None, description="End date (YYYY-MM-DD)"),
topic: Optional[List[str]] = Query(None, description="Filter by document topics"),
search_query: Optional[str] = Query(None, description="Search query for text matching"),
search_type: Optional[str] = Query("exact", description="Search type: 'exact' or 'dense'"),
) -> dict:
return {
'document_types': document_type,
'author_types': author_type,
'min_relevance': min_relevance,
'max_relevance': max_relevance,
'start_date': start_date,
'end_date': end_date,
'topics': topic,
'search_query': search_query,
'search_type': search_type,
}
def get_pagination_params(
page: int = Query(1, ge=1, description="Page number"),
limit: int = Query(20, ge=1, le=100, description="Items per page"),
sort_by: Optional[str] = Query("date", description="Sort by 'date' or 'score'"),
) -> dict:
return {
'page': page,
'limit': limit,
'sort_by': sort_by,
}
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
print("Loading dataset from HuggingFace...")
await data_loader.load_dataset()
print(f"Dataset loaded: {len(data_loader.articles)} articles")
yield
# Shutdown (nothing needed)
app = FastAPI(title="Archive Explorer API: AI, Labor and the Economy", version="1.0.0", lifespan=lifespan)
# Enable CORS for frontend
app.add_middleware(
CORSMiddleware,
allow_origins=[
"http://localhost:3000",
"http://localhost:5173",
"https://yjernite-labor-archive-backend.hf.space" # Add this line
],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
async def root():
return {"message": "Archive Explorer API: AI, Labor and the Economy", "articles_count": len(data_loader.articles)}
@app.get("/filters", response_model=FiltersResponse)
async def get_filters():
"""Get all available filter options"""
return data_loader.get_filter_options()
@app.get("/articles", response_model=List[ArticleResponse])
async def get_articles(
pagination: dict = Depends(get_pagination_params),
filters: dict = Depends(get_filter_params),
):
"""Get filtered and paginated articles"""
return data_loader.get_articles(
**pagination,
**filters,
)
@app.get("/articles/count")
async def get_articles_count(
filters: dict = Depends(get_filter_params),
):
"""Get count of articles matching filters"""
return {"count": data_loader.get_articles_count(**filters)}
@app.get("/filter-counts/{filter_type}")
async def get_filter_counts(
filter_type: str,
filters: dict = Depends(get_filter_params),
):
"""Get counts for each option in a specific filter type"""
if filter_type not in ['document_types', 'author_types', 'topics']:
raise HTTPException(status_code=400, detail="Invalid filter type")
counts = data_loader.get_filter_counts(
filter_type=filter_type,
**filters
)
return counts
@app.get("/articles/{article_id}", response_model=ArticleDetail)
async def get_article(article_id: int):
"""Get detailed article by ID"""
return data_loader.get_article_detail(article_id)
@app.get("/test-search")
async def test_search(q: str):
"""Test search functionality"""
return data_loader._search_articles(q, 'exact')
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)