Spaces:
Sleeping
Sleeping
File size: 5,061 Bytes
8fa1db6 b32f886 8fa1db6 b32f886 8fa1db6 6e5a18b 51c08bb 8fa1db6 b32f886 8fa1db6 51c08bb 8fa1db6 4b5fbc0 8fa1db6 c6d1805 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
from fastapi import FastAPI, HTTPException
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
from transformers import AutoTokenizer, AutoModel
import torch
import lancedb
import os
# For local testing
from dotenv import load_dotenv
load_dotenv()
app = FastAPI()
# Enable CORS for React frontend
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # HF Spaces handles security
allow_methods=["*"],
allow_headers=["*"],
)
class BookRequest(BaseModel):
query: str
limit: int = 5
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
def encode_query(text):
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
with torch.no_grad():
outputs = model(**inputs)
return outputs.last_hidden_state.mean(dim=1).squeeze().numpy()
db = lancedb.connect(
uri=os.getenv("LANCEDB_URI", "your-lancedb-uri"),
api_key=os.getenv("LANCEDB_API_KEY", "your-api-key"),
region="us-east-1"
)
# API Routes
@app.get("/health")
def health_check():
return {"status": "healthy"}
@app.get("/tables")
def list_tables():
"""Debug endpoint to see available tables"""
try:
tables = db.table_names()
return {"tables": tables}
except Exception as e:
return {"error": str(e)}
@app.get("/api/test")
def test_endpoint():
print("TEST ENDPOINT HIT!")
return {"message": "API is working"}
@app.post("/api/search")
def search_books(book_request: BookRequest):
if not book_request.query:
raise HTTPException(status_code=400, detail="Query string is required")
try:
print(f"1. Starting search for query: '{book_request.query}'")
print("2. Generating embeddings...")
query_embedding = encode_query(book_request.query)
print(f"3. Embeddings generated successfully, shape: {query_embedding.shape}")
print("4. Connecting to database table...")
table = db.open_table(os.getenv("LANCEDB_TABLE", "book_db"))
print("5. Database table opened successfully")
print("6. Starting vector search...")
results = (table.search(query_embedding)
.select([
"id", "title", "primary_author", "description",
"publisher", "published_date", "page_count",
"primary_category", "avg_rating", "ratings_count",
"thumbnail_url", "preview_link", "list_price", "buy_link",
"_distance"
])
.limit(book_request.limit)
.to_list())
print(f"7. Search completed successfully, found {len(results)} results")
return {"results": results}
except Exception as e:
print(f"ERROR: Search failed at step with error: {str(e)}")
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"Search failed: {str(e)}")
"""
@app.post("/api/search")
def search_books(book_request: BookRequest):
if not book_request.query:
raise HTTPException(status_code=400, detail="Query string is required")
# Get query embeddings
query_embedding = encode_query(book_request.query)
#query_embedding = model.encode(book_request.query)
table = db.open_table(os.getenv("LANCEDB_TABLE", "book_db"))
results = (table.search(query_embedding)
.select([
"id", "title", "primary_author", "description",
"publisher", "published_date", "page_count",
"primary_category", "avg_rating", "ratings_count",
"thumbnail_url", "preview_link", "list_price", "buy_link",
"_distance"
])
.limit(book_request.limit)
.to_list())
return {"results": results}
"""
@app.get("/debug/env")
def check_env():
return {
"LANCEDB_URI_set": bool(os.getenv("LANCEDB_URI")),
"LANCEDB_API_KEY_set": bool(os.getenv("LANCEDB_API_KEY")),
"LANCEDB_TABLE_set": bool(os.getenv("LANCEDB_TABLE")),
"LANCEDB_URI_preview": os.getenv("LANCEDB_URI", "NOT_SET")[:20] + "..." if os.getenv("LANCEDB_URI") else "NOT_SET"
}
# Serve React static files (add this section)
# Mount static assets (CSS, JS, images)
app.mount("/assets", StaticFiles(directory="dist/assets"), name="assets")
# Serve React app for all other routes (this should be last)
@app.get("/{full_path:path}")
async def serve_react_app(full_path: str):
# Don't serve React for API routes
if full_path.startswith("api/") or full_path in ["health", "tables", "docs", "redoc"]:
raise HTTPException(status_code=404, detail="Not found")
return FileResponse('dist/index.html')
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860) |