Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import FileResponse | |
| from pydantic import BaseModel | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from transformers import AutoTokenizer, AutoModel | |
| import torch | |
| import lancedb | |
| import os | |
| # For local testing | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| app = FastAPI() | |
| # Enable CORS for React frontend | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # HF Spaces handles security | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| class BookRequest(BaseModel): | |
| query: str | |
| limit: int = 5 | |
| tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2') | |
| model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2') | |
| def encode_query(text): | |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| return outputs.last_hidden_state.mean(dim=1).squeeze().numpy() | |
| db = lancedb.connect( | |
| uri=os.getenv("LANCEDB_URI", "your-lancedb-uri"), | |
| api_key=os.getenv("LANCEDB_API_KEY", "your-api-key"), | |
| region="us-east-1" | |
| ) | |
| # API Routes | |
| def health_check(): | |
| return {"status": "healthy"} | |
| def list_tables(): | |
| """Debug endpoint to see available tables""" | |
| try: | |
| tables = db.table_names() | |
| return {"tables": tables} | |
| except Exception as e: | |
| return {"error": str(e)} | |
| def test_endpoint(): | |
| print("TEST ENDPOINT HIT!") | |
| return {"message": "API is working"} | |
| def search_books(book_request: BookRequest): | |
| if not book_request.query: | |
| raise HTTPException(status_code=400, detail="Query string is required") | |
| try: | |
| print(f"1. Starting search for query: '{book_request.query}'") | |
| print("2. Generating embeddings...") | |
| query_embedding = encode_query(book_request.query) | |
| print(f"3. Embeddings generated successfully, shape: {query_embedding.shape}") | |
| print("4. Connecting to database table...") | |
| table = db.open_table(os.getenv("LANCEDB_TABLE", "book_db")) | |
| print("5. Database table opened successfully") | |
| print("6. Starting vector search...") | |
| results = (table.search(query_embedding) | |
| .select([ | |
| "id", "title", "primary_author", "description", | |
| "publisher", "published_date", "page_count", | |
| "primary_category", "avg_rating", "ratings_count", | |
| "thumbnail_url", "preview_link", "list_price", "buy_link", | |
| "_distance" | |
| ]) | |
| .limit(book_request.limit) | |
| .to_list()) | |
| print(f"7. Search completed successfully, found {len(results)} results") | |
| return {"results": results} | |
| except Exception as e: | |
| print(f"ERROR: Search failed at step with error: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=f"Search failed: {str(e)}") | |
| """ | |
| @app.post("/api/search") | |
| def search_books(book_request: BookRequest): | |
| if not book_request.query: | |
| raise HTTPException(status_code=400, detail="Query string is required") | |
| # Get query embeddings | |
| query_embedding = encode_query(book_request.query) | |
| #query_embedding = model.encode(book_request.query) | |
| table = db.open_table(os.getenv("LANCEDB_TABLE", "book_db")) | |
| results = (table.search(query_embedding) | |
| .select([ | |
| "id", "title", "primary_author", "description", | |
| "publisher", "published_date", "page_count", | |
| "primary_category", "avg_rating", "ratings_count", | |
| "thumbnail_url", "preview_link", "list_price", "buy_link", | |
| "_distance" | |
| ]) | |
| .limit(book_request.limit) | |
| .to_list()) | |
| return {"results": results} | |
| """ | |
| def check_env(): | |
| return { | |
| "LANCEDB_URI_set": bool(os.getenv("LANCEDB_URI")), | |
| "LANCEDB_API_KEY_set": bool(os.getenv("LANCEDB_API_KEY")), | |
| "LANCEDB_TABLE_set": bool(os.getenv("LANCEDB_TABLE")), | |
| "LANCEDB_URI_preview": os.getenv("LANCEDB_URI", "NOT_SET")[:20] + "..." if os.getenv("LANCEDB_URI") else "NOT_SET" | |
| } | |
| # Serve React static files (add this section) | |
| # Mount static assets (CSS, JS, images) | |
| app.mount("/assets", StaticFiles(directory="dist/assets"), name="assets") | |
| # Serve React app for all other routes (this should be last) | |
| async def serve_react_app(full_path: str): | |
| # Don't serve React for API routes | |
| if full_path.startswith("api/") or full_path in ["health", "tables", "docs", "redoc"]: | |
| raise HTTPException(status_code=404, detail="Not found") | |
| return FileResponse('dist/index.html') | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |