SanskarModi commited on
Commit
0451125
·
1 Parent(s): 58611cd

added langchain optional retriever

Browse files
backend/app/api/routes_chat.py CHANGED
@@ -2,7 +2,8 @@
2
 
3
  from app.core.llm import llm_chat
4
  from app.core.prompts import build_rag_prompt
5
- from app.models.api import ChatRequest, ChatResponse, Citation
 
6
  from app.retrieval.retrieve import hybrid_graph_search
7
  from fastapi import APIRouter
8
 
@@ -29,13 +30,12 @@ def chat(request: ChatRequest) -> ChatResponse:
29
 
30
  answer = llm_chat(messages=messages)
31
 
32
- citations = [
33
- Citation(
34
- page_start=sc.chunk.page_start,
35
- page_end=sc.chunk.page_end,
36
- snippet=sc.chunk.text[:300],
37
- )
38
- for sc in results
39
- ]
40
 
41
- return ChatResponse(answer=answer, citations=citations)
 
 
 
 
2
 
3
  from app.core.llm import llm_chat
4
  from app.core.prompts import build_rag_prompt
5
+ from app.models.api import ChatRequest, ChatResponse
6
+ from app.retrieval.citation_filter import filter_citations
7
  from app.retrieval.retrieve import hybrid_graph_search
8
  from fastapi import APIRouter
9
 
 
30
 
31
  answer = llm_chat(messages=messages)
32
 
33
+ citations = filter_citations(
34
+ answer=answer,
35
+ chunks=results,
36
+ )
 
 
 
 
37
 
38
+ return ChatResponse(
39
+ answer=answer,
40
+ citations=citations,
41
+ )
backend/app/api/routes_chat_langchain.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Chat routes using LangChain retriever."""
2
+
3
+ from app.config import settings
4
+ from app.models.api import ChatRequest, ChatResponse
5
+ from app.models.retrieval import ScoredChunk
6
+ from app.retrieval.citation_filter import filter_citations
7
+ from app.retrieval.langchain_retriever import AtlasGraphRetriever
8
+ from fastapi import APIRouter
9
+ from langchain.chains import RetrievalQA
10
+ from langchain_groq import ChatGroq
11
+
12
+ router = APIRouter()
13
+
14
+
15
+ @router.post("/ask/langchain", response_model=ChatResponse)
16
+ def chat_langchain(request: ChatRequest) -> ChatResponse:
17
+ """LangChain-powered RAG endpoint with citation filtering."""
18
+ retriever = AtlasGraphRetriever(top_k=request.top_k)
19
+
20
+ llm = ChatGroq(
21
+ api_key=settings.groq_api_key,
22
+ model=settings.default_model,
23
+ )
24
+
25
+ qa_chain = RetrievalQA.from_chain_type(
26
+ llm=llm,
27
+ retriever=retriever,
28
+ return_source_documents=True,
29
+ )
30
+
31
+ result = qa_chain.invoke({"query": request.query})
32
+
33
+ answer = result["result"]
34
+ source_docs = result.get("source_documents", [])
35
+
36
+ # Convert LangChain docs → ScoredChunk
37
+ scored_chunks = [
38
+ ScoredChunk(
39
+ chunk=doc.metadata["chunk"],
40
+ score=doc.metadata["score"],
41
+ )
42
+ for doc in source_docs
43
+ ]
44
+
45
+ citations = filter_citations(
46
+ answer=answer,
47
+ chunks=scored_chunks,
48
+ )
49
+
50
+ return ChatResponse(
51
+ answer=answer,
52
+ citations=citations,
53
+ )
backend/app/main.py CHANGED
@@ -1,6 +1,7 @@
1
  """Main FastAPI application for AtlasRAG backend."""
2
 
3
  from app.api.routes_chat import router as chat_router
 
4
  from app.api.routes_docs import router as docs_router
5
  from fastapi import FastAPI
6
  from fastapi.middleware.cors import CORSMiddleware
@@ -23,3 +24,4 @@ app.add_middleware(
23
  # Include routers
24
  app.include_router(chat_router, prefix="/chat")
25
  app.include_router(docs_router, prefix="/docs")
 
 
1
  """Main FastAPI application for AtlasRAG backend."""
2
 
3
  from app.api.routes_chat import router as chat_router
4
+ from app.api.routes_chat_langchain import router as chat_langchain_router
5
  from app.api.routes_docs import router as docs_router
6
  from fastapi import FastAPI
7
  from fastapi.middleware.cors import CORSMiddleware
 
24
  # Include routers
25
  app.include_router(chat_router, prefix="/chat")
26
  app.include_router(docs_router, prefix="/docs")
27
+ app.include_router(chat_langchain_router, prefix="/chat")
backend/app/retrieval/citation_filter.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Citation filtering utilities.
2
+
3
+ Selects only the sentences from retrieved chunks that
4
+ directly support the generated answer.
5
+ """
6
+
7
+ import re
8
+ from typing import List
9
+
10
+ from app.models.api import Citation
11
+ from app.models.retrieval import ScoredChunk
12
+ from sentence_transformers import SentenceTransformer, util
13
+
14
+ # Lightweight sentence embedder
15
+ _SENTENCE_MODEL = SentenceTransformer("all-MiniLM-L6-v2")
16
+
17
+ # Conservative threshold: avoids noise
18
+ _SIMILARITY_THRESHOLD = 0.45
19
+ _MAX_SENTENCES_PER_CHUNK = 2
20
+
21
+
22
+ def _split_sentences(text: str) -> List[str]:
23
+ """Split text into clean sentences."""
24
+ sentences = re.split(r"(?<=[.!?])\s+", text)
25
+ return [s.strip() for s in sentences if len(s.strip()) >= 20]
26
+
27
+
28
+ def filter_citations(
29
+ answer: str,
30
+ chunks: List[ScoredChunk],
31
+ ) -> List[Citation]:
32
+ """Filter citations to only answer-supporting sentences."""
33
+ if not answer.strip():
34
+ return []
35
+
36
+ answer_embedding = _SENTENCE_MODEL.encode(answer, normalize_embeddings=True)
37
+
38
+ filtered: List[Citation] = []
39
+
40
+ for sc in chunks:
41
+ sentences = _split_sentences(sc.chunk.text)
42
+ if not sentences:
43
+ continue
44
+
45
+ sentence_embeddings = _SENTENCE_MODEL.encode(
46
+ sentences,
47
+ normalize_embeddings=True,
48
+ )
49
+
50
+ similarities = util.cos_sim(answer_embedding, sentence_embeddings)[0]
51
+
52
+ # Collect best supporting sentences
53
+ selected_sentences: List[str] = []
54
+
55
+ for sent, score in zip(sentences, similarities):
56
+ if float(score) >= _SIMILARITY_THRESHOLD:
57
+ selected_sentences.append(sent)
58
+
59
+ if len(selected_sentences) >= _MAX_SENTENCES_PER_CHUNK:
60
+ break
61
+
62
+ if not selected_sentences:
63
+ continue
64
+
65
+ filtered.append(
66
+ Citation(
67
+ page_start=sc.chunk.page_start,
68
+ page_end=sc.chunk.page_end,
69
+ snippet=" ".join(selected_sentences),
70
+ )
71
+ )
72
+
73
+ return filtered
backend/app/retrieval/langchain_retriever.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """LangChain retriever wrapper for AtlasRAG."""
2
+
3
+ from typing import List
4
+
5
+ from app.retrieval.retrieve import hybrid_graph_search
6
+ from langchain_core.documents import Document
7
+ from langchain_core.retrievers import BaseRetriever
8
+
9
+
10
+ class AtlasGraphRetriever(BaseRetriever):
11
+ """LangChain-compatible retriever wrapping hybrid Graph-RAG."""
12
+
13
+ top_k: int = 5
14
+
15
+ def _get_relevant_documents(self, query: str) -> List[Document]:
16
+ """Retrieve documents for LangChain."""
17
+ results = hybrid_graph_search(query, self.top_k)
18
+
19
+ documents: List[Document] = []
20
+
21
+ for sc in results:
22
+ documents.append(
23
+ Document(
24
+ page_content=sc.chunk.text,
25
+ metadata={
26
+ "doc_id": sc.chunk.doc_id,
27
+ "page_start": sc.chunk.page_start,
28
+ "page_end": sc.chunk.page_end,
29
+ "chunk": sc.chunk,
30
+ "score": sc.score,
31
+ },
32
+ )
33
+ )
34
+
35
+ return documents