|
|
import os |
|
|
from datasets import load_dataset |
|
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
|
from langchain_community.vectorstores import Chroma |
|
|
from langchain_text_splitters import CharacterTextSplitter |
|
|
from langchain_core.documents import Document |
|
|
from langgraph.graph import START, StateGraph |
|
|
from langgraph.checkpoint.memory import MemorySaver |
|
|
from langgraph.prebuilt import create_react_agent |
|
|
from langchain_core.prompts import ChatPromptTemplate |
|
|
from huggingface_hub import login |
|
|
from dotenv import load_dotenv |
|
|
from typing import TypedDict, List |
|
|
from google import genai |
|
|
|
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") |
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not GOOGLE_API_KEY: |
|
|
raise ValueError("Please set GOOGLE_API_KEY in your environment variables.") |
|
|
|
|
|
|
|
|
if HF_TOKEN: |
|
|
try: |
|
|
login(token=HF_TOKEN) |
|
|
print("✅ Logged in to Hugging Face using HF_TOKEN.") |
|
|
except Exception as e: |
|
|
print(f"⚠️ Hugging Face login failed: {e}") |
|
|
else: |
|
|
print("⚠️ No HF_TOKEN found in .env file. Using public mode.") |
|
|
|
|
|
|
|
|
class RAGState(TypedDict): |
|
|
question: str |
|
|
context: str |
|
|
answer: str |
|
|
chat_history: List[str] |
|
|
source_documents: List[Document] |
|
|
|
|
|
|
|
|
class GeminiLLMWrapper: |
|
|
""" |
|
|
Wrapper around Google Gemini API using the latest Client interface. |
|
|
""" |
|
|
def __init__(self): |
|
|
|
|
|
self.client = genai.Client(api_key=GOOGLE_API_KEY) |
|
|
|
|
|
def invoke(self, prompt: str): |
|
|
|
|
|
response = self.client.models.generate_content( |
|
|
model="gemini-2.5-flash", |
|
|
contents=prompt |
|
|
) |
|
|
|
|
|
class Result: |
|
|
content = response.text |
|
|
return Result() |
|
|
|
|
|
def build_rag_pipeline(): |
|
|
"""Builds a LangGraph-based RAG pipeline compatible with LangChain 1.x.""" |
|
|
|
|
|
|
|
|
try: |
|
|
dataset = load_dataset("fadodr/mental_health_therapy", split="train[:300]") |
|
|
print("✅ Loaded dataset: fadodr/mental_health_therapy") |
|
|
except Exception as e: |
|
|
print(f"⚠️ Could not load dataset: {e}") |
|
|
dataset = load_dataset("mental_health_therapy", split="train[:300]", token=HF_TOKEN) |
|
|
|
|
|
|
|
|
texts = [f"Q: {d['instruction']}\nA: {d['input']}" for d in dataset if d.get("input", "").strip()] |
|
|
if not texts: |
|
|
raise ValueError("No valid text found in dataset to create embeddings!") |
|
|
|
|
|
splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=100) |
|
|
docs = [Document(page_content=t) for t in texts] |
|
|
split_docs = splitter.split_documents(docs) |
|
|
|
|
|
|
|
|
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") |
|
|
vector_db = Chroma.from_documents(split_docs, embeddings, persist_directory="chroma_db") |
|
|
retriever = vector_db.as_retriever(search_kwargs={"k": 3}) |
|
|
|
|
|
|
|
|
llm = GeminiLLMWrapper() |
|
|
|
|
|
|
|
|
prompt = ChatPromptTemplate.from_template( |
|
|
""" |
|
|
You are a helpful assistant. Use the following retrieved context to answer the user's question. |
|
|
If the context doesn't contain the answer, say so politely. |
|
|
Context: |
|
|
{context} |
|
|
|
|
|
Question: |
|
|
{question} |
|
|
|
|
|
Answer: |
|
|
""" |
|
|
) |
|
|
|
|
|
|
|
|
def retrieve_docs(state: RAGState): |
|
|
query = state["question"] |
|
|
docs = retriever.invoke(query) |
|
|
context = "\n\n".join([d.page_content for d in docs]) |
|
|
return {"context": context, "source_documents": docs} |
|
|
|
|
|
def generate_answer(state: RAGState): |
|
|
prompt_text = prompt.format(context=state["context"], question=state["question"]) |
|
|
response = llm.invoke(prompt_text) |
|
|
return {"answer": response.content} |
|
|
|
|
|
|
|
|
graph_builder = StateGraph(RAGState) |
|
|
graph_builder.add_node("retrieve", retrieve_docs) |
|
|
graph_builder.add_node("generate", generate_answer) |
|
|
graph_builder.add_edge(START, "retrieve") |
|
|
graph_builder.add_edge("retrieve", "generate") |
|
|
|
|
|
|
|
|
memory = MemorySaver() |
|
|
|
|
|
graph = graph_builder.compile(checkpointer=memory) |
|
|
|
|
|
|
|
|
class RAGChainWrapper: |
|
|
def __init__(self, graph): |
|
|
self.graph = graph |
|
|
|
|
|
def __call__(self, inputs: dict): |
|
|
question = inputs.get("question", "") |
|
|
state = {"question": question, "chat_history": []} |
|
|
result = self.graph.invoke( |
|
|
state, |
|
|
config={"configurable": {"thread_id": "default"}} |
|
|
) |
|
|
return { |
|
|
"answer": result.get("answer", ""), |
|
|
"source_documents": result.get("source_documents", []) |
|
|
} |
|
|
|
|
|
rag_chain = RAGChainWrapper(graph) |
|
|
|
|
|
return llm, retriever, rag_chain |
|
|
|