import logging from typing import TypedDict from dotenv import load_dotenv from langchain_community.retrievers import PineconeHybridSearchRetriever from langchain_core.documents import Document from langchain_openai import OpenAIEmbeddings from langchain_text_splitters import RecursiveCharacterTextSplitter from langgraph.graph import END, START, StateGraph from pinecone import Pinecone from pinecone_text.sparse import BM25Encoder from src.common.settings import cfg from src.model.citations import answer_citations, format_docs_with_id from src.model.hallucination import hallucination_grader from src.model.moderation import moderate _ = load_dotenv() logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) class SearchState(TypedDict): query: str documents: list[str] class GenerationState(TypedDict): query: str document: str generation: str chunks: list[dict] hallucination: str inappropriate: str text_splitter = RecursiveCharacterTextSplitter( chunk_size=512, separators=["\n\n", "\n", ". "], keep_separator=False, ) def _group_by_document(documents): grouped_id: dict[str, list[Document]] = {} for d in documents: id = d.metadata["id"] if id not in grouped_id: grouped_id[id] = [] grouped_id[id].append(d) out_nodes = [] for doc in grouped_id.values(): content = "\n\n".join([d.page_content for d in doc]) scores = [d.metadata["score"] for d in doc] document = Document( page_content=content, metadata=doc[0].metadata | {"score": max(scores)} ) out_nodes.append(document) return out_nodes def create_retriever(): bm25_encoder = BM25Encoder().load("bm25/bm25_values.json") pc = Pinecone() index = pc.Index(cfg.datastore.index_name, host=cfg.datastore.host) embeddings = OpenAIEmbeddings(model=cfg.datastore.embed_model) return PineconeHybridSearchRetriever( embeddings=embeddings, sparse_encoder=bm25_encoder, index=index, top_k=cfg.model.top_k, alpha=cfg.model.alpha, ) def retrieve(state, retriever): logging.info("Starting retrieval process...") query = state["query"] documents = retriever.invoke(query) documents = _group_by_document(documents) return {"documents": documents, "query": query} def explain_dataset(state): logging.info("Starting explain generation...") query = state["query"] document = state["document"] chunks = text_splitter.split_documents([document]) docs = format_docs_with_id(chunks) generation = answer_citations.invoke({"query": query, "context": docs}) return { "query": query, "document": document, "chunks": [c.dict() for c in chunks], } | generation def moderate_generation(state): logging.info("Starting moderation...") generation = state["generation"] moderation = moderate.invoke(generation) if moderation["output"] != generation: logging.warning("Inappropriate content found in generation") state["generation"] = "Inappropriate content found in generation." state["inappropriate"] = generation else: logging.info("Generation content is appropriate") return state def check_hallucination(state): logging.info("Starting hallucination check process...") query = state["query"] document = state["document"] generation = state["generation"] score = hallucination_grader.invoke( {"document": document, "generation": generation} ) if score.binary_score == "yes": logging.info("No hallucination found in generation") state["generation"] = generation else: logging.warning("Hallucination found in generation") state["generation"] = "Hallucination found in generation." return state def skip_hallucination(state): if state["generation"] == "Inappropriate content found in generation.": return END else: return "check_hallucination" def search_graph(): retriever = create_retriever() workflow = StateGraph(SearchState) workflow.add_node("retrieve", lambda state: retrieve(state, retriever)) # workflow.add_node("compress", lambda state: compress(state, retriever)) workflow.add_edge(START, "retrieve") workflow.add_edge("retrieve", END) # workflow.add_edge("retrieve", "compress") # workflow.add_edge("compress", END) return workflow.compile() def generation_graph(): workflow = StateGraph(GenerationState) workflow.add_node("explain_dataset", explain_dataset) workflow.add_node("moderate_generation", moderate_generation) workflow.add_node("check_hallucination", check_hallucination) workflow.add_edge(START, "explain_dataset") workflow.add_edge("explain_dataset", "moderate_generation") workflow.add_conditional_edges("moderate_generation", skip_hallucination) workflow.add_edge("check_hallucination", END) return workflow.compile() def search(query, thread_id): search = search_graph() output = search.invoke( {"query": query}, config={"configurable": {"thread_id": thread_id}} ) logging.info("Search done") return output def generate(query, document, thread_id): gen = generation_graph() output = gen.invoke( {"query": query, "document": document}, config={"configurable": {"thread_id": thread_id}}, ) logging.info("Generation done") return output if __name__ == "__main__": query = "farming in estonia" out = search(query=query, thread_id="1234") out_gen = generate(query=query, document=out["documents"][0], thread_id="1234")