""" LangGraph orchestration nodes for retrieval and generation NEEDS TO BE UPDATED """ import logging logger = logging.getLogger(__name__) from datetime import datetime import json from typing import TYPE_CHECKING from langchain_core.documents import Document from .telemetry import extract_retriever_telemetry from components.ingestor.ingestor import process_document # Assuming these Type definitions are available from state.py and retriever_orchestrator.py if TYPE_CHECKING: from components.retriever.retriever_orchestrator import ChaBoHFEndpointRetriever from components.generator.generator_orchestrator import Generator from components.orchestration.state import GraphState async def retrieve_node( state: 'GraphState', retriever: 'ChaBoHFEndpointRetriever' # Injected LangChain BaseRetriever instance ) -> 'GraphState': """ Node to retrieve relevant context using the ChaBoHFEndpointRetriever. The retriever performs Embed -> Search -> Rerank in one async call. """ start_time = datetime.now() # 1. Extract Query and Filters filters = state.get("metadata_filters") metadata = state.get("metadata", {}) logger.info(f"Retrieval: {state['query'][:50]}...") raw_documents: list[Document] = [] try: retriever_kwargs = {} if filters: retriever_kwargs['filters'] = filters raw_documents = await retriever.ainvoke( input=state['query'], **retriever_kwargs ) duration = (datetime.now() - start_time).total_seconds() retriever_config = { "initial_k": retriever.initial_k, "final_k": retriever.final_k, "qdrant_mode": retriever.qdrant_mode, } retriever_telemetry = extract_retriever_telemetry(raw_documents, retriever_config) metadata.update({ "retrieval_duration": duration, "filters_applied": json.dumps(filters) if filters else "None", "retriever_config": retriever_telemetry, "retrieval_success": True }) return { "raw_documents": raw_documents, "metadata": metadata } except Exception as e: duration = (datetime.now() - start_time).total_seconds() logger.error(f"Retrieval failed: {str(e)}", exc_info=True) metadata.update({ "retrieval_duration": duration, "retrieval_success": False, "retrieval_error": str(e) }) return {"raw_documents": [], "metadata": metadata} async def generate_node_streaming(state: "GraphState", generator: "Generator", *, writer): """ Node to generate the final response with StreamWriter for LangGraph custom streaming. Uses StreamWriter to emit events that LangGraph can capture with stream_mode="custom". """ start_time = datetime.now() query = state.get("query") raw_docs = state.get("raw_documents", []) metadata = state.get("metadata", {}) ingestor_context = state.get("ingestor_context") # If we have ingestor_context, prepend it to raw_docs as a Document if ingestor_context: ingestor_doc = Document( page_content=ingestor_context, metadata={"source": "uploaded_file", "filename": state.get("filename", "unknown")} ) raw_docs = [ingestor_doc] + raw_docs logger.info(f"Including ingestor context ({len(ingestor_context)} chars) with retrieved docs") accumulated_text = "" logger.info(f"Generation: {query[:50]}... ({len(raw_docs)} docs)") conversation_context = state.get("conversation_context") try: async for event in generator.generate_streaming( query=query, context=raw_docs, chatui_format=True, conversation_context=conversation_context ): # Track content to calculate metadata (length) at the end if event.get("event") == "data": accumulated_text += event.get("data", "") # Use StreamWriter to emit custom events writer(event) # Final Telemetry Update duration = (datetime.now() - start_time).total_seconds() metadata.update({ "generation_duration": duration, "generation_success": True, "response_length": len(accumulated_text) }) logger.info(f"Streaming complete in {duration:.2f}s. Length: {len(accumulated_text)}") except Exception as e: duration = (datetime.now() - start_time).total_seconds() logger.error(f"Generation node failed: {e}", exc_info=True) metadata.update({ "generation_duration": duration, "generation_success": False, "generation_error": str(e) }) writer({"event": "error", "data": {"error": str(e)}}) async def ingest_node(state: 'GraphState') -> 'GraphState': """ Node to process uploaded documents (PDF, DOCX) and extract chunked context. Only runs if file_content and filename are present in state. """ start_time = datetime.now() file_content = state.get("file_content") filename = state.get("filename") metadata = state.get("metadata", {}) # Skip if no file uploaded if not file_content or not filename: logger.info("No file to ingest, skipping ingest_node") return {} logger.info(f"Ingesting document: {filename}") try: # Process document and get chunked context ingestor_context = process_document(file_content, filename) duration = (datetime.now() - start_time).total_seconds() metadata.update({ "ingest_duration": duration, "ingest_success": True, "ingested_filename": filename, "ingestor_context_length": len(ingestor_context) }) logger.info(f"Document ingested successfully in {duration:.2f}s") return { "ingestor_context": ingestor_context, "metadata": metadata } except Exception as e: duration = (datetime.now() - start_time).total_seconds() logger.error(f"Document ingestion failed: {str(e)}", exc_info=True) metadata.update({ "ingest_duration": duration, "ingest_success": False, "ingest_error": str(e) }) return {"ingestor_context": "", "metadata": metadata} # from .state import GraphState # if TYPE_CHECKING: # from components.retriever.retriever_orchestrator import RetrieverOrchestrator # from components.orchestration.state import GraphState # async def retrieve_node( # state: GraphState, # retriever: 'RetrieverOrchestrator' # Injected service instance # ) -> GraphState: # """Retrieve relevant context using adapter""" # start_time = datetime.now() # logger.info(f"Retrieval: {state['query'][:50]}...") # context = "" # try: # # Get filters from state (provided by ChatUI or LLM agent) # filters = state.get("metadata_filters") # # --- FILLED CODE START --- # # Call the async method on the injected service instance # # The retriever orchestrator handles the remote API call to the Reranker/Embedder service # context_docs, retriever_meta = await retriever.aretrieve( # query=latest_message, # filters=filters # ) # # Format the retrieved documents into a single context string # # (This is commonly done here or inside the orchestrator) # context = "\n---\n".join([doc.page_content for doc in context_docs]) # # --- FILLED CODE END --- # duration = (datetime.now() - start_time).total_seconds() # metadata = state.get("metadata", {}) # # Update metadata and append retriever-specific metadata # metadata.update({ # "retrieval_duration": duration, # "context_length": len(context) if context else 0, # "retrieval_success": True, # "filters_applied": filters, # "retriever_config": retriever_meta, # Add metadata from retriever call # }) # # Return the updated state # return {"context": context, "metadata": metadata} # except Exception as e: # # ... (Error handling logic is good, no change needed) ... # duration = (datetime.now() - start_time).total_seconds() # logger.error(f"Retrieval failed: {str(e)}") # metadata = state.get("metadata", {}) # metadata.update({ # "retrieval_duration": duration, # "retrieval_success": False, # "retrieval_error": str(e) # }) # # Note: We return context as an empty string on failure to avoid cascading errors # return {"context": "", "metadata": metadata} # async def retrieve_node(state: GraphState) -> GraphState: # """Retrieve relevant context using adapter""" # start_time = datetime.now() # logger.info(f"Retrieval: {state['query'][:50]}...") # try: # # Get filters from state (provided by ChatUI or LLM agent) # filters = state.get("metadata_filters") # # instantiate the retirever instance # # get context using aysnc call # duration = (datetime.now() - start_time).total_seconds() # metadata = state.get("metadata", {}) # metadata.update({ # "retrieval_duration": duration, # "context_length": len(context) if context else 0, # "retrieval_success": True, # "filters_applied": filters, # "retriever_config": # get metadata from retirever # }) # return {"context": context, "metadata": metadata} # except Exception as e: # duration = (datetime.now() - start_time).total_seconds() # logger.error(f"Retrieval failed: {str(e)}") # metadata = state.get("metadata", {}) # metadata.update({ # "retrieval_duration": duration, # "retrieval_success": False, # "retrieval_error": str(e) # }) # return {"context": "", "metadata": metadata}