from __future__ import annotations from time import perf_counter from typing import Any, Dict, List, Optional, TypedDict from langchain_core.runnables.config import RunnableConfig from langgraph.graph import END, StateGraph from app.core.config import get_settings from app.core.errors import UpstreamServiceError from app.core.logging import get_logger from app.schemas.chat import ChatRequest from app.services.llm.groq_llm import get_llm from app.services.prompts.rag_prompt import build_rag_messages from app.services.pinecone_store import search as pinecone_search from app.services.tools.tavily_tool import get_tavily_tool, is_tavily_configured logger = get_logger(__name__) class ChatState(TypedDict, total=False): query: str namespace: str top_k: int min_score: float use_web_fallback: bool max_web_results: int chat_history: List[Dict[str, str]] retrieved: List[Dict[str, Any]] web_results: List[Dict[str, Any]] answer: str timings: Dict[str, float] tavily_available: bool web_fallback_used: bool top_score: float def _ensure_timings(state: ChatState) -> Dict[str, float]: timings = state.get("timings") or {} if not isinstance(timings, dict): timings = {} state["timings"] = timings return timings # type: ignore[return-value] def normalize_input(state: ChatState, _config: RunnableConfig | None = None) -> ChatState: """Normalise input state with default values from settings.""" settings = get_settings() namespace = state.get("namespace") or settings.PINECONE_NAMESPACE top_k = int(state.get("top_k") or settings.RAG_DEFAULT_TOP_K) min_score = float(state.get("min_score") or settings.RAG_MIN_SCORE) max_web_results = int(state.get("max_web_results") or settings.RAG_MAX_WEB_RESULTS) chat_history = state.get("chat_history") or [] # Normalise chat_history into a list of {role, content} dicts normalized_history: List[Dict[str, str]] = [] for item in chat_history: role = item.get("role", "user") content = item.get("content", "") if content: normalized_history.append({"role": role, "content": content}) new_state: ChatState = { **state, "namespace": namespace, "top_k": top_k, "min_score": min_score, "max_web_results": max_web_results, "chat_history": normalized_history, "retrieved": [], "web_results": [], "timings": state.get("timings") or {}, "tavily_available": is_tavily_configured(), "web_fallback_used": False, } logger.info( "Chat graph input normalised namespace='%s' top_k=%d min_score=%.3f " "use_web_fallback=%s max_web_results=%d tavily_available=%s", new_state["namespace"], new_state["top_k"], new_state["min_score"], bool(new_state["use_web_fallback"]), new_state["max_web_results"], new_state["tavily_available"], ) return new_state def retrieve_context(state: ChatState, _config: RunnableConfig | None = None) -> ChatState: """Retrieve relevant document chunks from Pinecone.""" settings = get_settings() timings = _ensure_timings(state) start = perf_counter() raw_hits: List[Dict[str, Any]] = pinecone_search( namespace=state["namespace"], query_text=state["query"], top_k=state["top_k"], filters=None, fields=None, ) elapsed_ms = (perf_counter() - start) * 1000.0 timings["retrieve_ms"] = elapsed_ms state["timings"] = timings text_field = settings.PINECONE_TEXT_FIELD retrieved: List[Dict[str, Any]] = [] top_score = 0.0 for hit in raw_hits: hit_score = float(hit.get("_score") or hit.get("score") or 0.0) fields: Dict[str, Any] = hit.get("fields") or {} raw_text = fields.get(text_field, "") or "" # Map the configured text field into a stable chunk_text key chunk_text = str(raw_text) title = str(fields.get("title") or "") source = str(fields.get("source") or "unknown") url = str(fields.get("url") or "") retrieved.append( { "source": source, "title": title, "url": url, "score": hit_score, "chunk_text": chunk_text, } ) top_score = max(top_score, hit_score) state["retrieved"] = retrieved state["top_score"] = top_score logger.info( "Pinecone retrieval completed namespace='%s' top_k=%d hits=%d top_score=%.4f", state["namespace"], state["top_k"], len(retrieved), top_score, ) return state def decide_next(state: ChatState, _config: RunnableConfig | None = None) -> ChatState: """Decide whether to proceed with web search or answer generation.""" use_web = bool(state.get("use_web_fallback")) tavily_available = bool(state.get("tavily_available")) retrieved = state.get("retrieved") or [] min_score = float(state.get("min_score") or 0.0) top_score = float(state.get("top_score") or 0.0) should_use_web = False if use_web and tavily_available: if not retrieved: should_use_web = True elif top_score < min_score: should_use_web = True state["web_fallback_used"] = should_use_web logger.info( "Chat routing decision use_web=%s tavily_available=%s " "retrieved=%d top_score=%.4f min_score=%.4f", should_use_web, tavily_available, len(retrieved), top_score, min_score, ) return state def _route_after_decide_next(state: ChatState) -> str: """Conditional routing function for LangGraph.""" if state.get("web_fallback_used"): return "web_search" return "generate_answer" def web_search(state: ChatState, config: RunnableConfig | None = None) -> ChatState: """Perform Tavily web search and convert results into pseudo-doc chunks.""" timings = _ensure_timings(state) max_results = int(state.get("max_web_results") or 5) tool = get_tavily_tool(max_results=max_results) if tool is None: logger.warning("Tavily tool unavailable; skipping web search.") timings.setdefault("web_ms", 0.0) state["timings"] = timings state["web_results"] = [] return state start = perf_counter() try: # The TavilySearchResults tool is a Runnable, so we can pass config for tracing. results: Any = tool.invoke({"query": state["query"]}, config=config or {}) except Exception as exc: # noqa: BLE001 elapsed_ms = (perf_counter() - start) * 1000.0 timings["web_ms"] = elapsed_ms state["timings"] = timings logger.error("Tavily web search failed: %s", exc) raise UpstreamServiceError( service="Tavily", message="Upstream Tavily web search failed. Try again later or disable web fallback.", ) from exc elapsed_ms = (perf_counter() - start) * 1000.0 timings["web_ms"] = elapsed_ms state["timings"] = timings web_hits: List[Dict[str, Any]] = [] # TavilySearchResults returns a list of dicts by default. if isinstance(results, list): iterable = results else: iterable = getattr(results, "data", []) or [] for item in iterable: if not isinstance(item, dict): continue url = str(item.get("url") or "") title = str(item.get("title") or "") or url content = str(item.get("content") or item.get("snippet") or "") web_hits.append( { "source": "web", "title": title, "url": url, "score": 0.0, "chunk_text": content, } ) logger.info( "Tavily web search completed results=%d elapsed_ms=%.2f", len(web_hits), elapsed_ms, ) state["web_results"] = web_hits return state def generate_answer(state: ChatState, config: RunnableConfig | None = None) -> ChatState: """Generate an answer using the Groq-backed chat model.""" timings = _ensure_timings(state) messages = build_rag_messages( chat_history=state.get("chat_history") or [], question=state["query"], sources=(state.get("retrieved") or []) + (state.get("web_results") or []), ) llm = get_llm() start = perf_counter() try: response = llm.invoke(messages, config=config or {}) except Exception as exc: # noqa: BLE001 elapsed_ms = (perf_counter() - start) * 1000.0 timings["generate_ms"] = elapsed_ms state["timings"] = timings logger.error("Groq chat completion failed: %s", exc) raise UpstreamServiceError( service="Groq", message="Upstream Groq chat completion failed. Please try again later.", ) from exc elapsed_ms = (perf_counter() - start) * 1000.0 timings["generate_ms"] = elapsed_ms state["timings"] = timings answer_text: str try: answer_text = str(getattr(response, "content", "") or response) except Exception: # noqa: BLE001 answer_text = str(response) state["answer"] = answer_text logger.info("Answer generation completed elapsed_ms=%.2f", elapsed_ms) return state def format_response(state: ChatState, _config: RunnableConfig | None = None) -> ChatState: """No-op node reserved for future formatting; currently returns state.""" # This node exists mainly to keep the graph structure explicit and ready # for future formatting steps (e.g. re-ranking or response post-processing). return state _graph: Optional[Any] = None def get_chat_graph() -> Any: """Return the compiled LangGraph chat graph (lazy singleton).""" global _graph if _graph is not None: return _graph workflow: StateGraph = StateGraph(ChatState) workflow.add_node("normalize_input", normalize_input) workflow.add_node("retrieve_context", retrieve_context) workflow.add_node("decide_next", decide_next) workflow.add_node("web_search", web_search) workflow.add_node("generate_answer", generate_answer) workflow.add_node("format_response", format_response) workflow.set_entry_point("normalize_input") workflow.add_edge("normalize_input", "retrieve_context") workflow.add_edge("retrieve_context", "decide_next") workflow.add_conditional_edges( "decide_next", _route_after_decide_next, { "web_search": "web_search", "generate_answer": "generate_answer", }, ) workflow.add_edge("web_search", "generate_answer") workflow.add_edge("generate_answer", "format_response") workflow.add_edge("format_response", END) _graph = workflow.compile() logger.info("Chat LangGraph compiled and initialised.") return _graph