File size: 6,471 Bytes
634117a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
"""
kerdos_rag/server.py
FastAPI REST server exposing the KerdosRAG engine over HTTP.

Endpoints:
    GET  /health          — liveness probe
    GET  /status          — knowledge-base metadata
    POST /index           — upload + index documents (multipart/form-data)
    POST /chat            — ask a question (SSE streaming response)
    DELETE /reset         — clear the knowledge base

Authentication (optional):
    Set API_KEY env var to enable X-Api-Key header validation.
    Leave unset to run in open mode (suitable for local / trusted environments).
"""

from __future__ import annotations

import os
import asyncio
from typing import AsyncGenerator

from fastapi import FastAPI, File, UploadFile, HTTPException, Depends, Header, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, JSONResponse
from pydantic import BaseModel

from kerdos_rag.core import KerdosRAG

# ── App & CORS ────────────────────────────────────────────────────────────────
app = FastAPI(
    title="Kerdos RAG API",
    description="Enterprise Document Q&A engine by Kerdos Infrasoft",
    version="0.1.0",
    contact={"name": "Kerdos Infrasoft", "url": "https://kerdos.in", "email": "partnership@kerdos.in"},
    license_info={"name": "MIT"},
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

# ── Singleton engine ───────────────────────────────────────────────────────────
_engine = KerdosRAG()

# ── Auth ──────────────────────────────────────────────────────────────────────
_API_KEY = os.environ.get("API_KEY", "")


def _check_auth(x_api_key: str | None = Header(default=None)) -> None:
    """If API_KEY env var is set, validate X-Api-Key header."""
    if _API_KEY and x_api_key != _API_KEY:
        raise HTTPException(status_code=401, detail="Invalid or missing X-Api-Key header.")


# ── Request / Response models ──────────────────────────────────────────────────
class ChatRequest(BaseModel):
    query: str
    history: list[dict] | None = None
    top_k: int | None = None


class StatusResponse(BaseModel):
    indexed_sources: list[str]
    chunk_count: int
    model: str
    top_k: int
    min_score: float


# ── Endpoints ─────────────────────────────────────────────────────────────────

@app.get("/health", tags=["Meta"])
def health() -> dict:
    """Liveness probe — always returns 200 OK."""
    return {"status": "ok", "version": "0.1.0"}


@app.get("/status", response_model=StatusResponse, tags=["Meta"])
def status(_: None = Depends(_check_auth)) -> StatusResponse:
    """Return current knowledge-base metadata."""
    return StatusResponse(
        indexed_sources=list(_engine.indexed_sources),
        chunk_count=_engine.chunk_count,
        model=_engine.model,
        top_k=_engine.top_k,
        min_score=_engine.min_score,
    )


@app.post("/index", tags=["RAG"])
async def index_documents(
    files: list[UploadFile] = File(...),
    _: None = Depends(_check_auth),
) -> JSONResponse:
    """
    Upload and index one or more documents.

    Accepts: PDF (.pdf), Word (.docx), plain text (.txt, .md, .csv).
    Duplicate filenames are automatically skipped.
    """
    import tempfile, shutil
    from pathlib import Path

    saved_paths: list[str] = []
    tmp_dir = tempfile.mkdtemp(prefix="kerdos_upload_")

    try:
        for upload in files:
            dest = Path(tmp_dir) / upload.filename
            with open(dest, "wb") as f:
                shutil.copyfileobj(upload.file, f)
            saved_paths.append(str(dest))

        result = _engine.index(saved_paths)
    finally:
        shutil.rmtree(tmp_dir, ignore_errors=True)

    return JSONResponse(content=result)


@app.post("/chat", tags=["RAG"])
async def chat(req: ChatRequest, _: None = Depends(_check_auth)) -> StreamingResponse:
    """
    Ask a question and receive a **Server-Sent Events** stream of tokens.

    Each SSE event has the form:
        data: <partial answer so far>\\n\\n

    The stream ends with:
        data: [DONE]\\n\\n

    Example (curl):
        curl -X POST http://localhost:8000/chat \\
             -H "Content-Type: application/json" \\
             -d '{"query": "What is the refund policy?"}' \\
             --no-buffer
    """
    if not _engine.is_ready:
        raise HTTPException(
            status_code=422,
            detail="Knowledge base is empty. POST files to /index first.",
        )

    hf_token = _engine.hf_token
    if not hf_token:
        raise HTTPException(
            status_code=401,
            detail="No Hugging Face token configured. Set HF_TOKEN env var.",
        )

    # Temporarily override top_k if caller specified it
    original_top_k = _engine.top_k
    if req.top_k is not None:
        _engine.top_k = req.top_k

    async def event_generator() -> AsyncGenerator[str, None]:
        try:
            # answer_stream is a sync generator — run in thread pool
            loop = asyncio.get_event_loop()
            gen = _engine.chat(req.query, history=req.history)

            while True:
                try:
                    token = await loop.run_in_executor(None, next, gen)
                    # SSE format: escape newlines in the data value
                    escaped = token.replace("\n", "\\n")
                    yield f"data: {escaped}\n\n"
                except StopIteration:
                    break
        finally:
            _engine.top_k = original_top_k

        yield "data: [DONE]\n\n"

    return StreamingResponse(event_generator(), media_type="text/event-stream")


@app.delete("/reset", tags=["RAG"])
def reset(_: None = Depends(_check_auth)) -> dict:
    """Clear the entire knowledge base."""
    _engine.reset()
    return {"ok": True, "message": "Knowledge base cleared."}