Spaces:
Sleeping
Sleeping
Commit ·
2cfed75
1
Parent(s): bb17e33
added summarizer mode
Browse files- backend/app/api/routes_chat.py +28 -17
- backend/app/api/routes_chat_langchain.py +53 -0
- backend/app/api/routes_summarize.py +0 -33
- backend/app/core/prompts.py +21 -0
- backend/app/main.py +2 -2
- backend/app/models/api.py +3 -0
- backend/app/retrieval/langchain_retriever.py +35 -0
- backend/app/summarization/__init__.py +0 -0
- backend/app/summarization/langchain_summarizer.py +0 -48
- requirements.txt +4 -1
backend/app/api/routes_chat.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
| 1 |
-
"""Chat routes for
|
| 2 |
|
| 3 |
from app.core.llm import llm_chat
|
| 4 |
-
from app.core.prompts import build_rag_prompt
|
| 5 |
from app.models.api import ChatRequest, ChatResponse
|
|
|
|
| 6 |
from app.retrieval.citation_filter import filter_citations
|
| 7 |
from app.retrieval.retrieve import hybrid_graph_search
|
| 8 |
from fastapi import APIRouter
|
|
@@ -12,7 +13,28 @@ router = APIRouter()
|
|
| 12 |
|
| 13 |
@router.post("/ask", response_model=ChatResponse)
|
| 14 |
def chat(request: ChatRequest) -> ChatResponse:
|
| 15 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
results = hybrid_graph_search(request.query, request.top_k)
|
| 17 |
|
| 18 |
if not results:
|
|
@@ -22,20 +44,9 @@ def chat(request: ChatRequest) -> ChatResponse:
|
|
| 22 |
)
|
| 23 |
|
| 24 |
context = "\n\n".join(sc.chunk.text for sc in results)
|
| 25 |
-
|
| 26 |
-
messages = build_rag_prompt(
|
| 27 |
-
context=context,
|
| 28 |
-
question=request.query,
|
| 29 |
-
)
|
| 30 |
|
| 31 |
answer = llm_chat(messages=messages)
|
|
|
|
| 32 |
|
| 33 |
-
|
| 34 |
-
answer=answer,
|
| 35 |
-
chunks=results,
|
| 36 |
-
)
|
| 37 |
-
|
| 38 |
-
return ChatResponse(
|
| 39 |
-
answer=answer,
|
| 40 |
-
citations=citations,
|
| 41 |
-
)
|
|
|
|
| 1 |
+
"""Chat routes for QA and summarization."""
|
| 2 |
|
| 3 |
from app.core.llm import llm_chat
|
| 4 |
+
from app.core.prompts import build_rag_prompt, build_summary_prompt
|
| 5 |
from app.models.api import ChatRequest, ChatResponse
|
| 6 |
+
from app.retrieval.chunk_registry import get_chunks
|
| 7 |
from app.retrieval.citation_filter import filter_citations
|
| 8 |
from app.retrieval.retrieve import hybrid_graph_search
|
| 9 |
from fastapi import APIRouter
|
|
|
|
| 13 |
|
| 14 |
@router.post("/ask", response_model=ChatResponse)
|
| 15 |
def chat(request: ChatRequest) -> ChatResponse:
|
| 16 |
+
"""Unified QA + Summarization endpoint."""
|
| 17 |
+
if request.mode == "summarize":
|
| 18 |
+
# Summarization uses ALL chunks (no top_k truncation)
|
| 19 |
+
chunks = get_chunks()
|
| 20 |
+
|
| 21 |
+
if not chunks:
|
| 22 |
+
return ChatResponse(
|
| 23 |
+
answer="No documents available to summarize.",
|
| 24 |
+
citations=[],
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
context = "\n\n".join(chunk.text for chunk in chunks)
|
| 28 |
+
messages = build_summary_prompt(context)
|
| 29 |
+
|
| 30 |
+
answer = llm_chat(messages=messages)
|
| 31 |
+
|
| 32 |
+
# no citations for summarization
|
| 33 |
+
citations = []
|
| 34 |
+
|
| 35 |
+
return ChatResponse(answer=answer, citations=citations)
|
| 36 |
+
|
| 37 |
+
# QA MODE (default)
|
| 38 |
results = hybrid_graph_search(request.query, request.top_k)
|
| 39 |
|
| 40 |
if not results:
|
|
|
|
| 44 |
)
|
| 45 |
|
| 46 |
context = "\n\n".join(sc.chunk.text for sc in results)
|
| 47 |
+
messages = build_rag_prompt(context, request.query)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
answer = llm_chat(messages=messages)
|
| 50 |
+
citations = filter_citations(answer=answer, chunks=results)
|
| 51 |
|
| 52 |
+
return ChatResponse(answer=answer, citations=citations)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
backend/app/api/routes_chat_langchain.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Chat routes using LangChain retriever."""
|
| 2 |
+
|
| 3 |
+
from app.config import settings
|
| 4 |
+
from app.models.api import ChatRequest, ChatResponse
|
| 5 |
+
from app.models.retrieval import ScoredChunk
|
| 6 |
+
from app.retrieval.citation_filter import filter_citations
|
| 7 |
+
from app.retrieval.langchain_retriever import AtlasGraphRetriever
|
| 8 |
+
from fastapi import APIRouter
|
| 9 |
+
from langchain.chains import RetrievalQA
|
| 10 |
+
from langchain_groq import ChatGroq
|
| 11 |
+
|
| 12 |
+
router = APIRouter()
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@router.post("/ask/langchain", response_model=ChatResponse)
|
| 16 |
+
def chat_langchain(request: ChatRequest) -> ChatResponse:
|
| 17 |
+
"""LangChain-powered RAG endpoint with citation filtering."""
|
| 18 |
+
retriever = AtlasGraphRetriever(top_k=request.top_k)
|
| 19 |
+
|
| 20 |
+
llm = ChatGroq(
|
| 21 |
+
api_key=settings.groq_api_key,
|
| 22 |
+
model=settings.default_model,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
qa_chain = RetrievalQA.from_chain_type(
|
| 26 |
+
llm=llm,
|
| 27 |
+
retriever=retriever,
|
| 28 |
+
return_source_documents=True,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
result = qa_chain.invoke({"query": request.query})
|
| 32 |
+
|
| 33 |
+
answer = result["result"]
|
| 34 |
+
source_docs = result.get("source_documents", [])
|
| 35 |
+
|
| 36 |
+
# Convert LangChain docs → ScoredChunk
|
| 37 |
+
scored_chunks = [
|
| 38 |
+
ScoredChunk(
|
| 39 |
+
chunk=doc.metadata["chunk"],
|
| 40 |
+
score=doc.metadata["score"],
|
| 41 |
+
)
|
| 42 |
+
for doc in source_docs
|
| 43 |
+
]
|
| 44 |
+
|
| 45 |
+
citations = filter_citations(
|
| 46 |
+
answer=answer,
|
| 47 |
+
chunks=scored_chunks,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
return ChatResponse(
|
| 51 |
+
answer=answer,
|
| 52 |
+
citations=citations,
|
| 53 |
+
)
|
backend/app/api/routes_summarize.py
DELETED
|
@@ -1,33 +0,0 @@
|
|
| 1 |
-
"""Document summarization route (LangChain-based)."""
|
| 2 |
-
|
| 3 |
-
from app.models.api import ChatResponse
|
| 4 |
-
from app.retrieval.chunk_registry import get_chunks
|
| 5 |
-
from app.summarization.langchain_summarizer import DocumentSummarizer
|
| 6 |
-
from fastapi import APIRouter, HTTPException
|
| 7 |
-
|
| 8 |
-
router = APIRouter()
|
| 9 |
-
summarizer = DocumentSummarizer()
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
@router.post("/langchain", response_model=ChatResponse)
|
| 13 |
-
def summarize_document() -> ChatResponse:
|
| 14 |
-
"""Summarize all ingested documents.
|
| 15 |
-
|
| 16 |
-
Note:
|
| 17 |
-
- This is recall-heavy by design
|
| 18 |
-
- No citations (summary ≠ factual QA)
|
| 19 |
-
"""
|
| 20 |
-
chunks = get_chunks()
|
| 21 |
-
|
| 22 |
-
if not chunks:
|
| 23 |
-
raise HTTPException(
|
| 24 |
-
status_code=400,
|
| 25 |
-
detail="No documents available for summarization.",
|
| 26 |
-
)
|
| 27 |
-
|
| 28 |
-
summary = summarizer.summarize(chunks)
|
| 29 |
-
|
| 30 |
-
return ChatResponse(
|
| 31 |
-
answer=summary,
|
| 32 |
-
citations=[],
|
| 33 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
backend/app/core/prompts.py
CHANGED
|
@@ -10,6 +10,16 @@ Rules:
|
|
| 10 |
- Do NOT add external knowledge.
|
| 11 |
"""
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
def build_rag_prompt(context: str, question: str) -> list[dict]:
|
| 15 |
"""Build messages for RAG-based QA."""
|
|
@@ -26,3 +36,14 @@ Question:
|
|
| 26 |
""".strip(),
|
| 27 |
},
|
| 28 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
- Do NOT add external knowledge.
|
| 11 |
"""
|
| 12 |
|
| 13 |
+
SUMMARY_SYSTEM_PROMPT = """
|
| 14 |
+
You are a document summarization assistant.
|
| 15 |
+
|
| 16 |
+
Rules:
|
| 17 |
+
- Produce a concise, well-structured summary of the provided content.
|
| 18 |
+
- Capture key ideas, steps, and distinctions.
|
| 19 |
+
- Do NOT invent information.
|
| 20 |
+
- Do NOT include instructions, questions, or meta commentary.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
|
| 24 |
def build_rag_prompt(context: str, question: str) -> list[dict]:
|
| 25 |
"""Build messages for RAG-based QA."""
|
|
|
|
| 36 |
""".strip(),
|
| 37 |
},
|
| 38 |
]
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def build_summary_prompt(context: str) -> list[dict]:
|
| 42 |
+
"""Build messages for RAG-based summarization."""
|
| 43 |
+
return [
|
| 44 |
+
{"role": "system", "content": SUMMARY_SYSTEM_PROMPT},
|
| 45 |
+
{
|
| 46 |
+
"role": "user",
|
| 47 |
+
"content": f"Document Content:\n{context}",
|
| 48 |
+
},
|
| 49 |
+
]
|
backend/app/main.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
"""Main FastAPI application for AtlasRAG backend."""
|
| 2 |
|
| 3 |
from app.api.routes_chat import router as chat_router
|
|
|
|
| 4 |
from app.api.routes_docs import router as docs_router
|
| 5 |
-
from app.api.routes_summarize import router as summarize_langchain_router
|
| 6 |
from fastapi import FastAPI
|
| 7 |
from fastapi.middleware.cors import CORSMiddleware
|
| 8 |
|
|
@@ -24,4 +24,4 @@ app.add_middleware(
|
|
| 24 |
# Include routers
|
| 25 |
app.include_router(chat_router, prefix="/chat")
|
| 26 |
app.include_router(docs_router, prefix="/docs")
|
| 27 |
-
app.include_router(
|
|
|
|
| 1 |
"""Main FastAPI application for AtlasRAG backend."""
|
| 2 |
|
| 3 |
from app.api.routes_chat import router as chat_router
|
| 4 |
+
from app.api.routes_chat_langchain import router as chat_langchain_router
|
| 5 |
from app.api.routes_docs import router as docs_router
|
|
|
|
| 6 |
from fastapi import FastAPI
|
| 7 |
from fastapi.middleware.cors import CORSMiddleware
|
| 8 |
|
|
|
|
| 24 |
# Include routers
|
| 25 |
app.include_router(chat_router, prefix="/chat")
|
| 26 |
app.include_router(docs_router, prefix="/docs")
|
| 27 |
+
app.include_router(chat_langchain_router, prefix="/chat")
|
backend/app/models/api.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
| 1 |
"""Pydantic models for API request and response bodies."""
|
| 2 |
|
|
|
|
|
|
|
| 3 |
from pydantic import BaseModel
|
| 4 |
|
| 5 |
|
|
@@ -8,6 +10,7 @@ class ChatRequest(BaseModel):
|
|
| 8 |
|
| 9 |
query: str
|
| 10 |
top_k: int = 5
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
class Citation(BaseModel):
|
|
|
|
| 1 |
"""Pydantic models for API request and response bodies."""
|
| 2 |
|
| 3 |
+
from typing import Literal
|
| 4 |
+
|
| 5 |
from pydantic import BaseModel
|
| 6 |
|
| 7 |
|
|
|
|
| 10 |
|
| 11 |
query: str
|
| 12 |
top_k: int = 5
|
| 13 |
+
mode: Literal["qa", "summarize"] = "qa"
|
| 14 |
|
| 15 |
|
| 16 |
class Citation(BaseModel):
|
backend/app/retrieval/langchain_retriever.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""LangChain retriever wrapper for AtlasRAG."""
|
| 2 |
+
|
| 3 |
+
from typing import List
|
| 4 |
+
|
| 5 |
+
from app.retrieval.retrieve import hybrid_graph_search
|
| 6 |
+
from langchain_core.documents import Document
|
| 7 |
+
from langchain_core.retrievers import BaseRetriever
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class AtlasGraphRetriever(BaseRetriever):
|
| 11 |
+
"""LangChain-compatible retriever wrapping hybrid Graph-RAG."""
|
| 12 |
+
|
| 13 |
+
top_k: int = 5
|
| 14 |
+
|
| 15 |
+
def _get_relevant_documents(self, query: str) -> List[Document]:
|
| 16 |
+
"""Retrieve documents for LangChain."""
|
| 17 |
+
results = hybrid_graph_search(query, self.top_k)
|
| 18 |
+
|
| 19 |
+
documents: List[Document] = []
|
| 20 |
+
|
| 21 |
+
for sc in results:
|
| 22 |
+
documents.append(
|
| 23 |
+
Document(
|
| 24 |
+
page_content=sc.chunk.text,
|
| 25 |
+
metadata={
|
| 26 |
+
"doc_id": sc.chunk.doc_id,
|
| 27 |
+
"page_start": sc.chunk.page_start,
|
| 28 |
+
"page_end": sc.chunk.page_end,
|
| 29 |
+
"chunk": sc.chunk,
|
| 30 |
+
"score": sc.score,
|
| 31 |
+
},
|
| 32 |
+
)
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
return documents
|
backend/app/summarization/__init__.py
DELETED
|
File without changes
|
backend/app/summarization/langchain_summarizer.py
DELETED
|
@@ -1,48 +0,0 @@
|
|
| 1 |
-
"""LangChain-based document summarization using a local HF model."""
|
| 2 |
-
|
| 3 |
-
from typing import List
|
| 4 |
-
|
| 5 |
-
from app.models.ingestion import Chunk
|
| 6 |
-
from langchain.chains.summarize import load_summarize_chain
|
| 7 |
-
from langchain.docstore.document import Document
|
| 8 |
-
from langchain.llms import HuggingFacePipeline
|
| 9 |
-
from transformers import pipeline
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
class DocumentSummarizer:
|
| 13 |
-
"""Document summarizer using LangChain + local HF model."""
|
| 14 |
-
|
| 15 |
-
def __init__(self) -> None:
|
| 16 |
-
"""Initialize HF Pipeline."""
|
| 17 |
-
summarizer = pipeline(
|
| 18 |
-
"summarization",
|
| 19 |
-
model="facebook/bart-large-cnn",
|
| 20 |
-
device=-1,
|
| 21 |
-
)
|
| 22 |
-
|
| 23 |
-
self.llm = HuggingFacePipeline(pipeline=summarizer)
|
| 24 |
-
|
| 25 |
-
self.chain = load_summarize_chain(
|
| 26 |
-
llm=self.llm,
|
| 27 |
-
chain_type="map_reduce",
|
| 28 |
-
verbose=False,
|
| 29 |
-
)
|
| 30 |
-
|
| 31 |
-
def summarize(self, chunks: List[Chunk]) -> str:
|
| 32 |
-
"""Summarize document chunks."""
|
| 33 |
-
if not chunks:
|
| 34 |
-
return "No content available to summarize."
|
| 35 |
-
|
| 36 |
-
documents = [
|
| 37 |
-
Document(
|
| 38 |
-
page_content=chunk.text,
|
| 39 |
-
metadata={
|
| 40 |
-
"doc_id": chunk.doc_id,
|
| 41 |
-
"page_start": chunk.page_start,
|
| 42 |
-
"page_end": chunk.page_end,
|
| 43 |
-
},
|
| 44 |
-
)
|
| 45 |
-
for chunk in chunks
|
| 46 |
-
]
|
| 47 |
-
|
| 48 |
-
return self.chain.run(documents)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
|
@@ -9,9 +9,12 @@ httpx==0.27.0
|
|
| 9 |
# LLM & Embedding Clients
|
| 10 |
openai==1.37.0
|
| 11 |
groq==0.5.0
|
| 12 |
-
langchain==0.2.
|
|
|
|
|
|
|
| 13 |
langchain-groq==0.1.4
|
| 14 |
langchain-openai==0.1.8
|
|
|
|
| 15 |
|
| 16 |
# Vector Databases
|
| 17 |
qdrant-client==1.9.0
|
|
|
|
| 9 |
# LLM & Embedding Clients
|
| 10 |
openai==1.37.0
|
| 11 |
groq==0.5.0
|
| 12 |
+
langchain==0.2.12
|
| 13 |
+
langchain-core==0.2.27
|
| 14 |
+
langchain-community==0.2.11
|
| 15 |
langchain-groq==0.1.4
|
| 16 |
langchain-openai==0.1.8
|
| 17 |
+
langchain-huggingface==0.0.3
|
| 18 |
|
| 19 |
# Vector Databases
|
| 20 |
qdrant-client==1.9.0
|