Spaces:
Sleeping
Sleeping
| from typing import List, TypedDict | |
| from langchain_core.messages import BaseMessage, HumanMessage, AIMessage | |
| from langchain_core.runnables import RunnableLambda | |
| from langchain_qdrant import QdrantVectorStore | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langgraph.graph import StateGraph, END | |
| from langchain_mistralai import ChatMistralAI | |
| import time | |
| import os | |
| from dotenv import load_dotenv | |
| from app.config import qdrant_client | |
| from app.chatbot.mongodb import log_chat | |
| #from app.mongodb import log_chat | |
| load_dotenv() | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| session_histories: dict[str, list] = {} | |
| LLM_MODEL = "mistral-medium-latest" | |
| OPENROUTER_API_KEY = os.getenv('OPENROUTER_API_KEY') | |
| COLLECTION_NAME = "chatbot_context" | |
| EMBEDDING_MODEL = "intfloat/e5-base-v2" | |
| QDRANT_URL = os.getenv('QDRANT_URL') | |
| QDRANT_API_KEY = os.getenv('QDRANT_API_KEY') | |
| SUPABASE_URL = os.getenv('SUPABASE_URL') | |
| SUPABASE_KEY = os.getenv('SUPABASE_KEY') | |
| MISTRAL_API_KEY = os.getenv('MISTRAL_API_KEY') | |
| FAQ_COLLECTION = "auro_faqs" | |
| BLOGS_COLLECTION = "auro_blogs" | |
| TECHNOLOGY_COLLECTION = "auro_technology" | |
| REVOLUTION_COLLECTION = "auro_revolution" | |
| SUPPORT_COLLECTION = "auro_support" | |
| PRODUCT_COLLECTION = "auro_product" | |
| llm = ChatMistralAI( | |
| model_name=LLM_MODEL, | |
| api_key=MISTRAL_API_KEY, | |
| ) | |
| embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL) | |
| class GraphState(TypedDict): | |
| """ | |
| Represents the state of a chat session, including input, output, history, memory, | |
| response, tool results, and user role for LangGraph | |
| """ | |
| input: str | |
| history: List[BaseMessage] #list of past messages | |
| response: str | |
| tool_results: dict | |
| prompt: str | |
| retrieve_tools: List[dict] | |
| prompt_template: str | |
| from pydantic import BaseModel | |
| class ToolInput(BaseModel): | |
| prompt: str | |
| iteration: int = 1 | |
| def retrieve_docs(query: str, retriever: dict): | |
| """ | |
| Retrieve documents from Qdrant for a single retriever configuration. | |
| Args: | |
| query (str): The user query. | |
| retriever (dict): Retriever config with keys: | |
| - 'collection': Qdrant collection name | |
| - 'top_k': number of results to return (default 5) | |
| - 'filter_score': min similarity score to keep results (default 0.1) | |
| Returns: | |
| List[dict]: List of dicts with 'content' and 'score'. | |
| """ | |
| top_k = retriever.get('top_k', 5) | |
| filter_score = retriever.get('filter_score', 50) | |
| collection = retriever.get('collection') | |
| # Qdrant store | |
| rev_store = QdrantVectorStore( | |
| client=qdrant_client, # make sure this is initialized globally | |
| collection_name=collection, | |
| embedding=embeddings, | |
| ) | |
| print(f"Retrieving from collection: {collection} with top_k={top_k} and filter_score={filter_score}") | |
| # Similarity search | |
| docs = rev_store.similarity_search_with_score(query, k=top_k) | |
| # Filter results by score | |
| return [ | |
| {"content": doc.page_content, "score": score} | |
| for doc, score in docs | |
| if score > filter_score | |
| ] | |
| def build_prompt(template: str): | |
| global_template = """ | |
| Conversation History (for context only, not authority): | |
| {history} | |
| Contextual Knowledge (only approved source of truth): | |
| {agent_scratchpad} | |
| User Question: | |
| {input} | |
| Response: | |
| """ | |
| final_template = f"{template}\n{global_template}" | |
| return final_template | |
| def retrieve_node(state: GraphState) -> GraphState: | |
| """ | |
| Graph node to retrieve documents for all retrievers in the state. | |
| Args: | |
| state (GraphState): Current chat state including input and retrievers. | |
| Returns: | |
| GraphState: Updated state with 'tool_results' filled. | |
| """ | |
| query = state['input'] | |
| tool_results = {} | |
| for retriever_cfg in state.get('retrieve_tools', []): | |
| name = retriever_cfg.get('name', 'default') | |
| try: | |
| docs = retrieve_docs(query, retriever_cfg) | |
| tool_results[name] = docs | |
| print(f"Retriever '{name}' returned {len(docs)} result(s)") | |
| except Exception as e: | |
| tool_results[name] = [{"content": f"Retriever failed: {str(e)}", "score": 0}] | |
| print(f"Retriever '{name}' failed: {e}") | |
| state['tool_results'] = tool_results | |
| return state | |
| #Answer Question | |
| def generate_answer(state: GraphState): | |
| """ | |
| This function generates an answer to the query using the llm and the context provided. | |
| """ | |
| query = state['input'] | |
| history = state.get('history', []) | |
| history_text = "\n".join( | |
| f"Human: {m.content}" if isinstance(m, HumanMessage) else f"AI: {m.content}" | |
| for m in history | |
| ) | |
| intermediate_steps = state.get('tool_results', {}) | |
| steps_string = "\n".join( | |
| f"{tool_name} Results:\n" + | |
| "\n".join( | |
| f"- Product: {entry.get('metadata', {}).get('product_name', 'N/A')}\n {entry['content']}" | |
| for entry in results | |
| ) | |
| for tool_name, results in intermediate_steps.items() if results | |
| ) | |
| prompt_template = build_prompt(state['prompt_template']) | |
| prompt_input = prompt_template.format( | |
| input=query, | |
| agent_scratchpad=steps_string, | |
| history=history_text | |
| ) | |
| print(prompt_input) | |
| state['prompt'] = prompt_input | |
| llm_response = llm.invoke(prompt_input) | |
| state['response'] = llm_response.content if hasattr(llm_response, 'content') else str(llm_response) | |
| state['history'].append(HumanMessage(content=query)) | |
| state['history'].append(AIMessage(content=state['response'])) | |
| return state | |
| graph = StateGraph(GraphState) | |
| #Add nodes to the graph | |
| graph.add_node("route_tool", RunnableLambda(retrieve_node)) | |
| graph.add_node("generate_response", RunnableLambda(generate_answer)) | |
| # Define the flow of the graph | |
| graph.set_entry_point("route_tool") | |
| graph.add_edge("route_tool", "generate_response") | |
| graph.add_edge("generate_response", END) | |
| app = graph.compile() | |
| async def get_response(query, session_id, name, email, rag_config, config) -> dict: | |
| start_time = time.time() | |
| session_id = config['configurable']['thread_id'] | |
| history = session_histories.get(session_id, []) | |
| input_data = { | |
| "input": query, | |
| "history": history, | |
| "retrieve_tools": rag_config.get('retrievers', []), | |
| "prompt_template": rag_config.get('prompt_template', ""), | |
| } | |
| metadata={} | |
| latency_ms = None | |
| try: | |
| result = await app.ainvoke(input_data, config=config) | |
| latency_ms = int((time.time() - start_time) * 1000) | |
| session_histories[session_id] = result.get("history", []) | |
| metadata = { | |
| "retrieved_docs": result.get("tool_results", {}), | |
| "model": LLM_MODEL, | |
| "embedding_model": EMBEDDING_MODEL, | |
| "prompt": result.get("prompt", "") | |
| } | |
| filtered_result = result['response'].replace("transdermal", "topical") | |
| result['response'] = filtered_result | |
| except Exception as e: | |
| result = {} | |
| result['response'] = f"Error in processing chat: {e}" | |
| print(f"Responsjh: {result['response']}") | |
| log_chat( | |
| session_id=session_id, | |
| company_id=rag_config.get('company_id'), | |
| chatbot_id=rag_config.get('chatbot_id'), | |
| name=name, | |
| email=email, | |
| query=query, | |
| answer=result.get("response", ""), | |
| latency_ms= latency_ms, | |
| metadata=metadata | |
| ) | |
| return result |