MaheshLEO4 commited on
Commit
afa47fa
Β·
1 Parent(s): 592ce9d

Add Gemini provider selection

Browse files
README.md CHANGED
@@ -15,8 +15,8 @@ Upload PDFs, index them, and chat with a multi-agent RAG workflow.
15
 
16
  1. Create a new Space and choose **Docker**.
17
  2. Upload this repository contents.
18
- 3. Add a secret named `GROQ_API_KEY` in **Settings β†’ Secrets**.
19
- 4. The app will start automatically.
20
 
21
  ## Notes
22
 
 
15
 
16
  1. Create a new Space and choose **Docker**.
17
  2. Upload this repository contents.
18
+ 3. Add a secret named `GROQ_API_KEY` or `GEMINI_API_KEY` in **Settings β†’ Secrets**.
19
+ 4. Choose the provider and model in the app sidebar.
20
 
21
  ## Notes
22
 
agents/base_agent.py CHANGED
@@ -1,7 +1,14 @@
1
  import os
2
  from pathlib import Path
3
  from langchain_groq import ChatGroq
4
- from config import GROQ_API_KEY, LLM_MODEL
 
 
 
 
 
 
 
5
  from utils import get_logger
6
 
7
  logger = get_logger(__name__)
@@ -16,23 +23,51 @@ class BaseAgent:
16
  provides a ChatGroq client.
17
  """
18
 
19
- def __init__(self, prompt_file: str, temperature: float = 0.0, max_tokens: int = 512):
20
- if not GROQ_API_KEY:
21
- raise EnvironmentError(
22
- "GROQ_API_KEY is not set. "
23
- "Add it to your .env file or Streamlit secrets."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  )
 
 
 
 
 
 
25
 
26
- self.llm = ChatGroq(
27
- model_name=LLM_MODEL,
28
- temperature=temperature,
29
- max_tokens=max_tokens,
30
- groq_api_key=GROQ_API_KEY,
31
- )
 
 
32
 
33
  prompt_path = PROMPT_DIR / prompt_file
34
  self.prompt_template = prompt_path.read_text(encoding="utf-8")
35
- logger.info(f"{self.__class__.__name__} ready (model={LLM_MODEL})")
 
 
36
 
37
  def _call_llm(self, prompt: str) -> str:
38
  response = self.llm.invoke(prompt)
 
1
  import os
2
  from pathlib import Path
3
  from langchain_groq import ChatGroq
4
+ from langchain_google_genai import ChatGoogleGenerativeAI
5
+ from config import (
6
+ GROQ_API_KEY,
7
+ GEMINI_API_KEY,
8
+ LLM_MODEL,
9
+ DEFAULT_PROVIDER,
10
+ DEFAULT_MODEL,
11
+ )
12
  from utils import get_logger
13
 
14
  logger = get_logger(__name__)
 
23
  provides a ChatGroq client.
24
  """
25
 
26
+ def __init__(
27
+ self,
28
+ prompt_file: str,
29
+ temperature: float = 0.0,
30
+ max_tokens: int = 512,
31
+ model_provider: str | None = None,
32
+ model_name: str | None = None,
33
+ ):
34
+ provider = (model_provider or DEFAULT_PROVIDER).lower()
35
+ model = model_name or DEFAULT_MODEL or LLM_MODEL
36
+
37
+ if provider == "groq":
38
+ if not GROQ_API_KEY:
39
+ raise EnvironmentError(
40
+ "GROQ_API_KEY is not set. "
41
+ "Add it to your .env file or Streamlit secrets."
42
+ )
43
+
44
+ self.llm = ChatGroq(
45
+ model_name=model,
46
+ temperature=temperature,
47
+ max_tokens=max_tokens,
48
+ groq_api_key=GROQ_API_KEY,
49
  )
50
+ elif provider == "gemini":
51
+ if not GEMINI_API_KEY:
52
+ raise EnvironmentError(
53
+ "GEMINI_API_KEY is not set. "
54
+ "Add it to your .env file or Streamlit secrets."
55
+ )
56
 
57
+ self.llm = ChatGoogleGenerativeAI(
58
+ model=model,
59
+ temperature=temperature,
60
+ max_output_tokens=max_tokens,
61
+ google_api_key=GEMINI_API_KEY,
62
+ )
63
+ else:
64
+ raise ValueError(f"Unknown model provider: {provider}")
65
 
66
  prompt_path = PROMPT_DIR / prompt_file
67
  self.prompt_template = prompt_path.read_text(encoding="utf-8")
68
+ logger.info(
69
+ f"{self.__class__.__name__} ready (provider={provider}, model={model})"
70
+ )
71
 
72
  def _call_llm(self, prompt: str) -> str:
73
  response = self.llm.invoke(prompt)
agents/relevance_agent.py CHANGED
@@ -13,8 +13,14 @@ class RelevanceAgent(BaseAgent):
13
  taking conversation history into account.
14
  """
15
 
16
- def __init__(self):
17
- super().__init__(prompt_file="relevance.txt", temperature=0.0, max_tokens=10)
 
 
 
 
 
 
18
 
19
  def check(self, question: str, documents: list[Document], history: str) -> str:
20
  """
 
13
  taking conversation history into account.
14
  """
15
 
16
+ def __init__(self, model_provider: str | None = None, model_name: str | None = None):
17
+ super().__init__(
18
+ prompt_file="relevance.txt",
19
+ temperature=0.0,
20
+ max_tokens=10,
21
+ model_provider=model_provider,
22
+ model_name=model_name,
23
+ )
24
 
25
  def check(self, question: str, documents: list[Document], history: str) -> str:
26
  """
agents/research_agent.py CHANGED
@@ -13,8 +13,14 @@ class ResearchAgent(BaseAgent):
13
  Also performs query rewriting when history is present.
14
  """
15
 
16
- def __init__(self):
17
- super().__init__(prompt_file="research.txt", temperature=0.1, max_tokens=600)
 
 
 
 
 
 
18
 
19
  # Load query-rewrite prompt from same prompts/ directory
20
  from pathlib import Path
 
13
  Also performs query rewriting when history is present.
14
  """
15
 
16
+ def __init__(self, model_provider: str | None = None, model_name: str | None = None):
17
+ super().__init__(
18
+ prompt_file="research.txt",
19
+ temperature=0.1,
20
+ max_tokens=600,
21
+ model_provider=model_provider,
22
+ model_name=model_name,
23
+ )
24
 
25
  # Load query-rewrite prompt from same prompts/ directory
26
  from pathlib import Path
agents/verification_agent.py CHANGED
@@ -10,8 +10,14 @@ class VerificationAgent(BaseAgent):
10
  Checks whether the draft answer is grounded in the retrieved documents.
11
  """
12
 
13
- def __init__(self):
14
- super().__init__(prompt_file="verification.txt", temperature=0.0, max_tokens=220)
 
 
 
 
 
 
15
 
16
  def check(self, answer: str, documents: list[Document]) -> dict:
17
  """
 
10
  Checks whether the draft answer is grounded in the retrieved documents.
11
  """
12
 
13
+ def __init__(self, model_provider: str | None = None, model_name: str | None = None):
14
+ super().__init__(
15
+ prompt_file="verification.txt",
16
+ temperature=0.0,
17
+ max_tokens=220,
18
+ model_provider=model_provider,
19
+ model_name=model_name,
20
+ )
21
 
22
  def check(self, answer: str, documents: list[Document]) -> dict:
23
  """
app.py CHANGED
@@ -4,7 +4,16 @@ import streamlit as st
4
  from ingestion import ingest_pdfs
5
  from retriever import HybridRetriever
6
  from graph import AgentWorkflow, Turn
7
- from config import UPLOAD_DIR, INDEX_DIR
 
 
 
 
 
 
 
 
 
8
 
9
  # ─────────────────────────────────────────────────────────────────────────────
10
  # Page config
@@ -27,6 +36,26 @@ with st.sidebar:
27
  "πŸ” **Verification Mode**: ~6–10 s β€” checks answer quality"
28
  )
29
  st.divider()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  st.caption("Conversation memory: last **4** Q&A pairs")
31
 
32
  # ─────────────────────────────────────────────────────────────────────────────
@@ -40,6 +69,8 @@ defaults = {
40
  "retriever": None,
41
  "files_indexed": False,
42
  "uploaded_file_names": set(),
 
 
43
  }
44
  for key, val in defaults.items():
45
  if key not in st.session_state:
@@ -127,6 +158,13 @@ for msg in st.session_state.chat_history:
127
  question = st.chat_input("Ask a question about your uploaded PDFs…")
128
 
129
  if question:
 
 
 
 
 
 
 
130
  if not os.path.exists(INDEX_DIR) or not os.listdir(INDEX_DIR):
131
  st.warning("⚠️ Please upload and index PDFs first.")
132
  st.stop()
@@ -144,6 +182,8 @@ if question:
144
  question=question,
145
  retriever=st.session_state.retriever,
146
  conversation_history=st.session_state.conversation_history,
 
 
147
  )
148
 
149
  # ── Persist updated history window back to session ────────────────────
 
4
  from ingestion import ingest_pdfs
5
  from retriever import HybridRetriever
6
  from graph import AgentWorkflow, Turn
7
+ from config import (
8
+ UPLOAD_DIR,
9
+ INDEX_DIR,
10
+ GROQ_FREE_MODELS,
11
+ GEMINI_FREE_MODELS,
12
+ DEFAULT_PROVIDER,
13
+ DEFAULT_MODEL,
14
+ GROQ_API_KEY,
15
+ GEMINI_API_KEY,
16
+ )
17
 
18
  # ─────────────────────────────────────────────────────────────────────────────
19
  # Page config
 
36
  "πŸ” **Verification Mode**: ~6–10 s β€” checks answer quality"
37
  )
38
  st.divider()
39
+ st.subheader("Model")
40
+
41
+ provider_labels = ["Groq", "Gemini"]
42
+ provider_index = 0 if st.session_state.model_provider == "groq" else 1
43
+ provider_label = st.selectbox("Provider", provider_labels, index=provider_index)
44
+ model_provider = provider_label.lower()
45
+
46
+ model_options = GROQ_FREE_MODELS if model_provider == "groq" else GEMINI_FREE_MODELS
47
+ if st.session_state.model_name not in model_options:
48
+ st.session_state.model_name = model_options[0]
49
+
50
+ model_name = st.selectbox(
51
+ "Model",
52
+ model_options,
53
+ index=model_options.index(st.session_state.model_name),
54
+ )
55
+
56
+ st.session_state.model_provider = model_provider
57
+ st.session_state.model_name = model_name
58
+ st.divider()
59
  st.caption("Conversation memory: last **4** Q&A pairs")
60
 
61
  # ─────────────────────────────────────────────────────────────────────────────
 
69
  "retriever": None,
70
  "files_indexed": False,
71
  "uploaded_file_names": set(),
72
+ "model_provider": DEFAULT_PROVIDER,
73
+ "model_name": DEFAULT_MODEL,
74
  }
75
  for key, val in defaults.items():
76
  if key not in st.session_state:
 
158
  question = st.chat_input("Ask a question about your uploaded PDFs…")
159
 
160
  if question:
161
+ if st.session_state.model_provider == "groq" and not GROQ_API_KEY:
162
+ st.error("GROQ_API_KEY is not set. Add it to your secrets or .env file.")
163
+ st.stop()
164
+ if st.session_state.model_provider == "gemini" and not GEMINI_API_KEY:
165
+ st.error("GEMINI_API_KEY is not set. Add it to your secrets or .env file.")
166
+ st.stop()
167
+
168
  if not os.path.exists(INDEX_DIR) or not os.listdir(INDEX_DIR):
169
  st.warning("⚠️ Please upload and index PDFs first.")
170
  st.stop()
 
182
  question=question,
183
  retriever=st.session_state.retriever,
184
  conversation_history=st.session_state.conversation_history,
185
+ model_provider=st.session_state.model_provider,
186
+ model_name=st.session_state.model_name,
187
  )
188
 
189
  # ── Persist updated history window back to session ────────────────────
config.py CHANGED
@@ -28,7 +28,22 @@ BATCH_SIZE = 1000 # nodes per indexing batch for large PDFs
28
 
29
  # ── LLM ──────────────────────────────────────────────────────────────────────
30
  GROQ_API_KEY = os.getenv("GROQ_API_KEY")
 
31
  LLM_MODEL = "llama-3.1-8b-instant"
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  # ── Workflow ──────────────────────────────────────────────────────────────────
34
  MAX_ITERATIONS = 2 # max researchβ†’verify loops before forcing end
 
28
 
29
  # ── LLM ──────────────────────────────────────────────────────────────────────
30
  GROQ_API_KEY = os.getenv("GROQ_API_KEY")
31
+ GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
32
  LLM_MODEL = "llama-3.1-8b-instant"
33
 
34
+ GROQ_FREE_MODELS = [
35
+ "llama-3.1-8b-instant",
36
+ "llama-3.1-70b-versatile",
37
+ "mixtral-8x7b-32768",
38
+ ]
39
+
40
+ GEMINI_FREE_MODELS = [
41
+ "gemini-1.5-flash",
42
+ "gemini-1.5-flash-8b",
43
+ ]
44
+
45
+ DEFAULT_PROVIDER = "groq"
46
+ DEFAULT_MODEL = GROQ_FREE_MODELS[0]
47
+
48
  # ── Workflow ──────────────────────────────────────────────────────────────────
49
  MAX_ITERATIONS = 2 # max researchβ†’verify loops before forcing end
graph/nodes.py CHANGED
@@ -39,7 +39,10 @@ def rewrite_query_node(state: AgentState) -> dict:
39
 
40
  logger.info("Node: rewrite_query")
41
  history = _format_history(state)
42
- agent = ResearchAgent()
 
 
 
43
  rewritten = agent.rewrite_query(state["question"], history)
44
  return {"rewritten_query": rewritten}
45
 
@@ -55,7 +58,10 @@ def check_relevance_node(state: AgentState) -> dict:
55
 
56
  logger.info("Node: check_relevance")
57
  history = _format_history(state)
58
- agent = RelevanceAgent()
 
 
 
59
 
60
  label = agent.check(
61
  question=state["rewritten_query"],
@@ -85,7 +91,10 @@ def research_node(state: AgentState) -> dict:
85
 
86
  logger.info("Node: research")
87
  history = _format_history(state)
88
- agent = ResearchAgent()
 
 
 
89
 
90
  result = agent.generate(
91
  question=state["rewritten_query"],
@@ -110,7 +119,10 @@ def verify_node(state: AgentState) -> dict:
110
  from agents.verification_agent import VerificationAgent
111
 
112
  logger.info("Node: verify")
113
- agent = VerificationAgent()
 
 
 
114
  result = agent.check(
115
  answer=state["draft_answer"],
116
  documents=state["documents"],
 
39
 
40
  logger.info("Node: rewrite_query")
41
  history = _format_history(state)
42
+ agent = ResearchAgent(
43
+ model_provider=state.get("model_provider"),
44
+ model_name=state.get("model_name"),
45
+ )
46
  rewritten = agent.rewrite_query(state["question"], history)
47
  return {"rewritten_query": rewritten}
48
 
 
58
 
59
  logger.info("Node: check_relevance")
60
  history = _format_history(state)
61
+ agent = RelevanceAgent(
62
+ model_provider=state.get("model_provider"),
63
+ model_name=state.get("model_name"),
64
+ )
65
 
66
  label = agent.check(
67
  question=state["rewritten_query"],
 
91
 
92
  logger.info("Node: research")
93
  history = _format_history(state)
94
+ agent = ResearchAgent(
95
+ model_provider=state.get("model_provider"),
96
+ model_name=state.get("model_name"),
97
+ )
98
 
99
  result = agent.generate(
100
  question=state["rewritten_query"],
 
119
  from agents.verification_agent import VerificationAgent
120
 
121
  logger.info("Node: verify")
122
+ agent = VerificationAgent(
123
+ model_provider=state.get("model_provider"),
124
+ model_name=state.get("model_name"),
125
+ )
126
  result = agent.check(
127
  answer=state["draft_answer"],
128
  documents=state["documents"],
graph/state.py CHANGED
@@ -32,3 +32,5 @@ class AgentState(TypedDict):
32
  retriever: Any # HybridRetriever instance (passed through)
33
  iteration_count: int # tracks researchβ†’verify loops
34
  enable_verification: bool # toggle slower verification path
 
 
 
32
  retriever: Any # HybridRetriever instance (passed through)
33
  iteration_count: int # tracks researchβ†’verify loops
34
  enable_verification: bool # toggle slower verification path
35
+ model_provider: str # "groq" | "gemini"
36
+ model_name: str # selected model name
graph/workflow.py CHANGED
@@ -132,6 +132,8 @@ class AgentWorkflow:
132
  question: str,
133
  retriever: Any,
134
  conversation_history: list[Turn] | None = None,
 
 
135
  ) -> dict:
136
  """
137
  Run the full pipeline for one user turn.
@@ -183,6 +185,8 @@ class AgentWorkflow:
183
  "retriever": retriever,
184
  "iteration_count": 0,
185
  "enable_verification": self.enable_verification,
 
 
186
  }
187
 
188
  try:
 
132
  question: str,
133
  retriever: Any,
134
  conversation_history: list[Turn] | None = None,
135
+ model_provider: str | None = None,
136
+ model_name: str | None = None,
137
  ) -> dict:
138
  """
139
  Run the full pipeline for one user turn.
 
185
  "retriever": retriever,
186
  "iteration_count": 0,
187
  "enable_verification": self.enable_verification,
188
+ "model_provider": model_provider or "groq",
189
+ "model_name": model_name or "",
190
  }
191
 
192
  try:
requirements.txt CHANGED
@@ -14,3 +14,4 @@ pypdf>=4.2.0
14
  langchain>=0.1.20
15
  langgraph>=0.0.40
16
  langchain-groq>=0.1.4
 
 
14
  langchain>=0.1.20
15
  langgraph>=0.0.40
16
  langchain-groq>=0.1.4
17
+ langchain-google-genai>=1.0.7