File size: 5,295 Bytes
ecb8437 f267137 ecb8437 68af232 3749f20 ecb8437 3749f20 68af232 3749f20 68af232 3749f20 68af232 3749f20 68af232 3749f20 3379d9a 3749f20 ecb8437 68af232 ecb8437 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import os
from datasets import load_dataset
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_text_splitters import CharacterTextSplitter
from langchain_core.documents import Document
from langgraph.graph import START, StateGraph
from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt import create_react_agent
from langchain_core.prompts import ChatPromptTemplate
from huggingface_hub import login
from dotenv import load_dotenv
from typing import TypedDict, List
from google import genai
# Load environment variables
load_dotenv()
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
HF_TOKEN = os.getenv("HF_TOKEN")
# Configure Google Gemini SDK (deprecated genai.configure removed in new SDK)
# Instead, authentication happens via Client(api_key=...)
# Make sure GOOGLE_API_KEY is set
if not GOOGLE_API_KEY:
raise ValueError("Please set GOOGLE_API_KEY in your environment variables.")
# Authenticate Hugging Face
if HF_TOKEN:
try:
login(token=HF_TOKEN)
print("✅ Logged in to Hugging Face using HF_TOKEN.")
except Exception as e:
print(f"⚠️ Hugging Face login failed: {e}")
else:
print("⚠️ No HF_TOKEN found in .env file. Using public mode.")
# --- STATE DEFINITION ---
class RAGState(TypedDict):
question: str
context: str
answer: str
chat_history: List[str]
source_documents: List[Document]
# --- LLM Wrapper ---
class GeminiLLMWrapper:
"""
Wrapper around Google Gemini API using the latest Client interface.
"""
def __init__(self):
# Create a Gemini Client with API key
self.client = genai.Client(api_key=GOOGLE_API_KEY)
def invoke(self, prompt: str):
# Use generate_content to produce text
response = self.client.models.generate_content(
model="gemini-2.5-flash",
contents=prompt
)
# response.text contains the generated text
class Result:
content = response.text
return Result()
def build_rag_pipeline():
"""Builds a LangGraph-based RAG pipeline compatible with LangChain 1.x."""
# --- Load dataset ---
try:
dataset = load_dataset("fadodr/mental_health_therapy", split="train[:300]")
print("✅ Loaded dataset: fadodr/mental_health_therapy")
except Exception as e:
print(f"⚠️ Could not load dataset: {e}")
dataset = load_dataset("mental_health_therapy", split="train[:300]", token=HF_TOKEN)
# --- Prepare documents ---
texts = [f"Q: {d['instruction']}\nA: {d['input']}" for d in dataset if d.get("input", "").strip()]
if not texts:
raise ValueError("No valid text found in dataset to create embeddings!")
splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=100)
docs = [Document(page_content=t) for t in texts]
split_docs = splitter.split_documents(docs)
# --- Embeddings + Chroma DB ---
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
vector_db = Chroma.from_documents(split_docs, embeddings, persist_directory="chroma_db")
retriever = vector_db.as_retriever(search_kwargs={"k": 3})
# --- LLM ---
llm = GeminiLLMWrapper() # Use wrapper with updated Client
# --- PROMPT TEMPLATE ---
prompt = ChatPromptTemplate.from_template(
"""
You are a helpful assistant. Use the following retrieved context to answer the user's question.
If the context doesn't contain the answer, say so politely.
Context:
{context}
Question:
{question}
Answer:
"""
)
# --- NODES (GRAPH FUNCTIONS) ---
def retrieve_docs(state: RAGState):
query = state["question"]
docs = retriever.invoke(query)
context = "\n\n".join([d.page_content for d in docs])
return {"context": context, "source_documents": docs}
def generate_answer(state: RAGState):
prompt_text = prompt.format(context=state["context"], question=state["question"])
response = llm.invoke(prompt_text)
return {"answer": response.content}
# --- BUILD THE GRAPH ---
graph_builder = StateGraph(RAGState)
graph_builder.add_node("retrieve", retrieve_docs)
graph_builder.add_node("generate", generate_answer)
graph_builder.add_edge(START, "retrieve")
graph_builder.add_edge("retrieve", "generate")
# Add in-memory checkpointing (conversation memory)
memory = MemorySaver()
graph = graph_builder.compile(checkpointer=memory)
# Wrap in a callable interface so app.py still works
class RAGChainWrapper:
def __init__(self, graph):
self.graph = graph
def __call__(self, inputs: dict):
question = inputs.get("question", "")
state = {"question": question, "chat_history": []}
result = self.graph.invoke(
state,
config={"configurable": {"thread_id": "default"}}
)
return {
"answer": result.get("answer", ""),
"source_documents": result.get("source_documents", [])
}
rag_chain = RAGChainWrapper(graph)
return llm, retriever, rag_chain
|