Spaces:
Runtime error
Runtime error
Update chat_engine.py
Browse files- chat_engine.py +23 -10
chat_engine.py
CHANGED
|
@@ -69,23 +69,36 @@ def extract_text_from_pdf(pdf_path):
|
|
| 69 |
|
| 70 |
|
| 71 |
def get_api_type(api_type):
|
| 72 |
-
if
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
| 76 |
return Anthropic(model="claude-3-5-sonnet-20240620")
|
| 77 |
-
elif
|
| 78 |
-
return
|
| 79 |
-
elif
|
| 80 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
else:
|
| 82 |
raise NotImplementedError
|
| 83 |
|
| 84 |
|
| 85 |
-
def get_chat_engine(files, progress=gr.Progress()):
|
| 86 |
|
| 87 |
progress(0, desc="Uploading Documents...")
|
| 88 |
-
llm = get_api_type(
|
| 89 |
Settings.llm = llm
|
| 90 |
embed_model = MistralAIEmbedding(model_name='mistral-embed', api_key=mistral_api_key)
|
| 91 |
Settings.embed_model = embed_model
|
|
|
|
| 69 |
|
| 70 |
|
| 71 |
def get_api_type(api_type):
|
| 72 |
+
if model_name == 'openai-gpt-4o':
|
| 73 |
+
return OpenAI(model='gpt-4o', temperature=0.7)
|
| 74 |
+
elif model_name == 'openai-gpt-4-turbo':
|
| 75 |
+
return OpenAI(model='gpt-4-turbo', temperature=0.7'})
|
| 76 |
+
elif model_name == 'openai-gpt-3.5-turbo':
|
| 77 |
+
return OpenAI(model='gpt-3.5-turbo', temperature=0.7'})
|
| 78 |
+
elif model_name == 'claude-sonnet-3.5':
|
| 79 |
return Anthropic(model="claude-3-5-sonnet-20240620")
|
| 80 |
+
elif model_name == 'claude-opus-3':
|
| 81 |
+
return Anthropic(model="claude-3-opus-20240229")
|
| 82 |
+
elif model_name == 'claude-sonnet-3':
|
| 83 |
+
return Anthropic(model="claude-3-sonnet-20240229")
|
| 84 |
+
elif model_name == 'claude-haiku-3':
|
| 85 |
+
return Anthropic(model="claude-3-haiku-20240307")
|
| 86 |
+
elif model_name == 'llama-3-70B':
|
| 87 |
+
return Anyscale(model='meta-llama/Meta-Llama-3-70B-Instruct')
|
| 88 |
+
elif model_name == 'llama-3-8B':
|
| 89 |
+
return Anyscale(model='meta-llama/Meta-Llama-3-70B-Instruct')
|
| 90 |
+
elif model_name == 'mistral-8x7B':
|
| 91 |
+
return Anyscale(model='mistralai/Mixtral-8x7B-Instruct-v0.1')
|
| 92 |
+
elif model_name == 'mistral-8x22B':
|
| 93 |
+
return Anyscale(model='mistralai/Mixtral-8x7B-Instruct-v0.1')
|
| 94 |
else:
|
| 95 |
raise NotImplementedError
|
| 96 |
|
| 97 |
|
| 98 |
+
def get_chat_engine(files, api_type, progress=gr.Progress()):
|
| 99 |
|
| 100 |
progress(0, desc="Uploading Documents...")
|
| 101 |
+
llm = get_api_type(api_type)
|
| 102 |
Settings.llm = llm
|
| 103 |
embed_model = MistralAIEmbedding(model_name='mistral-embed', api_key=mistral_api_key)
|
| 104 |
Settings.embed_model = embed_model
|