MeteKaba commited on
Commit
3749f20
·
verified ·
1 Parent(s): ecb8437

Update src/rag_pipeline.py

Browse files
Files changed (1) hide show
  1. src/rag_pipeline.py +21 -3
src/rag_pipeline.py CHANGED
@@ -3,7 +3,6 @@ from datasets import load_dataset
3
  from langchain_community.embeddings import HuggingFaceEmbeddings
4
  from langchain_community.vectorstores import Chroma
5
  from langchain_text_splitters import CharacterTextSplitter
6
- from langchain.chat_models import ChatGoogleGenerativeAI
7
  from langchain_core.documents import Document
8
  from langgraph.graph import START, StateGraph
9
  from langgraph.checkpoint.memory import MemorySaver
@@ -12,6 +11,7 @@ from langchain_core.prompts import ChatPromptTemplate
12
  from huggingface_hub import login
13
  from dotenv import load_dotenv
14
  from typing import TypedDict, List
 
15
 
16
  # Load environment variables
17
  load_dotenv()
@@ -19,6 +19,9 @@ load_dotenv()
19
  GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
20
  HF_TOKEN = os.getenv("HF_TOKEN")
21
 
 
 
 
22
  # Authenticate Hugging Face
23
  if HF_TOKEN:
24
  try:
@@ -29,7 +32,6 @@ if HF_TOKEN:
29
  else:
30
  print("⚠️ No HF_TOKEN found in .env file. Using public mode.")
31
 
32
-
33
  # --- STATE DEFINITION ---
34
  class RAGState(TypedDict):
35
  question: str
@@ -38,6 +40,22 @@ class RAGState(TypedDict):
38
  chat_history: List[str]
39
  source_documents: List[Document]
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  def build_rag_pipeline():
43
  """Builds a LangGraph-based RAG pipeline compatible with LangChain 1.x."""
@@ -65,7 +83,7 @@ def build_rag_pipeline():
65
  retriever = vector_db.as_retriever(search_kwargs={"k": 3})
66
 
67
  # --- LLM ---
68
- llm = ChatGoogleGenerativeAI(model="models/gemini-2.5-flash", google_api_key=GOOGLE_API_KEY)
69
 
70
  # --- PROMPT TEMPLATE ---
71
  prompt = ChatPromptTemplate.from_template(
 
3
  from langchain_community.embeddings import HuggingFaceEmbeddings
4
  from langchain_community.vectorstores import Chroma
5
  from langchain_text_splitters import CharacterTextSplitter
 
6
  from langchain_core.documents import Document
7
  from langgraph.graph import START, StateGraph
8
  from langgraph.checkpoint.memory import MemorySaver
 
11
  from huggingface_hub import login
12
  from dotenv import load_dotenv
13
  from typing import TypedDict, List
14
+ import google.generativeai as genai # Official Google Gemini SDK
15
 
16
  # Load environment variables
17
  load_dotenv()
 
19
  GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
20
  HF_TOKEN = os.getenv("HF_TOKEN")
21
 
22
+ # Configure Google Gemini SDK
23
+ genai.configure(api_key=GOOGLE_API_KEY)
24
+
25
  # Authenticate Hugging Face
26
  if HF_TOKEN:
27
  try:
 
32
  else:
33
  print("⚠️ No HF_TOKEN found in .env file. Using public mode.")
34
 
 
35
  # --- STATE DEFINITION ---
36
  class RAGState(TypedDict):
37
  question: str
 
40
  chat_history: List[str]
41
  source_documents: List[Document]
42
 
43
+ # --- LLM Wrapper ---
44
+ class GeminiLLMWrapper:
45
+ """
46
+ A simple wrapper around google-generativeai chat API to mimic
47
+ the ChatGoogleGenerativeAI interface for compatibility with app.py.
48
+ """
49
+ def invoke(self, prompt: str):
50
+ response = genai.chat.create(
51
+ model="models/gemini-2.5-flash",
52
+ messages=[{"role": "user", "content": prompt}]
53
+ )
54
+ # Wrap the response to have a .content attribute
55
+ class Result:
56
+ content = response.last
57
+ return Result()
58
+
59
 
60
  def build_rag_pipeline():
61
  """Builds a LangGraph-based RAG pipeline compatible with LangChain 1.x."""
 
83
  retriever = vector_db.as_retriever(search_kwargs={"k": 3})
84
 
85
  # --- LLM ---
86
+ llm = GeminiLLMWrapper() # Use wrapper instead of ChatGoogleGenerativeAI
87
 
88
  # --- PROMPT TEMPLATE ---
89
  prompt = ChatPromptTemplate.from_template(