MohitGupta41 commited on
Commit
8d5a4b2
·
1 Parent(s): 0570fc0

FastAPI RAG backend (Docker)

Browse files
Files changed (4) hide show
  1. Dockerfile +23 -0
  2. app.py +86 -0
  3. rag.py +102 -0
  4. requirements.txt +9 -0
Dockerfile ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # System deps (pymupdf may need extra libs sometimes; this minimal usually works)
6
+ RUN apt-get update && apt-get install -y --no-install-recommends \
7
+ gcc \
8
+ && rm -rf /var/lib/apt/lists/*
9
+
10
+ COPY requirements.txt .
11
+ RUN pip install --no-cache-dir -r requirements.txt
12
+
13
+ COPY . .
14
+
15
+ # Hugging Face expects port 7860
16
+ EXPOSE 7860
17
+
18
+ # Recommended: keep caches in /tmp on Spaces
19
+ ENV HF_HOME=/tmp/hf
20
+ ENV TRANSFORMERS_CACHE=/tmp/hf/transformers
21
+ ENV SENTENCE_TRANSFORMERS_HOME=/tmp/hf/sentence-transformers
22
+
23
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File, HTTPException
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from pydantic import BaseModel
4
+ from typing import List, Optional, Dict, Any
5
+
6
+ from rag import (
7
+ extract_text_from_pdf,
8
+ chunk_text,
9
+ create_session,
10
+ retrieve_top_k,
11
+ generate_answer,
12
+ SESSIONS,
13
+ )
14
+
15
+ app = FastAPI(title="Mini RAG Backend")
16
+
17
+ app.add_middleware(
18
+ CORSMiddleware,
19
+ allow_origins=["*"], # tighten later if needed
20
+ allow_credentials=True,
21
+ allow_methods=["*"],
22
+ allow_headers=["*"],
23
+ )
24
+
25
+
26
+ class AskRequest(BaseModel):
27
+ session_id: str
28
+ question: str
29
+ top_k: int = 3
30
+
31
+
32
+ @app.post("/ingest")
33
+ async def ingest(file: UploadFile = File(...)) -> Dict[str, Any]:
34
+ filename = file.filename.lower()
35
+ content = await file.read()
36
+
37
+ if filename.endswith(".pdf"):
38
+ text = extract_text_from_pdf(content)
39
+ elif filename.endswith(".txt"):
40
+ text = content.decode("utf-8", errors="ignore")
41
+ else:
42
+ raise HTTPException(status_code=400, detail="Only PDF or TXT allowed")
43
+
44
+ text = text.strip()
45
+ if not text:
46
+ raise HTTPException(status_code=400, detail="No extractable text found")
47
+
48
+ chunks = chunk_text(text, chunk_size_words=350, overlap_words=60)
49
+ if len(chunks) == 0:
50
+ raise HTTPException(status_code=400, detail="Chunking produced 0 chunks")
51
+
52
+ session_id = create_session(chunks)
53
+
54
+ return {
55
+ "session_id": session_id,
56
+ "num_chunks": len(chunks)
57
+ }
58
+
59
+
60
+ @app.post("/ask")
61
+ async def ask(req: AskRequest) -> Dict[str, Any]:
62
+ sess = SESSIONS.get(req.session_id)
63
+ if not sess:
64
+ raise HTTPException(status_code=404, detail="Invalid session_id")
65
+
66
+ chunks = sess["chunks"]
67
+ index = sess["index"]
68
+
69
+ hits = retrieve_top_k(req.question, chunks, index, k=req.top_k)
70
+ context = "\n\n---\n\n".join([h[2] for h in hits])
71
+
72
+ answer = generate_answer(req.question, context)
73
+
74
+ sources = [
75
+ {"chunk_id": h[0], "score": h[1], "text": h[2][:400] + "..."}
76
+ for h in hits
77
+ ]
78
+
79
+ return {"answer": answer, "sources": sources}
80
+
81
+
82
+ @app.delete("/session/{session_id}")
83
+ async def delete_session(session_id: str):
84
+ if session_id in SESSIONS:
85
+ del SESSIONS[session_id]
86
+ return {"status": "ok"}
rag.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uuid
2
+ from dataclasses import dataclass
3
+ from typing import List, Dict, Any, Tuple
4
+ import time
5
+
6
+ import numpy as np
7
+ import faiss
8
+ from sentence_transformers import SentenceTransformer
9
+
10
+ # PDF extraction
11
+ import fitz # pymupdf
12
+
13
+ # LLM (choose 1)
14
+ from transformers import pipeline
15
+
16
+
17
+ # -----------------------------
18
+ # Globals (MVP)
19
+ # -----------------------------
20
+ EMBEDDER = SentenceTransformer("all-MiniLM-L6-v2")
21
+
22
+ # For MVP: use a smallish instruct model if possible
23
+ # NOTE: Mistral 7B is heavy; if you can't run it locally, use a smaller HF model.
24
+ GENERATOR = pipeline(
25
+ "text-generation",
26
+ model="google/flan-t5-base", # safe CPU model for MVP
27
+ max_new_tokens=256
28
+ )
29
+
30
+ SESSIONS: Dict[str, Dict[str, Any]] = {} # session_id -> {chunks, index, created_at}
31
+
32
+
33
+ # -----------------------------
34
+ # Helpers
35
+ # -----------------------------
36
+ def extract_text_from_pdf(pdf_bytes: bytes) -> str:
37
+ doc = fitz.open(stream=pdf_bytes, filetype="pdf")
38
+ pages = []
39
+ for page in doc:
40
+ pages.append(page.get_text("text"))
41
+ return "\n".join(pages).strip()
42
+
43
+
44
+ def chunk_text(text: str, chunk_size_words: int = 350, overlap_words: int = 60) -> List[str]:
45
+ words = text.split()
46
+ chunks = []
47
+ step = max(1, chunk_size_words - overlap_words)
48
+ for i in range(0, len(words), step):
49
+ chunk = words[i:i + chunk_size_words]
50
+ if chunk:
51
+ chunks.append(" ".join(chunk))
52
+ return chunks
53
+
54
+
55
+ def build_faiss_index(vectors: np.ndarray) -> faiss.Index:
56
+ vectors = vectors.astype("float32")
57
+ dim = vectors.shape[1]
58
+ index = faiss.IndexFlatIP(dim) # cosine-like if vectors normalized
59
+ faiss.normalize_L2(vectors)
60
+ index.add(vectors)
61
+ return index
62
+
63
+
64
+ def retrieve_top_k(query: str, chunks: List[str], index: faiss.Index, k: int = 3) -> List[Tuple[int, float, str]]:
65
+ q = EMBEDDER.encode([query], convert_to_numpy=True).astype("float32")
66
+ faiss.normalize_L2(q)
67
+ scores, ids = index.search(q, k)
68
+ results = []
69
+ for rank, idx in enumerate(ids[0]):
70
+ if idx == -1:
71
+ continue
72
+ results.append((int(idx), float(scores[0][rank]), chunks[int(idx)]))
73
+ return results
74
+
75
+
76
+ def generate_answer(question: str, context: str) -> str:
77
+ prompt = (
78
+ "Answer using ONLY the provided context. "
79
+ "If not found in the context, say: Not found in the provided documents.\n\n"
80
+ f"Context:\n{context}\n\nQuestion:\n{question}\n\nAnswer:"
81
+ )
82
+
83
+ # For flan-t5 pipeline: use text2text-generation instead normally,
84
+ # but keeping this simple - if needed swap pipeline.
85
+ out = GENERATOR(prompt)
86
+ # pipeline output format differs by model; handle safely:
87
+ if isinstance(out, list) and out and "generated_text" in out[0]:
88
+ return out[0]["generated_text"]
89
+ return str(out)
90
+
91
+
92
+ def create_session(chunks: List[str]) -> str:
93
+ embeddings = EMBEDDER.encode(chunks, convert_to_numpy=True)
94
+ index = build_faiss_index(embeddings)
95
+
96
+ session_id = str(uuid.uuid4())
97
+ SESSIONS[session_id] = {
98
+ "chunks": chunks,
99
+ "index": index,
100
+ "created_at": time.time()
101
+ }
102
+ return session_id
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ python-multipart
4
+ pydantic
5
+ sentence-transformers
6
+ faiss-cpu
7
+ pymupdf
8
+ transformers
9
+ torch