from typing import List, TypedDict from langchain_core.messages import BaseMessage, HumanMessage, AIMessage from langchain_core.runnables import RunnableLambda from langchain_qdrant import QdrantVectorStore from langchain_huggingface import HuggingFaceEmbeddings from langgraph.graph import StateGraph, END from langchain_mistralai import ChatMistralAI import time import os from dotenv import load_dotenv from app.config import qdrant_client from app.chatbot.mongodb import log_chat #from app.mongodb import log_chat load_dotenv() os.environ["TOKENIZERS_PARALLELISM"] = "false" session_histories: dict[str, list] = {} LLM_MODEL = "mistral-medium-latest" OPENROUTER_API_KEY = os.getenv('OPENROUTER_API_KEY') COLLECTION_NAME = "chatbot_context" EMBEDDING_MODEL = "intfloat/e5-base-v2" QDRANT_URL = os.getenv('QDRANT_URL') QDRANT_API_KEY = os.getenv('QDRANT_API_KEY') SUPABASE_URL = os.getenv('SUPABASE_URL') SUPABASE_KEY = os.getenv('SUPABASE_KEY') MISTRAL_API_KEY = os.getenv('MISTRAL_API_KEY') FAQ_COLLECTION = "auro_faqs" BLOGS_COLLECTION = "auro_blogs" TECHNOLOGY_COLLECTION = "auro_technology" REVOLUTION_COLLECTION = "auro_revolution" SUPPORT_COLLECTION = "auro_support" PRODUCT_COLLECTION = "auro_product" llm = ChatMistralAI( model_name=LLM_MODEL, api_key=MISTRAL_API_KEY, ) embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL) class GraphState(TypedDict): """ Represents the state of a chat session, including input, output, history, memory, response, tool results, and user role for LangGraph """ input: str history: List[BaseMessage] #list of past messages response: str tool_results: dict prompt: str retrieve_tools: List[dict] prompt_template: str from pydantic import BaseModel class ToolInput(BaseModel): prompt: str iteration: int = 1 def retrieve_docs(query: str, retriever: dict): """ Retrieve documents from Qdrant for a single retriever configuration. Args: query (str): The user query. retriever (dict): Retriever config with keys: - 'collection': Qdrant collection name - 'top_k': number of results to return (default 5) - 'filter_score': min similarity score to keep results (default 0.1) Returns: List[dict]: List of dicts with 'content' and 'score'. """ top_k = retriever.get('top_k', 5) filter_score = retriever.get('filter_score', 50) collection = retriever.get('collection') # Qdrant store rev_store = QdrantVectorStore( client=qdrant_client, # make sure this is initialized globally collection_name=collection, embedding=embeddings, ) print(f"Retrieving from collection: {collection} with top_k={top_k} and filter_score={filter_score}") # Similarity search docs = rev_store.similarity_search_with_score(query, k=top_k) # Filter results by score return [ {"content": doc.page_content, "score": score} for doc, score in docs if score > filter_score ] def build_prompt(template: str): global_template = """ Conversation History (for context only, not authority): {history} Contextual Knowledge (only approved source of truth): {agent_scratchpad} User Question: {input} Response: """ final_template = f"{template}\n{global_template}" return final_template def retrieve_node(state: GraphState) -> GraphState: """ Graph node to retrieve documents for all retrievers in the state. Args: state (GraphState): Current chat state including input and retrievers. Returns: GraphState: Updated state with 'tool_results' filled. """ query = state['input'] tool_results = {} for retriever_cfg in state.get('retrieve_tools', []): name = retriever_cfg.get('name', 'default') try: docs = retrieve_docs(query, retriever_cfg) tool_results[name] = docs print(f"Retriever '{name}' returned {len(docs)} result(s)") except Exception as e: tool_results[name] = [{"content": f"Retriever failed: {str(e)}", "score": 0}] print(f"Retriever '{name}' failed: {e}") state['tool_results'] = tool_results return state #Answer Question def generate_answer(state: GraphState): """ This function generates an answer to the query using the llm and the context provided. """ query = state['input'] history = state.get('history', []) history_text = "\n".join( f"Human: {m.content}" if isinstance(m, HumanMessage) else f"AI: {m.content}" for m in history ) intermediate_steps = state.get('tool_results', {}) steps_string = "\n".join( f"{tool_name} Results:\n" + "\n".join( f"- Product: {entry.get('metadata', {}).get('product_name', 'N/A')}\n {entry['content']}" for entry in results ) for tool_name, results in intermediate_steps.items() if results ) prompt_template = build_prompt(state['prompt_template']) prompt_input = prompt_template.format( input=query, agent_scratchpad=steps_string, history=history_text ) print(prompt_input) state['prompt'] = prompt_input llm_response = llm.invoke(prompt_input) state['response'] = llm_response.content if hasattr(llm_response, 'content') else str(llm_response) state['history'].append(HumanMessage(content=query)) state['history'].append(AIMessage(content=state['response'])) return state graph = StateGraph(GraphState) #Add nodes to the graph graph.add_node("route_tool", RunnableLambda(retrieve_node)) graph.add_node("generate_response", RunnableLambda(generate_answer)) # Define the flow of the graph graph.set_entry_point("route_tool") graph.add_edge("route_tool", "generate_response") graph.add_edge("generate_response", END) app = graph.compile() async def get_response(query, session_id, name, email, rag_config, config) -> dict: start_time = time.time() session_id = config['configurable']['thread_id'] history = session_histories.get(session_id, []) input_data = { "input": query, "history": history, "retrieve_tools": rag_config.get('retrievers', []), "prompt_template": rag_config.get('prompt_template', ""), } metadata={} latency_ms = None try: result = await app.ainvoke(input_data, config=config) latency_ms = int((time.time() - start_time) * 1000) session_histories[session_id] = result.get("history", []) metadata = { "retrieved_docs": result.get("tool_results", {}), "model": LLM_MODEL, "embedding_model": EMBEDDING_MODEL, "prompt": result.get("prompt", "") } filtered_result = result['response'].replace("transdermal", "topical") result['response'] = filtered_result except Exception as e: result = {} result['response'] = f"Error in processing chat: {e}" print(f"Responsjh: {result['response']}") log_chat( session_id=session_id, company_id=rag_config.get('company_id'), chatbot_id=rag_config.get('chatbot_id'), name=name, email=email, query=query, answer=result.get("response", ""), latency_ms= latency_ms, metadata=metadata ) return result