mtyrrell's picture
ingestor fix
6329a68
"""
LangGraph orchestration nodes for retrieval and generation
NEEDS TO BE UPDATED
"""
import logging
logger = logging.getLogger(__name__)
from datetime import datetime
import json
from typing import TYPE_CHECKING
from langchain_core.documents import Document
from .telemetry import extract_retriever_telemetry
from components.ingestor.ingestor import process_document
# Assuming these Type definitions are available from state.py and retriever_orchestrator.py
if TYPE_CHECKING:
from components.retriever.retriever_orchestrator import ChaBoHFEndpointRetriever
from components.generator.generator_orchestrator import Generator
from components.orchestration.state import GraphState
async def retrieve_node(
state: 'GraphState',
retriever: 'ChaBoHFEndpointRetriever' # Injected LangChain BaseRetriever instance
) -> 'GraphState':
"""
Node to retrieve relevant context using the ChaBoHFEndpointRetriever.
The retriever performs Embed -> Search -> Rerank in one async call.
"""
start_time = datetime.now()
# 1. Extract Query and Filters
filters = state.get("metadata_filters")
metadata = state.get("metadata", {})
logger.info(f"Retrieval: {state['query'][:50]}...")
raw_documents: list[Document] = []
try:
retriever_kwargs = {}
if filters:
retriever_kwargs['filters'] = filters
raw_documents = await retriever.ainvoke(
input=state['query'],
**retriever_kwargs
)
duration = (datetime.now() - start_time).total_seconds()
retriever_config = {
"initial_k": retriever.initial_k,
"final_k": retriever.final_k,
"qdrant_mode": retriever.qdrant_mode,
}
retriever_telemetry = extract_retriever_telemetry(raw_documents, retriever_config)
metadata.update({
"retrieval_duration": duration,
"filters_applied": json.dumps(filters) if filters else "None",
"retriever_config": retriever_telemetry,
"retrieval_success": True
})
return {
"raw_documents": raw_documents,
"metadata": metadata
}
except Exception as e:
duration = (datetime.now() - start_time).total_seconds()
logger.error(f"Retrieval failed: {str(e)}", exc_info=True)
metadata.update({
"retrieval_duration": duration,
"retrieval_success": False,
"retrieval_error": str(e)
})
return {"raw_documents": [], "metadata": metadata}
async def generate_node_streaming(state: "GraphState", generator: "Generator", *, writer):
"""
Node to generate the final response with StreamWriter for LangGraph custom streaming.
Uses StreamWriter to emit events that LangGraph can capture with stream_mode="custom".
"""
start_time = datetime.now()
query = state.get("query")
raw_docs = state.get("raw_documents", [])
metadata = state.get("metadata", {})
ingestor_context = state.get("ingestor_context")
# If we have ingestor_context, prepend it to raw_docs as a Document
if ingestor_context:
ingestor_doc = Document(
page_content=ingestor_context,
metadata={"source": "uploaded_file", "filename": state.get("filename", "unknown")}
)
raw_docs = [ingestor_doc] + raw_docs
logger.info(f"Including ingestor context ({len(ingestor_context)} chars) with retrieved docs")
accumulated_text = ""
logger.info(f"Generation: {query[:50]}... ({len(raw_docs)} docs)")
conversation_context = state.get("conversation_context")
try:
async for event in generator.generate_streaming(
query=query,
context=raw_docs,
chatui_format=True,
conversation_context=conversation_context
):
# Track content to calculate metadata (length) at the end
if event.get("event") == "data":
accumulated_text += event.get("data", "")
# Use StreamWriter to emit custom events
writer(event)
# Final Telemetry Update
duration = (datetime.now() - start_time).total_seconds()
metadata.update({
"generation_duration": duration,
"generation_success": True,
"response_length": len(accumulated_text)
})
logger.info(f"Streaming complete in {duration:.2f}s. Length: {len(accumulated_text)}")
except Exception as e:
duration = (datetime.now() - start_time).total_seconds()
logger.error(f"Generation node failed: {e}", exc_info=True)
metadata.update({
"generation_duration": duration,
"generation_success": False,
"generation_error": str(e)
})
writer({"event": "error", "data": {"error": str(e)}})
async def ingest_node(state: 'GraphState') -> 'GraphState':
"""
Node to process uploaded documents (PDF, DOCX) and extract chunked context.
Only runs if file_content and filename are present in state.
"""
start_time = datetime.now()
file_content = state.get("file_content")
filename = state.get("filename")
metadata = state.get("metadata", {})
# Skip if no file uploaded
if not file_content or not filename:
logger.info("No file to ingest, skipping ingest_node")
return {}
logger.info(f"Ingesting document: {filename}")
try:
# Process document and get chunked context
ingestor_context = process_document(file_content, filename)
duration = (datetime.now() - start_time).total_seconds()
metadata.update({
"ingest_duration": duration,
"ingest_success": True,
"ingested_filename": filename,
"ingestor_context_length": len(ingestor_context)
})
logger.info(f"Document ingested successfully in {duration:.2f}s")
return {
"ingestor_context": ingestor_context,
"metadata": metadata
}
except Exception as e:
duration = (datetime.now() - start_time).total_seconds()
logger.error(f"Document ingestion failed: {str(e)}", exc_info=True)
metadata.update({
"ingest_duration": duration,
"ingest_success": False,
"ingest_error": str(e)
})
return {"ingestor_context": "", "metadata": metadata}
# from .state import GraphState
# if TYPE_CHECKING:
# from components.retriever.retriever_orchestrator import RetrieverOrchestrator
# from components.orchestration.state import GraphState
# async def retrieve_node(
# state: GraphState,
# retriever: 'RetrieverOrchestrator' # Injected service instance
# ) -> GraphState:
# """Retrieve relevant context using adapter"""
# start_time = datetime.now()
# logger.info(f"Retrieval: {state['query'][:50]}...")
# context = ""
# try:
# # Get filters from state (provided by ChatUI or LLM agent)
# filters = state.get("metadata_filters")
# # --- FILLED CODE START ---
# # Call the async method on the injected service instance
# # The retriever orchestrator handles the remote API call to the Reranker/Embedder service
# context_docs, retriever_meta = await retriever.aretrieve(
# query=latest_message,
# filters=filters
# )
# # Format the retrieved documents into a single context string
# # (This is commonly done here or inside the orchestrator)
# context = "\n---\n".join([doc.page_content for doc in context_docs])
# # --- FILLED CODE END ---
# duration = (datetime.now() - start_time).total_seconds()
# metadata = state.get("metadata", {})
# # Update metadata and append retriever-specific metadata
# metadata.update({
# "retrieval_duration": duration,
# "context_length": len(context) if context else 0,
# "retrieval_success": True,
# "filters_applied": filters,
# "retriever_config": retriever_meta, # Add metadata from retriever call
# })
# # Return the updated state
# return {"context": context, "metadata": metadata}
# except Exception as e:
# # ... (Error handling logic is good, no change needed) ...
# duration = (datetime.now() - start_time).total_seconds()
# logger.error(f"Retrieval failed: {str(e)}")
# metadata = state.get("metadata", {})
# metadata.update({
# "retrieval_duration": duration,
# "retrieval_success": False,
# "retrieval_error": str(e)
# })
# # Note: We return context as an empty string on failure to avoid cascading errors
# return {"context": "", "metadata": metadata}
# async def retrieve_node(state: GraphState) -> GraphState:
# """Retrieve relevant context using adapter"""
# start_time = datetime.now()
# logger.info(f"Retrieval: {state['query'][:50]}...")
# try:
# # Get filters from state (provided by ChatUI or LLM agent)
# filters = state.get("metadata_filters")
# # instantiate the retirever instance
# # get context using aysnc call
# duration = (datetime.now() - start_time).total_seconds()
# metadata = state.get("metadata", {})
# metadata.update({
# "retrieval_duration": duration,
# "context_length": len(context) if context else 0,
# "retrieval_success": True,
# "filters_applied": filters,
# "retriever_config": # get metadata from retirever
# })
# return {"context": context, "metadata": metadata}
# except Exception as e:
# duration = (datetime.now() - start_time).total_seconds()
# logger.error(f"Retrieval failed: {str(e)}")
# metadata = state.get("metadata", {})
# metadata.update({
# "retrieval_duration": duration,
# "retrieval_success": False,
# "retrieval_error": str(e)
# })
# return {"context": "", "metadata": metadata}