File size: 5,295 Bytes
ecb8437
 
 
 
 
 
 
 
 
 
 
 
 
f267137
 
ecb8437
 
 
 
 
 
 
68af232
 
 
 
 
3749f20
ecb8437
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3749f20
 
 
68af232
3749f20
68af232
 
 
 
3749f20
68af232
 
 
 
3749f20
68af232
3749f20
3379d9a
3749f20
 
ecb8437
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68af232
ecb8437
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import os
from datasets import load_dataset
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_text_splitters import CharacterTextSplitter
from langchain_core.documents import Document
from langgraph.graph import START, StateGraph
from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt import create_react_agent
from langchain_core.prompts import ChatPromptTemplate
from huggingface_hub import login
from dotenv import load_dotenv
from typing import TypedDict, List
from google import genai


# Load environment variables
load_dotenv()

GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
HF_TOKEN = os.getenv("HF_TOKEN")

# Configure Google Gemini SDK (deprecated genai.configure removed in new SDK)
# Instead, authentication happens via Client(api_key=...)
# Make sure GOOGLE_API_KEY is set
if not GOOGLE_API_KEY:
    raise ValueError("Please set GOOGLE_API_KEY in your environment variables.")

# Authenticate Hugging Face
if HF_TOKEN:
    try:
        login(token=HF_TOKEN)
        print("✅ Logged in to Hugging Face using HF_TOKEN.")
    except Exception as e:
        print(f"⚠️ Hugging Face login failed: {e}")
else:
    print("⚠️ No HF_TOKEN found in .env file. Using public mode.")

# --- STATE DEFINITION ---
class RAGState(TypedDict):
    question: str
    context: str
    answer: str
    chat_history: List[str]
    source_documents: List[Document]

# --- LLM Wrapper ---
class GeminiLLMWrapper:
    """
    Wrapper around Google Gemini API using the latest Client interface.
    """
    def __init__(self):
        # Create a Gemini Client with API key
        self.client = genai.Client(api_key=GOOGLE_API_KEY)

    def invoke(self, prompt: str):
        # Use generate_content to produce text
        response = self.client.models.generate_content(
            model="gemini-2.5-flash",
            contents=prompt
        )
        # response.text contains the generated text
        class Result:
            content = response.text
        return Result()

def build_rag_pipeline():
    """Builds a LangGraph-based RAG pipeline compatible with LangChain 1.x."""

    # --- Load dataset ---
    try:
        dataset = load_dataset("fadodr/mental_health_therapy", split="train[:300]")
        print("✅ Loaded dataset: fadodr/mental_health_therapy")
    except Exception as e:
        print(f"⚠️ Could not load dataset: {e}")
        dataset = load_dataset("mental_health_therapy", split="train[:300]", token=HF_TOKEN)

    # --- Prepare documents ---
    texts = [f"Q: {d['instruction']}\nA: {d['input']}" for d in dataset if d.get("input", "").strip()]
    if not texts:
        raise ValueError("No valid text found in dataset to create embeddings!")

    splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=100)
    docs = [Document(page_content=t) for t in texts]
    split_docs = splitter.split_documents(docs)

    # --- Embeddings + Chroma DB ---
    embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
    vector_db = Chroma.from_documents(split_docs, embeddings, persist_directory="chroma_db")
    retriever = vector_db.as_retriever(search_kwargs={"k": 3})

    # --- LLM ---
    llm = GeminiLLMWrapper()  # Use wrapper with updated Client

    # --- PROMPT TEMPLATE ---
    prompt = ChatPromptTemplate.from_template(
        """
        You are a helpful assistant. Use the following retrieved context to answer the user's question.
        If the context doesn't contain the answer, say so politely.
        Context:
        {context}

        Question:
        {question}

        Answer:
        """
    )

    # --- NODES (GRAPH FUNCTIONS) ---
    def retrieve_docs(state: RAGState):
        query = state["question"]
        docs = retriever.invoke(query)
        context = "\n\n".join([d.page_content for d in docs])
        return {"context": context, "source_documents": docs}

    def generate_answer(state: RAGState):
        prompt_text = prompt.format(context=state["context"], question=state["question"])
        response = llm.invoke(prompt_text)
        return {"answer": response.content}

    # --- BUILD THE GRAPH ---
    graph_builder = StateGraph(RAGState)
    graph_builder.add_node("retrieve", retrieve_docs)
    graph_builder.add_node("generate", generate_answer)
    graph_builder.add_edge(START, "retrieve")
    graph_builder.add_edge("retrieve", "generate")

    # Add in-memory checkpointing (conversation memory)
    memory = MemorySaver()

    graph = graph_builder.compile(checkpointer=memory)

    # Wrap in a callable interface so app.py still works
    class RAGChainWrapper:
        def __init__(self, graph):
            self.graph = graph

        def __call__(self, inputs: dict):
            question = inputs.get("question", "")
            state = {"question": question, "chat_history": []}
            result = self.graph.invoke(
                state,
                config={"configurable": {"thread_id": "default"}}
            )
            return {
                "answer": result.get("answer", ""),
                "source_documents": result.get("source_documents", [])
            }

    rag_chain = RAGChainWrapper(graph)

    return llm, retriever, rag_chain