mylesai commited on
Commit
2235a91
·
verified ·
1 Parent(s): 859aeaa

Update chat_engine.py

Browse files
Files changed (1) hide show
  1. chat_engine.py +23 -10
chat_engine.py CHANGED
@@ -69,23 +69,36 @@ def extract_text_from_pdf(pdf_path):
69
 
70
 
71
  def get_api_type(api_type):
72
- if api_type == 'openai':
73
- # default is gpt-3.5-turbo, can also be gpt-4-0314
74
- return OpenAI(model='gpt-4o', temperature=0.9) # for QA, temp is low
75
- elif api_type == 'claude':
 
 
 
76
  return Anthropic(model="claude-3-5-sonnet-20240620")
77
- elif api_type == 'llama':
78
- return Anyscale(model='meta-llama/Llama-2-70b-chat-hf')
79
- elif api_type == 'mistral':
80
- return Anyscale(model='mistralai/Mixtral-8x7B-Instruct-v0.1', max_tokens=10000)
 
 
 
 
 
 
 
 
 
 
81
  else:
82
  raise NotImplementedError
83
 
84
 
85
- def get_chat_engine(files, progress=gr.Progress()):
86
 
87
  progress(0, desc="Uploading Documents...")
88
- llm = get_api_type('openai')
89
  Settings.llm = llm
90
  embed_model = MistralAIEmbedding(model_name='mistral-embed', api_key=mistral_api_key)
91
  Settings.embed_model = embed_model
 
69
 
70
 
71
  def get_api_type(api_type):
72
+ if model_name == 'openai-gpt-4o':
73
+ return OpenAI(model='gpt-4o', temperature=0.7)
74
+ elif model_name == 'openai-gpt-4-turbo':
75
+ return OpenAI(model='gpt-4-turbo', temperature=0.7'})
76
+ elif model_name == 'openai-gpt-3.5-turbo':
77
+ return OpenAI(model='gpt-3.5-turbo', temperature=0.7'})
78
+ elif model_name == 'claude-sonnet-3.5':
79
  return Anthropic(model="claude-3-5-sonnet-20240620")
80
+ elif model_name == 'claude-opus-3':
81
+ return Anthropic(model="claude-3-opus-20240229")
82
+ elif model_name == 'claude-sonnet-3':
83
+ return Anthropic(model="claude-3-sonnet-20240229")
84
+ elif model_name == 'claude-haiku-3':
85
+ return Anthropic(model="claude-3-haiku-20240307")
86
+ elif model_name == 'llama-3-70B':
87
+ return Anyscale(model='meta-llama/Meta-Llama-3-70B-Instruct')
88
+ elif model_name == 'llama-3-8B':
89
+ return Anyscale(model='meta-llama/Meta-Llama-3-70B-Instruct')
90
+ elif model_name == 'mistral-8x7B':
91
+ return Anyscale(model='mistralai/Mixtral-8x7B-Instruct-v0.1')
92
+ elif model_name == 'mistral-8x22B':
93
+ return Anyscale(model='mistralai/Mixtral-8x7B-Instruct-v0.1')
94
  else:
95
  raise NotImplementedError
96
 
97
 
98
+ def get_chat_engine(files, api_type, progress=gr.Progress()):
99
 
100
  progress(0, desc="Uploading Documents...")
101
+ llm = get_api_type(api_type)
102
  Settings.llm = llm
103
  embed_model = MistralAIEmbedding(model_name='mistral-embed', api_key=mistral_api_key)
104
  Settings.embed_model = embed_model