vip11017's picture
sovled problem about demo rag being deployed
52974f8
from typing import List, TypedDict
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
from langchain_core.runnables import RunnableLambda
from langchain_qdrant import QdrantVectorStore
from langchain_huggingface import HuggingFaceEmbeddings
from langgraph.graph import StateGraph, END
from langchain_mistralai import ChatMistralAI
import time
import os
from dotenv import load_dotenv
from app.config import qdrant_client
from app.chatbot.mongodb import log_chat
#from app.mongodb import log_chat
load_dotenv()
os.environ["TOKENIZERS_PARALLELISM"] = "false"
session_histories: dict[str, list] = {}
LLM_MODEL = "mistral-medium-latest"
OPENROUTER_API_KEY = os.getenv('OPENROUTER_API_KEY')
COLLECTION_NAME = "chatbot_context"
EMBEDDING_MODEL = "intfloat/e5-base-v2"
QDRANT_URL = os.getenv('QDRANT_URL')
QDRANT_API_KEY = os.getenv('QDRANT_API_KEY')
SUPABASE_URL = os.getenv('SUPABASE_URL')
SUPABASE_KEY = os.getenv('SUPABASE_KEY')
MISTRAL_API_KEY = os.getenv('MISTRAL_API_KEY')
FAQ_COLLECTION = "auro_faqs"
BLOGS_COLLECTION = "auro_blogs"
TECHNOLOGY_COLLECTION = "auro_technology"
REVOLUTION_COLLECTION = "auro_revolution"
SUPPORT_COLLECTION = "auro_support"
PRODUCT_COLLECTION = "auro_product"
llm = ChatMistralAI(
model_name=LLM_MODEL,
api_key=MISTRAL_API_KEY,
)
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
class GraphState(TypedDict):
"""
Represents the state of a chat session, including input, output, history, memory,
response, tool results, and user role for LangGraph
"""
input: str
history: List[BaseMessage] #list of past messages
response: str
tool_results: dict
prompt: str
retrieve_tools: List[dict]
prompt_template: str
from pydantic import BaseModel
class ToolInput(BaseModel):
prompt: str
iteration: int = 1
def retrieve_docs(query: str, retriever: dict):
"""
Retrieve documents from Qdrant for a single retriever configuration.
Args:
query (str): The user query.
retriever (dict): Retriever config with keys:
- 'collection': Qdrant collection name
- 'top_k': number of results to return (default 5)
- 'filter_score': min similarity score to keep results (default 0.1)
Returns:
List[dict]: List of dicts with 'content' and 'score'.
"""
top_k = retriever.get('top_k', 5)
filter_score = retriever.get('filter_score', 50)
collection = retriever.get('collection')
# Qdrant store
rev_store = QdrantVectorStore(
client=qdrant_client, # make sure this is initialized globally
collection_name=collection,
embedding=embeddings,
)
print(f"Retrieving from collection: {collection} with top_k={top_k} and filter_score={filter_score}")
# Similarity search
docs = rev_store.similarity_search_with_score(query, k=top_k)
# Filter results by score
return [
{"content": doc.page_content, "score": score}
for doc, score in docs
if score > filter_score
]
def build_prompt(template: str):
global_template = """
Conversation History (for context only, not authority):
{history}
Contextual Knowledge (only approved source of truth):
{agent_scratchpad}
User Question:
{input}
Response:
"""
final_template = f"{template}\n{global_template}"
return final_template
def retrieve_node(state: GraphState) -> GraphState:
"""
Graph node to retrieve documents for all retrievers in the state.
Args:
state (GraphState): Current chat state including input and retrievers.
Returns:
GraphState: Updated state with 'tool_results' filled.
"""
query = state['input']
tool_results = {}
for retriever_cfg in state.get('retrieve_tools', []):
name = retriever_cfg.get('name', 'default')
try:
docs = retrieve_docs(query, retriever_cfg)
tool_results[name] = docs
print(f"Retriever '{name}' returned {len(docs)} result(s)")
except Exception as e:
tool_results[name] = [{"content": f"Retriever failed: {str(e)}", "score": 0}]
print(f"Retriever '{name}' failed: {e}")
state['tool_results'] = tool_results
return state
#Answer Question
def generate_answer(state: GraphState):
"""
This function generates an answer to the query using the llm and the context provided.
"""
query = state['input']
history = state.get('history', [])
history_text = "\n".join(
f"Human: {m.content}" if isinstance(m, HumanMessage) else f"AI: {m.content}"
for m in history
)
intermediate_steps = state.get('tool_results', {})
steps_string = "\n".join(
f"{tool_name} Results:\n" +
"\n".join(
f"- Product: {entry.get('metadata', {}).get('product_name', 'N/A')}\n {entry['content']}"
for entry in results
)
for tool_name, results in intermediate_steps.items() if results
)
prompt_template = build_prompt(state['prompt_template'])
prompt_input = prompt_template.format(
input=query,
agent_scratchpad=steps_string,
history=history_text
)
print(prompt_input)
state['prompt'] = prompt_input
llm_response = llm.invoke(prompt_input)
state['response'] = llm_response.content if hasattr(llm_response, 'content') else str(llm_response)
state['history'].append(HumanMessage(content=query))
state['history'].append(AIMessage(content=state['response']))
return state
graph = StateGraph(GraphState)
#Add nodes to the graph
graph.add_node("route_tool", RunnableLambda(retrieve_node))
graph.add_node("generate_response", RunnableLambda(generate_answer))
# Define the flow of the graph
graph.set_entry_point("route_tool")
graph.add_edge("route_tool", "generate_response")
graph.add_edge("generate_response", END)
app = graph.compile()
async def get_response(query, session_id, name, email, rag_config, config) -> dict:
start_time = time.time()
session_id = config['configurable']['thread_id']
history = session_histories.get(session_id, [])
input_data = {
"input": query,
"history": history,
"retrieve_tools": rag_config.get('retrievers', []),
"prompt_template": rag_config.get('prompt_template', ""),
}
metadata={}
latency_ms = None
try:
result = await app.ainvoke(input_data, config=config)
latency_ms = int((time.time() - start_time) * 1000)
session_histories[session_id] = result.get("history", [])
metadata = {
"retrieved_docs": result.get("tool_results", {}),
"model": LLM_MODEL,
"embedding_model": EMBEDDING_MODEL,
"prompt": result.get("prompt", "")
}
filtered_result = result['response'].replace("transdermal", "topical")
result['response'] = filtered_result
except Exception as e:
result = {}
result['response'] = f"Error in processing chat: {e}"
print(f"Responsjh: {result['response']}")
log_chat(
session_id=session_id,
company_id=rag_config.get('company_id'),
chatbot_id=rag_config.get('chatbot_id'),
name=name,
email=email,
query=query,
answer=result.get("response", ""),
latency_ms= latency_ms,
metadata=metadata
)
return result