John Graham Reynolds commited on
Commit
102a4d9
·
1 Parent(s): c86902c

update chain to use open weights open-mistral-7b chat model

Browse files
Files changed (1) hide show
  1. src/chain.py +9 -4
src/chain.py CHANGED
@@ -15,7 +15,8 @@ class GlossaryChain:
15
  self.vector_store = load_vector_store()
16
  self.retriever = self.vector_store.as_retriever()
17
  self.llm = ChatMistralAI(
18
- model="mistral-large-latest",
 
19
  mistral_api_key=MISTRAL_API_KEY,
20
  temperature=0.2
21
  )
@@ -38,16 +39,20 @@ class GlossaryChain:
38
 
39
  def invoke(self, input: str) -> str:
40
  return self.chain.invoke(input=input)
41
-
42
 
43
  def format_docs(docs: list[Document]) -> str:
44
  """Format retrieved documents into a readable string"""
45
  return "\n\n".join(doc.page_content for doc in docs)
46
 
47
- @st.cache_data
48
  def load_vector_store() -> FAISS:
 
 
 
 
49
  return FAISS.load_local(
50
- folder_path="faiss_index",
51
  embeddings = MistralAIEmbeddings(model="mistral-embed", mistral_api_key=MISTRAL_API_KEY),
52
  allow_dangerous_deserialization=True
53
  )
 
15
  self.vector_store = load_vector_store()
16
  self.retriever = self.vector_store.as_retriever()
17
  self.llm = ChatMistralAI(
18
+ # model="mistral-large-latest", # "error","message":"Service tier capacity exceeded for this model.","type":"service_tier_capacity_exceeded
19
+ model="open-mistral-7b", # we must use the open-weight model
20
  mistral_api_key=MISTRAL_API_KEY,
21
  temperature=0.2
22
  )
 
39
 
40
  def invoke(self, input: str) -> str:
41
  return self.chain.invoke(input=input)
42
+
43
 
44
  def format_docs(docs: list[Document]) -> str:
45
  """Format retrieved documents into a readable string"""
46
  return "\n\n".join(doc.page_content for doc in docs)
47
 
48
+ @st.cache_resource
49
  def load_vector_store() -> FAISS:
50
+ import os
51
+ # Get the absolute path to the faiss_index directory
52
+ current_dir = os.path.dirname(os.path.abspath(__file__))
53
+ faiss_path = os.path.join(os.path.dirname(current_dir), "faiss_index")
54
  return FAISS.load_local(
55
+ folder_path=faiss_path,
56
  embeddings = MistralAIEmbeddings(model="mistral-embed", mistral_api_key=MISTRAL_API_KEY),
57
  allow_dangerous_deserialization=True
58
  )