juliaturc commited on
Commit
a520549
·
1 Parent(s): fc7dede

Fix Pinecone import (#18)

Browse files
Files changed (2) hide show
  1. src/chat.py +8 -1
  2. src/vector_store.py +2 -2
src/chat.py CHANGED
@@ -77,7 +77,6 @@ if __name__ == "__main__":
77
  parser.add_argument("--llm_provider", default="anthropic", choices=["openai", "anthropic", "ollama"])
78
  parser.add_argument(
79
  "--llm_model",
80
- default="claude-3-opus-20240229",
81
  help="The LLM name. Must be supported by the provider specified via --llm_provider.",
82
  )
83
  parser.add_argument("--vector_store_type", default="pinecone", choices=["pinecone", "marqo"])
@@ -94,6 +93,14 @@ if __name__ == "__main__":
94
  )
95
  args = parser.parse_args()
96
 
 
 
 
 
 
 
 
 
97
  rag_chain = build_rag_chain(args)
98
 
99
  def _predict(message, history):
 
77
  parser.add_argument("--llm_provider", default="anthropic", choices=["openai", "anthropic", "ollama"])
78
  parser.add_argument(
79
  "--llm_model",
 
80
  help="The LLM name. Must be supported by the provider specified via --llm_provider.",
81
  )
82
  parser.add_argument("--vector_store_type", default="pinecone", choices=["pinecone", "marqo"])
 
93
  )
94
  args = parser.parse_args()
95
 
96
+ if not args.llm_model:
97
+ if args.llm_provider == "openai":
98
+ args.llm_model = "gpt-4"
99
+ elif args.llm_provider == "anthropic":
100
+ args.llm_model = "claude-3-opus-20240229"
101
+ else:
102
+ raise ValueError("Please specify --llm_model")
103
+
104
  rag_chain = build_rag_chain(args)
105
 
106
  def _predict(message, history):
src/vector_store.py CHANGED
@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
4
  from typing import Dict, Generator, List, Tuple
5
 
6
  import marqo
7
- from langchain_community.vectorstores import Marqo
8
  from langchain_core.documents import Document
9
  from langchain_openai import OpenAIEmbeddings
10
  from pinecone import Pinecone
@@ -61,7 +61,7 @@ class PineconeVectorStore(VectorStore):
61
  self.index.upsert(vectors=pinecone_vectors, namespace=self.namespace)
62
 
63
  def to_langchain(self):
64
- return Pinecone.from_existing_index(
65
  index_name=self.index_name, embedding=OpenAIEmbeddings(), namespace=self.namespace
66
  )
67
 
 
4
  from typing import Dict, Generator, List, Tuple
5
 
6
  import marqo
7
+ from langchain_community.vectorstores import Marqo, Pinecone as LangChainPinecone
8
  from langchain_core.documents import Document
9
  from langchain_openai import OpenAIEmbeddings
10
  from pinecone import Pinecone
 
61
  self.index.upsert(vectors=pinecone_vectors, namespace=self.namespace)
62
 
63
  def to_langchain(self):
64
+ return LangChainPinecone.from_existing_index(
65
  index_name=self.index_name, embedding=OpenAIEmbeddings(), namespace=self.namespace
66
  )
67