Spaces:
Sleeping
Sleeping
File size: 5,757 Bytes
9099bff a9ae09a 2cdef68 a9ae09a 9099bff a9ae09a f195ef5 a9ae09a 9099bff a9ae09a f195ef5 a9ae09a 9099bff a9ae09a 9099bff 4df4229 a9ae09a 9099bff a9ae09a f195ef5 a9ae09a f195ef5 a9ae09a f195ef5 a9ae09a f195ef5 9099bff a9ae09a 9099bff a9ae09a 4df4229 7009a3c f195ef5 2cdef68 f195ef5 2cdef68 f195ef5 4df4229 9099bff 4df4229 f195ef5 a9ae09a 4df4229 9099bff a9ae09a 9099bff a9ae09a 4df4229 9099bff a9ae09a f195ef5 9099bff 4df4229 9099bff 4df4229 9099bff 4df4229 f195ef5 a9ae09a f195ef5 9099bff a9ae09a 9099bff a9ae09a 4df4229 a9ae09a 4df4229 a9ae09a 4df4229 a9ae09a 4df4229 a9ae09a 4df4229 a9ae09a f195ef5 4df4229 9099bff 4df4229 9099bff 4df4229 9099bff 4df4229 9099bff 4df4229 f195ef5 4df4229 9099bff | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 | import logging
from typing import TypedDict
from dotenv import load_dotenv
from langchain_community.retrievers import PineconeHybridSearchRetriever
from langchain_core.documents import Document
from langchain_openai import OpenAIEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langgraph.graph import END, START, StateGraph
from pinecone import Pinecone
from pinecone_text.sparse import BM25Encoder
from src.common.settings import cfg
from src.model.citations import answer_citations, format_docs_with_id
from src.model.hallucination import hallucination_grader
from src.model.moderation import moderate
_ = load_dotenv()
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
class SearchState(TypedDict):
query: str
documents: list[str]
class GenerationState(TypedDict):
query: str
document: str
generation: str
chunks: list[dict]
hallucination: str
inappropriate: str
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=512,
separators=["\n\n", "\n", ". "],
keep_separator=False,
)
def _group_by_document(documents):
grouped_id: dict[str, list[Document]] = {}
for d in documents:
id = d.metadata["id"]
if id not in grouped_id:
grouped_id[id] = []
grouped_id[id].append(d)
out_nodes = []
for doc in grouped_id.values():
content = "\n\n".join([d.page_content for d in doc])
scores = [d.metadata["score"] for d in doc]
document = Document(
page_content=content, metadata=doc[0].metadata | {"score": max(scores)}
)
out_nodes.append(document)
return out_nodes
def create_retriever():
bm25_encoder = BM25Encoder().load("bm25/bm25_values.json")
pc = Pinecone()
index = pc.Index(cfg.datastore.index_name, host=cfg.datastore.host)
embeddings = OpenAIEmbeddings(model=cfg.datastore.embed_model)
return PineconeHybridSearchRetriever(
embeddings=embeddings,
sparse_encoder=bm25_encoder,
index=index,
top_k=cfg.model.top_k,
alpha=cfg.model.alpha,
)
def retrieve(state, retriever):
logging.info("Starting retrieval process...")
query = state["query"]
documents = retriever.invoke(query)
documents = _group_by_document(documents)
return {"documents": documents, "query": query}
def explain_dataset(state):
logging.info("Starting explain generation...")
query = state["query"]
document = state["document"]
chunks = text_splitter.split_documents([document])
docs = format_docs_with_id(chunks)
generation = answer_citations.invoke({"query": query, "context": docs})
return {
"query": query,
"document": document,
"chunks": [c.dict() for c in chunks],
} | generation
def moderate_generation(state):
logging.info("Starting moderation...")
generation = state["generation"]
moderation = moderate.invoke(generation)
if moderation["output"] != generation:
logging.warning("Inappropriate content found in generation")
state["generation"] = "Inappropriate content found in generation."
state["inappropriate"] = generation
else:
logging.info("Generation content is appropriate")
return state
def check_hallucination(state):
logging.info("Starting hallucination check process...")
query = state["query"]
document = state["document"]
generation = state["generation"]
score = hallucination_grader.invoke(
{"document": document, "generation": generation}
)
if score.binary_score == "yes":
logging.info("No hallucination found in generation")
state["generation"] = generation
else:
logging.warning("Hallucination found in generation")
state["generation"] = "Hallucination found in generation."
return state
def skip_hallucination(state):
if state["generation"] == "Inappropriate content found in generation.":
return END
else:
return "check_hallucination"
def search_graph():
retriever = create_retriever()
workflow = StateGraph(SearchState)
workflow.add_node("retrieve", lambda state: retrieve(state, retriever))
# workflow.add_node("compress", lambda state: compress(state, retriever))
workflow.add_edge(START, "retrieve")
workflow.add_edge("retrieve", END)
# workflow.add_edge("retrieve", "compress")
# workflow.add_edge("compress", END)
return workflow.compile()
def generation_graph():
workflow = StateGraph(GenerationState)
workflow.add_node("explain_dataset", explain_dataset)
workflow.add_node("moderate_generation", moderate_generation)
workflow.add_node("check_hallucination", check_hallucination)
workflow.add_edge(START, "explain_dataset")
workflow.add_edge("explain_dataset", "moderate_generation")
workflow.add_conditional_edges("moderate_generation", skip_hallucination)
workflow.add_edge("check_hallucination", END)
return workflow.compile()
def search(query, thread_id):
search = search_graph()
output = search.invoke(
{"query": query}, config={"configurable": {"thread_id": thread_id}}
)
logging.info("Search done")
return output
def generate(query, document, thread_id):
gen = generation_graph()
output = gen.invoke(
{"query": query, "document": document},
config={"configurable": {"thread_id": thread_id}},
)
logging.info("Generation done")
return output
if __name__ == "__main__":
query = "farming in estonia"
out = search(query=query, thread_id="1234")
out_gen = generate(query=query, document=out["documents"][0], thread_id="1234")
|