| | import os |
| | import getpass |
| | import pandas as pd |
| | from typing import Optional, Dict, Any |
| |
|
| | |
| | try: |
| | from langchain_core.runnables.base import Runnable |
| | except ImportError: |
| | try: |
| | from langchain.runnables.base import Runnable |
| | except ImportError: |
| | raise ImportError("Cannot find Runnable class. Please upgrade LangChain or check your installation.") |
| |
|
| | from langchain.docstore.document import Document |
| | from langchain.embeddings import HuggingFaceEmbeddings |
| | from langchain.vectorstores import FAISS |
| | from langchain.chains import RetrievalQA |
| |
|
| | from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel |
| | import litellm |
| |
|
| | from classification_chain import get_classification_chain |
| | from refusal_chain import get_refusal_chain |
| | from tailor_chain import get_tailor_chain |
| | from cleaner_chain import get_cleaner_chain |
| | from contextualize_chain import get_contextualize_chain |
| |
|
| | from langchain.llms.base import LLM |
| |
|
| |
|
| | |
| | |
| | |
| | if not os.environ.get("GEMINI_API_KEY"): |
| | os.environ["GEMINI_API_KEY"] = getpass.getpass("Enter your Gemini API Key: ") |
| | if not os.environ.get("GROQ_API_KEY"): |
| | os.environ["GROQ_API_KEY"] = getpass.getpass("Enter your GROQ API Key: ") |
| |
|
| | |
| | |
| | |
| | def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS: |
| | if os.path.exists(store_dir): |
| | print(f"DEBUG: Found existing FAISS store at '{store_dir}'. Loading from disk.") |
| | embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/multi-qa-mpnet-base-dot-v1") |
| | vectorstore = FAISS.load_local(store_dir, embeddings) |
| | return vectorstore |
| | else: |
| | print(f"DEBUG: Building new store from CSV: {csv_path}") |
| | df = pd.read_csv(csv_path) |
| | df = df.loc[:, ~df.columns.str.contains('^Unnamed')] |
| | df.columns = df.columns.str.strip() |
| |
|
| | if "Answer" in df.columns: |
| | df.rename(columns={"Answer": "Answers"}, inplace=True) |
| | if "Question" not in df.columns and "Question " in df.columns: |
| | df.rename(columns={"Question ": "Question"}, inplace=True) |
| |
|
| | if "Question" not in df.columns or "Answers" not in df.columns: |
| | raise ValueError("CSV must have 'Question' and 'Answers' columns.") |
| |
|
| | docs = [] |
| | for _, row in df.iterrows(): |
| | q = str(row["Question"]) |
| | ans = str(row["Answers"]) |
| | doc = Document(page_content=ans, metadata={"question": q}) |
| | docs.append(doc) |
| |
|
| | embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/multi-qa-mpnet-base-dot-v1") |
| | vectorstore = FAISS.from_documents(docs, embedding=embeddings) |
| | vectorstore.save_local(store_dir) |
| | return vectorstore |
| |
|
| | |
| | |
| | |
| | def build_rag_chain(llm_model: LiteLLMModel, vectorstore: FAISS) -> RetrievalQA: |
| | class GeminiLangChainLLM(LLM): |
| | def _call(self, prompt: str, stop: Optional[list] = None, **kwargs) -> str: |
| | messages = [{"role": "user", "content": prompt}] |
| | return llm_model(messages, stop_sequences=stop) |
| |
|
| | @property |
| | def _llm_type(self) -> str: |
| | return "custom_gemini" |
| |
|
| | retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 3}) |
| | gemini_as_llm = GeminiLangChainLLM() |
| | rag_chain = RetrievalQA.from_chain_type( |
| | llm=gemini_as_llm, |
| | chain_type="stuff", |
| | retriever=retriever, |
| | return_source_documents=True |
| | ) |
| | return rag_chain |
| |
|
| | |
| | |
| | |
| | classification_chain = get_classification_chain() |
| | refusal_chain = get_refusal_chain() |
| | tailor_chain = get_tailor_chain() |
| | cleaner_chain = get_cleaner_chain() |
| | contextualize_chain = get_contextualize_chain() |
| |
|
| | |
| | |
| | |
| | gemini_llm = LiteLLMModel(model_id="gemini/gemini-pro", api_key=os.environ.get("GEMINI_API_KEY")) |
| |
|
| | wellness_csv = "AIChatbot.csv" |
| | brand_csv = "BrandAI.csv" |
| | wellness_store_dir = "faiss_wellness_store" |
| | brand_store_dir = "faiss_brand_store" |
| |
|
| | wellness_vectorstore = build_or_load_vectorstore(wellness_csv, wellness_store_dir) |
| | brand_vectorstore = build_or_load_vectorstore(brand_csv, brand_store_dir) |
| |
|
| | wellness_rag_chain = build_rag_chain(gemini_llm, wellness_vectorstore) |
| | brand_rag_chain = build_rag_chain(gemini_llm, brand_vectorstore) |
| |
|
| | search_tool = DuckDuckGoSearchTool() |
| | web_agent = CodeAgent(tools=[search_tool], model=gemini_llm) |
| | managed_web_agent = ManagedAgent(agent=web_agent, name="web_search", description="Runs web search for you.") |
| | manager_agent = CodeAgent(tools=[], model=gemini_llm, managed_agents=[managed_web_agent]) |
| |
|
| | def do_web_search(query: str) -> str: |
| | print("DEBUG: Attempting web search for more info...") |
| | search_query = f"Give me relevant info: {query}" |
| | response = manager_agent.run(search_query) |
| | return response |
| |
|
| | |
| | |
| | |
| | def run_with_chain_context(inputs: Dict[str, Any]) -> Dict[str, str]: |
| | user_query = inputs["input"] |
| | chat_history = inputs.get("chat_history", []) |
| |
|
| | contextualized_query = contextualize_chain.invoke({"user_query": user_query, "chat_history": chat_history}) |
| |
|
| | |
| | class_result = classification_chain.invoke({"query": contextualized_query, "chat_history": chat_history}) |
| | classification = class_result.get("text", "").strip() |
| |
|
| | if classification == "OutOfScope": |
| | refusal_text = refusal_chain.run({"chat_history": chat_history}) |
| | final_refusal = tailor_chain.run({"response": refusal_text, "chat_history": chat_history}) |
| | return {"answer": final_refusal.strip()} |
| |
|
| | if classification == "Wellness": |
| | rag_result = wellness_rag_chain.invoke({ |
| | "query": contextualized_query, |
| | "chat_history": chat_history |
| | }) |
| | csv_answer = rag_result["result"].strip() |
| | web_answer = do_web_search(contextualized_query) if not csv_answer else "" |
| | final_merged = cleaner_chain.merge(kb=csv_answer, web=web_answer, chat_history=chat_history) |
| | final_answer = tailor_chain.run({"response": final_merged, "chat_history": chat_history}).strip() |
| | return {"answer": final_answer} |
| |
|
| | if classification == "Brand": |
| | rag_result = brand_rag_chain.invoke({ |
| | "query": contextualized_query, |
| | "chat_history": chat_history |
| | }) |
| | csv_answer = rag_result["result"].strip() |
| | final_merged = cleaner_chain.merge(kb=csv_answer, web="", chat_history=chat_history) |
| | final_answer = tailor_chain.run({"response": final_merged, "chat_history": chat_history}).strip() |
| | return {"answer": final_answer} |
| |
|
| | refusal_text = refusal_chain.run({"chat_history": chat_history}) |
| | final_refusal = tailor_chain.run({"response": refusal_text, "chat_history": chat_history}).strip() |
| | return {"answer": final_refusal} |
| |
|
| | |
| | |
| | |
| | class PipelineRunnable(Runnable[Dict[str, Any], Dict[str, str]]): |
| | def invoke(self, input: Dict[str, Any], config: Optional[Any] = None) -> Dict[str, str]: |
| | return run_with_chain_context(input) |
| |
|
| | pipeline_runnable = PipelineRunnable() |
| |
|