SanskarModi commited on
Commit
2cfed75
·
1 Parent(s): bb17e33

added summarizer mode

Browse files
backend/app/api/routes_chat.py CHANGED
@@ -1,8 +1,9 @@
1
- """Chat routes for Graph-RAG."""
2
 
3
  from app.core.llm import llm_chat
4
- from app.core.prompts import build_rag_prompt
5
  from app.models.api import ChatRequest, ChatResponse
 
6
  from app.retrieval.citation_filter import filter_citations
7
  from app.retrieval.retrieve import hybrid_graph_search
8
  from fastapi import APIRouter
@@ -12,7 +13,28 @@ router = APIRouter()
12
 
13
  @router.post("/ask", response_model=ChatResponse)
14
  def chat(request: ChatRequest) -> ChatResponse:
15
- """Graph-augmented RAG endpoint."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  results = hybrid_graph_search(request.query, request.top_k)
17
 
18
  if not results:
@@ -22,20 +44,9 @@ def chat(request: ChatRequest) -> ChatResponse:
22
  )
23
 
24
  context = "\n\n".join(sc.chunk.text for sc in results)
25
-
26
- messages = build_rag_prompt(
27
- context=context,
28
- question=request.query,
29
- )
30
 
31
  answer = llm_chat(messages=messages)
 
32
 
33
- citations = filter_citations(
34
- answer=answer,
35
- chunks=results,
36
- )
37
-
38
- return ChatResponse(
39
- answer=answer,
40
- citations=citations,
41
- )
 
1
+ """Chat routes for QA and summarization."""
2
 
3
  from app.core.llm import llm_chat
4
+ from app.core.prompts import build_rag_prompt, build_summary_prompt
5
  from app.models.api import ChatRequest, ChatResponse
6
+ from app.retrieval.chunk_registry import get_chunks
7
  from app.retrieval.citation_filter import filter_citations
8
  from app.retrieval.retrieve import hybrid_graph_search
9
  from fastapi import APIRouter
 
13
 
14
  @router.post("/ask", response_model=ChatResponse)
15
  def chat(request: ChatRequest) -> ChatResponse:
16
+ """Unified QA + Summarization endpoint."""
17
+ if request.mode == "summarize":
18
+ # Summarization uses ALL chunks (no top_k truncation)
19
+ chunks = get_chunks()
20
+
21
+ if not chunks:
22
+ return ChatResponse(
23
+ answer="No documents available to summarize.",
24
+ citations=[],
25
+ )
26
+
27
+ context = "\n\n".join(chunk.text for chunk in chunks)
28
+ messages = build_summary_prompt(context)
29
+
30
+ answer = llm_chat(messages=messages)
31
+
32
+ # no citations for summarization
33
+ citations = []
34
+
35
+ return ChatResponse(answer=answer, citations=citations)
36
+
37
+ # QA MODE (default)
38
  results = hybrid_graph_search(request.query, request.top_k)
39
 
40
  if not results:
 
44
  )
45
 
46
  context = "\n\n".join(sc.chunk.text for sc in results)
47
+ messages = build_rag_prompt(context, request.query)
 
 
 
 
48
 
49
  answer = llm_chat(messages=messages)
50
+ citations = filter_citations(answer=answer, chunks=results)
51
 
52
+ return ChatResponse(answer=answer, citations=citations)
 
 
 
 
 
 
 
 
backend/app/api/routes_chat_langchain.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Chat routes using LangChain retriever."""
2
+
3
+ from app.config import settings
4
+ from app.models.api import ChatRequest, ChatResponse
5
+ from app.models.retrieval import ScoredChunk
6
+ from app.retrieval.citation_filter import filter_citations
7
+ from app.retrieval.langchain_retriever import AtlasGraphRetriever
8
+ from fastapi import APIRouter
9
+ from langchain.chains import RetrievalQA
10
+ from langchain_groq import ChatGroq
11
+
12
+ router = APIRouter()
13
+
14
+
15
+ @router.post("/ask/langchain", response_model=ChatResponse)
16
+ def chat_langchain(request: ChatRequest) -> ChatResponse:
17
+ """LangChain-powered RAG endpoint with citation filtering."""
18
+ retriever = AtlasGraphRetriever(top_k=request.top_k)
19
+
20
+ llm = ChatGroq(
21
+ api_key=settings.groq_api_key,
22
+ model=settings.default_model,
23
+ )
24
+
25
+ qa_chain = RetrievalQA.from_chain_type(
26
+ llm=llm,
27
+ retriever=retriever,
28
+ return_source_documents=True,
29
+ )
30
+
31
+ result = qa_chain.invoke({"query": request.query})
32
+
33
+ answer = result["result"]
34
+ source_docs = result.get("source_documents", [])
35
+
36
+ # Convert LangChain docs → ScoredChunk
37
+ scored_chunks = [
38
+ ScoredChunk(
39
+ chunk=doc.metadata["chunk"],
40
+ score=doc.metadata["score"],
41
+ )
42
+ for doc in source_docs
43
+ ]
44
+
45
+ citations = filter_citations(
46
+ answer=answer,
47
+ chunks=scored_chunks,
48
+ )
49
+
50
+ return ChatResponse(
51
+ answer=answer,
52
+ citations=citations,
53
+ )
backend/app/api/routes_summarize.py DELETED
@@ -1,33 +0,0 @@
1
- """Document summarization route (LangChain-based)."""
2
-
3
- from app.models.api import ChatResponse
4
- from app.retrieval.chunk_registry import get_chunks
5
- from app.summarization.langchain_summarizer import DocumentSummarizer
6
- from fastapi import APIRouter, HTTPException
7
-
8
- router = APIRouter()
9
- summarizer = DocumentSummarizer()
10
-
11
-
12
- @router.post("/langchain", response_model=ChatResponse)
13
- def summarize_document() -> ChatResponse:
14
- """Summarize all ingested documents.
15
-
16
- Note:
17
- - This is recall-heavy by design
18
- - No citations (summary ≠ factual QA)
19
- """
20
- chunks = get_chunks()
21
-
22
- if not chunks:
23
- raise HTTPException(
24
- status_code=400,
25
- detail="No documents available for summarization.",
26
- )
27
-
28
- summary = summarizer.summarize(chunks)
29
-
30
- return ChatResponse(
31
- answer=summary,
32
- citations=[],
33
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
backend/app/core/prompts.py CHANGED
@@ -10,6 +10,16 @@ Rules:
10
  - Do NOT add external knowledge.
11
  """
12
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  def build_rag_prompt(context: str, question: str) -> list[dict]:
15
  """Build messages for RAG-based QA."""
@@ -26,3 +36,14 @@ Question:
26
  """.strip(),
27
  },
28
  ]
 
 
 
 
 
 
 
 
 
 
 
 
10
  - Do NOT add external knowledge.
11
  """
12
 
13
+ SUMMARY_SYSTEM_PROMPT = """
14
+ You are a document summarization assistant.
15
+
16
+ Rules:
17
+ - Produce a concise, well-structured summary of the provided content.
18
+ - Capture key ideas, steps, and distinctions.
19
+ - Do NOT invent information.
20
+ - Do NOT include instructions, questions, or meta commentary.
21
+ """
22
+
23
 
24
  def build_rag_prompt(context: str, question: str) -> list[dict]:
25
  """Build messages for RAG-based QA."""
 
36
  """.strip(),
37
  },
38
  ]
39
+
40
+
41
+ def build_summary_prompt(context: str) -> list[dict]:
42
+ """Build messages for RAG-based summarization."""
43
+ return [
44
+ {"role": "system", "content": SUMMARY_SYSTEM_PROMPT},
45
+ {
46
+ "role": "user",
47
+ "content": f"Document Content:\n{context}",
48
+ },
49
+ ]
backend/app/main.py CHANGED
@@ -1,8 +1,8 @@
1
  """Main FastAPI application for AtlasRAG backend."""
2
 
3
  from app.api.routes_chat import router as chat_router
 
4
  from app.api.routes_docs import router as docs_router
5
- from app.api.routes_summarize import router as summarize_langchain_router
6
  from fastapi import FastAPI
7
  from fastapi.middleware.cors import CORSMiddleware
8
 
@@ -24,4 +24,4 @@ app.add_middleware(
24
  # Include routers
25
  app.include_router(chat_router, prefix="/chat")
26
  app.include_router(docs_router, prefix="/docs")
27
- app.include_router(summarize_langchain_router, prefix="/summarize")
 
1
  """Main FastAPI application for AtlasRAG backend."""
2
 
3
  from app.api.routes_chat import router as chat_router
4
+ from app.api.routes_chat_langchain import router as chat_langchain_router
5
  from app.api.routes_docs import router as docs_router
 
6
  from fastapi import FastAPI
7
  from fastapi.middleware.cors import CORSMiddleware
8
 
 
24
  # Include routers
25
  app.include_router(chat_router, prefix="/chat")
26
  app.include_router(docs_router, prefix="/docs")
27
+ app.include_router(chat_langchain_router, prefix="/chat")
backend/app/models/api.py CHANGED
@@ -1,5 +1,7 @@
1
  """Pydantic models for API request and response bodies."""
2
 
 
 
3
  from pydantic import BaseModel
4
 
5
 
@@ -8,6 +10,7 @@ class ChatRequest(BaseModel):
8
 
9
  query: str
10
  top_k: int = 5
 
11
 
12
 
13
  class Citation(BaseModel):
 
1
  """Pydantic models for API request and response bodies."""
2
 
3
+ from typing import Literal
4
+
5
  from pydantic import BaseModel
6
 
7
 
 
10
 
11
  query: str
12
  top_k: int = 5
13
+ mode: Literal["qa", "summarize"] = "qa"
14
 
15
 
16
  class Citation(BaseModel):
backend/app/retrieval/langchain_retriever.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """LangChain retriever wrapper for AtlasRAG."""
2
+
3
+ from typing import List
4
+
5
+ from app.retrieval.retrieve import hybrid_graph_search
6
+ from langchain_core.documents import Document
7
+ from langchain_core.retrievers import BaseRetriever
8
+
9
+
10
+ class AtlasGraphRetriever(BaseRetriever):
11
+ """LangChain-compatible retriever wrapping hybrid Graph-RAG."""
12
+
13
+ top_k: int = 5
14
+
15
+ def _get_relevant_documents(self, query: str) -> List[Document]:
16
+ """Retrieve documents for LangChain."""
17
+ results = hybrid_graph_search(query, self.top_k)
18
+
19
+ documents: List[Document] = []
20
+
21
+ for sc in results:
22
+ documents.append(
23
+ Document(
24
+ page_content=sc.chunk.text,
25
+ metadata={
26
+ "doc_id": sc.chunk.doc_id,
27
+ "page_start": sc.chunk.page_start,
28
+ "page_end": sc.chunk.page_end,
29
+ "chunk": sc.chunk,
30
+ "score": sc.score,
31
+ },
32
+ )
33
+ )
34
+
35
+ return documents
backend/app/summarization/__init__.py DELETED
File without changes
backend/app/summarization/langchain_summarizer.py DELETED
@@ -1,48 +0,0 @@
1
- """LangChain-based document summarization using a local HF model."""
2
-
3
- from typing import List
4
-
5
- from app.models.ingestion import Chunk
6
- from langchain.chains.summarize import load_summarize_chain
7
- from langchain.docstore.document import Document
8
- from langchain.llms import HuggingFacePipeline
9
- from transformers import pipeline
10
-
11
-
12
- class DocumentSummarizer:
13
- """Document summarizer using LangChain + local HF model."""
14
-
15
- def __init__(self) -> None:
16
- """Initialize HF Pipeline."""
17
- summarizer = pipeline(
18
- "summarization",
19
- model="facebook/bart-large-cnn",
20
- device=-1,
21
- )
22
-
23
- self.llm = HuggingFacePipeline(pipeline=summarizer)
24
-
25
- self.chain = load_summarize_chain(
26
- llm=self.llm,
27
- chain_type="map_reduce",
28
- verbose=False,
29
- )
30
-
31
- def summarize(self, chunks: List[Chunk]) -> str:
32
- """Summarize document chunks."""
33
- if not chunks:
34
- return "No content available to summarize."
35
-
36
- documents = [
37
- Document(
38
- page_content=chunk.text,
39
- metadata={
40
- "doc_id": chunk.doc_id,
41
- "page_start": chunk.page_start,
42
- "page_end": chunk.page_end,
43
- },
44
- )
45
- for chunk in chunks
46
- ]
47
-
48
- return self.chain.run(documents)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -9,9 +9,12 @@ httpx==0.27.0
9
  # LLM & Embedding Clients
10
  openai==1.37.0
11
  groq==0.5.0
12
- langchain==0.2.11
 
 
13
  langchain-groq==0.1.4
14
  langchain-openai==0.1.8
 
15
 
16
  # Vector Databases
17
  qdrant-client==1.9.0
 
9
  # LLM & Embedding Clients
10
  openai==1.37.0
11
  groq==0.5.0
12
+ langchain==0.2.12
13
+ langchain-core==0.2.27
14
+ langchain-community==0.2.11
15
  langchain-groq==0.1.4
16
  langchain-openai==0.1.8
17
+ langchain-huggingface==0.0.3
18
 
19
  # Vector Databases
20
  qdrant-client==1.9.0