MeteKaba commited on
Commit
14cb65f
·
verified ·
1 Parent(s): 002e843

Upload 2 files

Browse files
Files changed (2) hide show
  1. src/rag_pipeline.py +128 -0
  2. src/streamlit_app.py +0 -0
src/rag_pipeline.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from datasets import load_dataset
3
+ from langchain_community.embeddings import HuggingFaceEmbeddings
4
+ from langchain_community.vectorstores import Chroma
5
+ from langchain_text_splitters import CharacterTextSplitter
6
+ from langchain_google_genai import ChatGoogleGenerativeAI
7
+ from langchain_core.documents import Document
8
+ from langgraph.graph import START, StateGraph
9
+ from langgraph.checkpoint.memory import MemorySaver
10
+ from langgraph.prebuilt import create_react_agent
11
+ from langchain_core.prompts import ChatPromptTemplate
12
+ from huggingface_hub import login
13
+ from dotenv import load_dotenv
14
+ from typing import TypedDict, List
15
+
16
+ # Load environment variables
17
+ load_dotenv()
18
+
19
+ GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
20
+ HF_TOKEN = os.getenv("HF_TOKEN")
21
+
22
+ # Authenticate Hugging Face
23
+ if HF_TOKEN:
24
+ try:
25
+ login(token=HF_TOKEN)
26
+ print("✅ Logged in to Hugging Face using HF_TOKEN.")
27
+ except Exception as e:
28
+ print(f"⚠️ Hugging Face login failed: {e}")
29
+ else:
30
+ print("⚠️ No HF_TOKEN found in .env file. Using public mode.")
31
+
32
+
33
+ # --- STATE DEFINITION ---
34
+ class RAGState(TypedDict):
35
+ question: str
36
+ context: str
37
+ answer: str
38
+ chat_history: List[str]
39
+ source_documents: List[Document]
40
+
41
+
42
+ def build_rag_pipeline():
43
+ """Builds a LangGraph-based RAG pipeline compatible with LangChain 1.x."""
44
+
45
+ # --- Load dataset ---
46
+ try:
47
+ dataset = load_dataset("fadodr/mental_health_therapy", split="train[:300]")
48
+ print("✅ Loaded dataset: fadodr/mental_health_therapy")
49
+ except Exception as e:
50
+ print(f"⚠️ Could not load dataset: {e}")
51
+ dataset = load_dataset("mental_health_therapy", split="train[:300]", token=HF_TOKEN)
52
+
53
+ # --- Prepare documents ---
54
+ texts = [f"Q: {d['instruction']}\nA: {d['input']}" for d in dataset if d.get("input", "").strip()]
55
+ if not texts:
56
+ raise ValueError("No valid text found in dataset to create embeddings!")
57
+
58
+ splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=100)
59
+ docs = [Document(page_content=t) for t in texts]
60
+ split_docs = splitter.split_documents(docs)
61
+
62
+ # --- Embeddings + Chroma DB ---
63
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
64
+ vector_db = Chroma.from_documents(split_docs, embeddings, persist_directory="chroma_db")
65
+ retriever = vector_db.as_retriever(search_kwargs={"k": 3})
66
+
67
+ # --- LLM ---
68
+ llm = ChatGoogleGenerativeAI(model="models/gemini-2.5-flash", google_api_key=GOOGLE_API_KEY)
69
+
70
+ # --- PROMPT TEMPLATE ---
71
+ prompt = ChatPromptTemplate.from_template(
72
+ """
73
+ You are a helpful assistant. Use the following retrieved context to answer the user's question.
74
+ If the context doesn't contain the answer, say so politely.
75
+ Context:
76
+ {context}
77
+
78
+ Question:
79
+ {question}
80
+
81
+ Answer:
82
+ """
83
+ )
84
+
85
+ # --- NODES (GRAPH FUNCTIONS) ---
86
+ def retrieve_docs(state: RAGState):
87
+ query = state["question"]
88
+ docs = retriever.invoke(query)
89
+ context = "\n\n".join([d.page_content for d in docs])
90
+ return {"context": context, "source_documents": docs}
91
+
92
+ def generate_answer(state: RAGState):
93
+ prompt_text = prompt.format(context=state["context"], question=state["question"])
94
+ response = llm.invoke(prompt_text)
95
+ return {"answer": response.content}
96
+
97
+ # --- BUILD THE GRAPH ---
98
+ graph_builder = StateGraph(RAGState)
99
+ graph_builder.add_node("retrieve", retrieve_docs)
100
+ graph_builder.add_node("generate", generate_answer)
101
+ graph_builder.add_edge(START, "retrieve")
102
+ graph_builder.add_edge("retrieve", "generate")
103
+
104
+ # Add in-memory checkpointing (conversation memory)
105
+ memory = MemorySaver()
106
+
107
+ graph = graph_builder.compile(checkpointer=memory)
108
+
109
+ # Wrap in a callable interface so app.py still works
110
+ class RAGChainWrapper:
111
+ def __init__(self, graph):
112
+ self.graph = graph
113
+
114
+ def __call__(self, inputs: dict):
115
+ question = inputs.get("question", "")
116
+ state = {"question": question, "chat_history": []}
117
+ result = self.graph.invoke(
118
+ state,
119
+ config={"configurable": {"thread_id": "default"}}
120
+ )
121
+ return {
122
+ "answer": result.get("answer", ""),
123
+ "source_documents": result.get("source_documents", [])
124
+ }
125
+
126
+ rag_chain = RAGChainWrapper(graph)
127
+
128
+ return llm, retriever, rag_chain
src/streamlit_app.py ADDED
File without changes