innofacisteven's picture
Update main.py
1203ceb verified
import os
import shutil
from fastapi import FastAPI, UploadFile, File, HTTPException, Depends, Security
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import uvicorn
# Import logic from your rag.py
from rag import get_retriever, add_document
app = FastAPI(title="LlamaIndex RAG API")
# Setup Security
security = HTTPBearer()
# Retrieve the secret from Hugging Face Space Settings
API_KEY = os.environ.get("HF_TOKEN")
if not API_KEY:
# If the secret is missing, the service should not start
raise RuntimeError("HF_TOKEN secret is not set in Hugging Face Space settings!")
def validate_api_key(auth: HTTPAuthorizationCredentials = Security(security)):
"""
Validates the Bearer token provided in the Authorization header.
Expects: Authorization: Bearer <your_hf_token>
"""
if auth.credentials != API_KEY:
raise HTTPException(
status_code=403,
detail="Invalid or missing API Key"
)
return auth.credentials
# Enable CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"]
)
class QueryRequest(BaseModel):
query: str
@app.get("/")
def read_root():
"""Public health check endpoint."""
return {"status": "online", "service": "LlamaIndex Retrieval Service"}
@app.post("/upload", dependencies=[Depends(validate_api_key)])
async def upload_file(file: UploadFile = File(...)):
"""
Endpoint for n8n to upload files.
Protected by API Key.
"""
os.makedirs("uploads", exist_ok=True)
file_path = f"uploads/{file.filename}"
try:
# Stream the file to disk
with open(file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
# Index the document via LlamaIndex
doc_count = add_document(file_path)
return {
"filename": file.filename,
"status": "success",
"total_documents_indexed": doc_count
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Upload failed: {str(e)}")
finally:
await file.close()
@app.post("/retrieve", dependencies=[Depends(validate_api_key)])
async def retrieve(request: QueryRequest):
try:
engine = get_retriever()
response = engine.query(request.query)
# 建立 Markdown 格式的總結字串
markdown_output = f"# Query: {request.query}\n\n"
nodes = []
for i, node in enumerate(response.source_nodes):
text = node.node.get_content()
score = float(node.score) if node.score is not None else 0.0
file_name = node.node.metadata.get("file_name", "unknown")
# 放入列表
nodes.append({
"chunk_id": i + 1,
"text": text,
"score": score,
"metadata": node.node.metadata
})
# 拼接成 Markdown 格式
markdown_output += f"### Chunk {i+1} (Score: {score:.4f})\n"
markdown_output += f"**Source:** {file_name}\n\n"
markdown_output += f"{text}\n\n---\n\n"
return {
"query": request.query,
"total_chunks": len(nodes),
"retrieved_chunks": nodes,
"markdown": markdown_output # 直接給 n8n 一個大字串
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/inspect_index", dependencies=[Depends(validate_api_key)])
async def inspect_index():
"""View all parsed text blocks in the current index."""
try:
from rag import initialize_index
index = initialize_index()
docstore = index.storage_context.docstore
all_nodes = docstore.get_all_ref_nodes()
return {
"total_nodes": len(all_nodes),
"nodes": [n.get_content()[:200] + "..." for n in all_nodes]
}
except Exception as e:
return {"error": str(e)}
if __name__ == "__main__":
# Hugging Face Spaces port is 7860
uvicorn.run(app, host="0.0.0.0", port=7860)