Seth0330 commited on
Commit
165c3c2
·
verified ·
1 Parent(s): 36c52bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -1
app.py CHANGED
@@ -15,7 +15,11 @@ from langchain_openai import ChatOpenAI
15
  from langchain.prompts import ChatPromptTemplate
16
 
17
  DB_PATH = "json_vector.db"
 
 
18
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
 
 
19
  EMBEDDING_MODEL = "text-embedding-ada-002"
20
 
21
  if "ingested_batches" not in st.session_state:
@@ -226,7 +230,31 @@ prompt = ChatPromptTemplate.from_messages([
226
  ("human", "Here are the most relevant records:\n{context}\n\nQuestion: {question}")
227
  ])
228
 
229
- llm = ChatOpenAI(model="gpt-4.1", openai_api_key=OPENAI_API_KEY, temperature=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  retriever = HybridRetriever(top_k=5)
231
  qa_chain = RetrievalQA.from_chain_type(
232
  llm=llm,
 
15
  from langchain.prompts import ChatPromptTemplate
16
 
17
  DB_PATH = "json_vector.db"
18
+
19
+ # Read API keys
20
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
21
+ OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY") # NEW
22
+
23
  EMBEDDING_MODEL = "text-embedding-ada-002"
24
 
25
  if "ingested_batches" not in st.session_state:
 
230
  ("human", "Here are the most relevant records:\n{context}\n\nQuestion: {question}")
231
  ])
232
 
233
+ # --- LLM PROVIDER SELECTION --- # NEW/MODIFIED FOR LLM SELECTION
234
+ llm_provider = st.selectbox(
235
+ "Select LLM Provider",
236
+ options=["OpenAI GPT-4", "Mistral (OpenRouter)"],
237
+ index=0,
238
+ help="Choose which LLM to use for answering your questions."
239
+ )
240
+
241
+ def get_llm(llm_provider):
242
+ if llm_provider == "OpenAI GPT-4":
243
+ return ChatOpenAI(
244
+ model="gpt-4.1",
245
+ openai_api_key=OPENAI_API_KEY,
246
+ temperature=0,
247
+ )
248
+ else: # "Mistral (OpenRouter)"
249
+ return ChatOpenAI(
250
+ model="mistralai/mistral-small-3.1-24b-instruct:free", # Or another Mistral model if desired
251
+ openai_api_key=OPENROUTER_API_KEY,
252
+ openai_api_base="https://openrouter.ai/api/v1",
253
+ temperature=0,
254
+ )
255
+
256
+ llm = get_llm(llm_provider)
257
+
258
  retriever = HybridRetriever(top_k=5)
259
  qa_chain = RetrievalQA.from_chain_type(
260
  llm=llm,