Spaces:
Sleeping
Sleeping
| import os | |
| from typing import Dict, List, Any, Literal | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.runnables import RunnablePassthrough | |
| from langgraph.graph import StateGraph | |
| from langgraph.graph.graph import END | |
| from dotenv import load_dotenv | |
| import google.generativeai as genai | |
| from google.generativeai import GenerativeModel | |
| import sys | |
| # Add the parent directory to the path to import utils | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))) | |
| from utils.create_vectordb import query_chroma_db | |
| load_dotenv() | |
| # Initialize Gemini model | |
| api_key = os.getenv("GOOGLE_API_KEY") | |
| genai.configure(api_key=api_key) | |
| model = GenerativeModel("gemini-2.5-flash-preview-05-20") | |
| def retrieve_context(state: Dict[str, Any]) -> Dict[str, Any]: | |
| """ | |
| Retrieve relevant context from the vector database based on the user query. | |
| """ | |
| query = state.get("user_input", "") | |
| if not query: | |
| return {"context": "No query provided.", "user_input": query, "next": "request_clarification"} | |
| # Check if query is clear enough | |
| if len(query.split()) < 3 or "?" not in query and not any(w in query.lower() for w in ["what", "how", "why", "when", "where", "who", "which"]): | |
| return {"context": "", "user_input": query, "next": "request_clarification"} | |
| # Query the vector database | |
| results = query_chroma_db(query, n_results=3) | |
| # Extract the retrieved documents | |
| documents = results.get("documents", [[]])[0] | |
| metadatas = results.get("metadatas", [[]])[0] | |
| # Format the context | |
| formatted_context = [] | |
| for i, (doc, metadata) in enumerate(zip(documents, metadatas)): | |
| source = metadata.get("source", "Unknown") | |
| formatted_context.append(f"Document {i+1} (Source: {source}):\n{doc}\n") | |
| context = "\n".join(formatted_context) if formatted_context else "" | |
| # Determine next step based on context quality | |
| if not context or len(context) < 50: | |
| next_step = "use_gemini_knowledge" | |
| else: | |
| next_step = "generate_response" | |
| return {"context": context, "user_input": query, "next": next_step} | |
| def request_clarification(state: Dict[str, Any]) -> Dict[str, Any]: | |
| """ | |
| Request clarification from the user when the query is unclear. | |
| """ | |
| query = state.get("user_input", "") | |
| clarification_message = model.generate_content( | |
| f"""The user asked: "{query}" | |
| This query seems vague or unclear. Generate a polite response asking for more specific details. | |
| Focus on what additional information would help you understand their request better. | |
| Keep your response under 3 sentences and make it conversational.""" | |
| ) | |
| response = clarification_message.text | |
| # Update chat history | |
| chat_history = state.get("chat_history", []) | |
| new_chat_history = chat_history + [ | |
| {"role": "user", "content": query}, | |
| {"role": "assistant", "content": response} | |
| ] | |
| return { | |
| "response": response, | |
| "chat_history": new_chat_history, | |
| "needs_clarification": True | |
| } | |
| def use_gemini_knowledge(state: Dict[str, Any]) -> Dict[str, Any]: | |
| """ | |
| Use Gemini's knowledge base when no relevant information is found in the vector database. | |
| """ | |
| query = state.get("user_input", "") | |
| chat_history = state.get("chat_history", []) | |
| # Construct the prompt | |
| prompt_template = """ | |
| I couldn't find specific information about this in my local database. However, I can try to answer based on my general knowledge. | |
| User Question: {query} | |
| First, acknowledge that you're answering from general knowledge rather than the specific database. | |
| Then provide a helpful, accurate response based on what you know about the topic. | |
| """ | |
| # Generate response | |
| response = model.generate_content( | |
| prompt_template.format(query=query) | |
| ) | |
| # Update chat history | |
| new_chat_history = chat_history + [ | |
| {"role": "user", "content": query}, | |
| {"role": "assistant", "content": response.text} | |
| ] | |
| return { | |
| "response": response.text, | |
| "chat_history": new_chat_history | |
| } | |
| def generate_response(state: Dict[str, Any]) -> Dict[str, Any]: | |
| """ | |
| Generate a response using the LLM based on the retrieved context and user query. | |
| """ | |
| context = state.get("context", "") | |
| query = state.get("user_input", "") | |
| chat_history = state.get("chat_history", []) | |
| # Construct the prompt | |
| prompt_template = """ | |
| You are a helpful assistant that answers questions based on the provided context. | |
| Context: | |
| {context} | |
| Chat History: | |
| {chat_history} | |
| User Question: {query} | |
| Answer the question based only on the provided context. If the context doesn't contain enough information, | |
| acknowledge this but still try to provide a helpful response based on the available information. | |
| Provide a clear, concise, and helpful response. | |
| """ | |
| # Format chat history for the prompt | |
| formatted_chat_history = "\n".join([f"{msg['role']}: {msg['content']}" for msg in chat_history]) | |
| # Generate response | |
| response = model.generate_content( | |
| prompt_template.format( | |
| context=context, | |
| chat_history=formatted_chat_history, | |
| query=query | |
| ) | |
| ) | |
| # Update chat history | |
| new_chat_history = chat_history + [ | |
| {"role": "user", "content": query}, | |
| {"role": "assistant", "content": response.text} | |
| ] | |
| return { | |
| "response": response.text, | |
| "chat_history": new_chat_history | |
| } | |
| def decide_next_step(state: Dict[str, Any]) -> Literal["request_clarification", "use_gemini_knowledge", "generate_response"]: | |
| """ | |
| Decide the next step in the workflow based on the state. | |
| """ | |
| return state["next"] | |
| # Define the workflow | |
| def build_graph(): | |
| workflow = StateGraph(state_schema=Dict[str, Any]) | |
| # Add nodes | |
| workflow.add_node("retrieve_context", retrieve_context) | |
| workflow.add_node("request_clarification", request_clarification) | |
| workflow.add_node("use_gemini_knowledge", use_gemini_knowledge) | |
| workflow.add_node("generate_response", generate_response) | |
| # Define edges with conditional routing | |
| workflow.set_entry_point("retrieve_context") | |
| workflow.add_conditional_edges( | |
| "retrieve_context", | |
| decide_next_step, | |
| { | |
| "request_clarification": "request_clarification", | |
| "use_gemini_knowledge": "use_gemini_knowledge", | |
| "generate_response": "generate_response" | |
| } | |
| ) | |
| # Set finish points | |
| workflow.add_edge("request_clarification", END) | |
| workflow.add_edge("use_gemini_knowledge", END) | |
| workflow.add_edge("generate_response", END) | |
| # Compile the graph | |
| return workflow.compile() | |
| # Create the graph | |
| graph = build_graph() | |