import csv import os from pathlib import Path import faiss import numpy as np from fastapi import FastAPI, HTTPException from pydantic import BaseModel, Field from sentence_transformers import SentenceTransformer DATA_PATH = Path(os.getenv("FAQ_CSV_PATH", "data/lauki_qna.csv")) MODEL_NAME = os.getenv("EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2") app = FastAPI(title="Lauki FAQ Retrieval Service") class SearchRequest(BaseModel): query: str = Field(min_length=1) k: int = Field(default=3, ge=1, le=10) class SearchResult(BaseModel): rank: int score: float question: str answer: str content: str class SearchResponse(BaseModel): query: str k: int context: str results: list[SearchResult] def load_rows(path: Path) -> list[dict[str, str]]: if not path.exists(): raise FileNotFoundError(f"FAQ CSV not found: {path}") rows: list[dict[str, str]] = [] with path.open("r", encoding="utf-8") as f: reader = csv.DictReader(f) for row in reader: question = row["question"].strip() answer = row["answer"].strip() rows.append( { "question": question, "answer": answer, "content": f"Q: {question}\nA: {answer}", } ) return rows rows = load_rows(DATA_PATH) model = SentenceTransformer(MODEL_NAME) embeddings = model.encode( [row["content"] for row in rows], convert_to_numpy=True, normalize_embeddings=True, ) embeddings = np.asarray(embeddings, dtype="float32") index = faiss.IndexFlatIP(embeddings.shape[1]) index.add(embeddings) @app.get("/health") def health() -> dict[str, int | str]: return {"status": "ok", "documents": len(rows), "model": MODEL_NAME} @app.post("/search", response_model=SearchResponse) def search(req: SearchRequest) -> SearchResponse: if not rows: raise HTTPException(status_code=503, detail="FAQ index is empty") query_embedding = model.encode( [req.query], convert_to_numpy=True, normalize_embeddings=True, ) query_embedding = np.asarray(query_embedding, dtype="float32") scores, indices = index.search(query_embedding, min(req.k, len(rows))) results: list[SearchResult] = [] for rank, (score, idx) in enumerate(zip(scores[0], indices[0]), start=1): if idx < 0: continue row = rows[int(idx)] results.append( SearchResult( rank=rank, score=float(score), question=row["question"], answer=row["answer"], content=row["content"], ) ) context = "\n\n---\n\n".join( f"FAQ Entry {result.rank}:\n{result.content}" for result in results ) return SearchResponse(query=req.query, k=req.k, context=context, results=results) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", "7860")))