File size: 3,553 Bytes
6287022
8d5a4b2
 
 
6287022
8d5a4b2
 
 
 
 
 
 
 
 
 
 
 
 
 
6287022
8d5a4b2
 
 
 
 
 
 
 
 
 
 
6287022
10a955c
 
 
8d5a4b2
 
6287022
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8d5a4b2
 
 
6287022
 
 
8d5a4b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6287022
8d5a4b2
6287022
8d5a4b2
 
 
6287022
 
 
 
 
8d5a4b2
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# app.py
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Dict, Any, List

from rag import (
    extract_text_from_pdf,
    chunk_text,
    create_session,
    retrieve_top_k,
    generate_answer,
    SESSIONS,
)

app = FastAPI(title="Mini RAG Backend")

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


class AskRequest(BaseModel):
    session_id: str
    question: str
    top_k: int = 3


@app.get("/")
def root():
    return {"status": "ok", "service": "rag-backend", "docs": "/docs"}


@app.post("/ingest")
async def ingest(files: List[UploadFile] = File(...)) -> Dict[str, Any]:
    """
    Accept multiple PDF/TXT files and build ONE combined RAG index across all chunks.
    """
    all_chunks: List[str] = []
    docs_meta: List[Dict[str, Any]] = []

    if not files or len(files) == 0:
        raise HTTPException(status_code=400, detail="No files uploaded")

    for doc_id, file in enumerate(files):
        filename = (file.filename or "").lower()
        content = await file.read()

        if filename.endswith(".pdf"):
            text = extract_text_from_pdf(content)
        elif filename.endswith(".txt"):
            text = content.decode("utf-8", errors="ignore")
        else:
            raise HTTPException(status_code=400, detail=f"Only PDF or TXT allowed: {file.filename}")

        text = text.strip()
        if not text:
            raise HTTPException(status_code=400, detail=f"No extractable text found in: {file.filename}")

        chunks = chunk_text(text, chunk_size_words=350, overlap_words=60)

        if len(chunks) == 0:
            raise HTTPException(status_code=400, detail=f"Chunking produced 0 chunks for: {file.filename}")

        # Prefix chunk with filename to help the LLM cite which file it came from (optional but helpful)
        # Keeps retrieval the same, improves answer grounding readability.
        chunks = [f"[SOURCE: {file.filename}]\n{c}" for c in chunks]

        start_idx = len(all_chunks)
        all_chunks.extend(chunks)
        end_idx = len(all_chunks) - 1

        docs_meta.append(
            {
                "doc_id": doc_id,
                "filename": file.filename,
                "num_chunks": len(chunks),
                "chunk_range": [start_idx, end_idx],
            }
        )

    session_id = create_session(all_chunks, docs_meta)

    return {
        "session_id": session_id,
        "num_files": len(files),
        "num_chunks": len(all_chunks),
        "docs": docs_meta,
    }


@app.post("/ask")
async def ask(req: AskRequest) -> Dict[str, Any]:
    sess = SESSIONS.get(req.session_id)
    if not sess:
        raise HTTPException(status_code=404, detail="Invalid session_id")

    chunks = sess["chunks"]
    index = sess["index"]

    hits = retrieve_top_k(req.question, chunks, index, k=req.top_k)
    context = "\n\n---\n\n".join([h[2] for h in hits])

    answer = generate_answer(req.question, context)

    # Return full chunk text (no truncation)
    sources = [
        {"chunk_id": h[0], "score": h[1], "text": h[2]}
        for h in hits
    ]

    return {
        "answer": answer,
        "sources": sources,
        "docs": sess.get("docs", []),
    }


@app.delete("/session/{session_id}")
async def delete_session(session_id: str):
    if session_id in SESSIONS:
        del SESSIONS[session_id]
    return {"status": "ok"}