Fix Bedrock system prompt (#2062)
Browse files### What problem does this PR solve?
Bugfix: usage of Bedrock models require the system prompt (for models
that support it) to be provided in the API in a different way, at least
that was my experience with it just today. This PR fixes it.
https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html
### Type of change
- [x] Bug Fix (non-breaking change which fixes an issue)
- rag/llm/chat_model.py +4 -6
rag/llm/chat_model.py
CHANGED
|
@@ -667,8 +667,6 @@ class BedrockChat(Base):
|
|
| 667 |
|
| 668 |
def chat(self, system, history, gen_conf):
|
| 669 |
from botocore.exceptions import ClientError
|
| 670 |
-
if system:
|
| 671 |
-
history.insert(0, {"role": "system", "content": system})
|
| 672 |
for k in list(gen_conf.keys()):
|
| 673 |
if k not in ["temperature", "top_p", "max_tokens"]:
|
| 674 |
del gen_conf[k]
|
|
@@ -688,7 +686,8 @@ class BedrockChat(Base):
|
|
| 688 |
response = self.client.converse(
|
| 689 |
modelId=self.model_name,
|
| 690 |
messages=history,
|
| 691 |
-
inferenceConfig=gen_conf
|
|
|
|
| 692 |
)
|
| 693 |
|
| 694 |
# Extract and print the response text.
|
|
@@ -700,8 +699,6 @@ class BedrockChat(Base):
|
|
| 700 |
|
| 701 |
def chat_streamly(self, system, history, gen_conf):
|
| 702 |
from botocore.exceptions import ClientError
|
| 703 |
-
if system:
|
| 704 |
-
history.insert(0, {"role": "system", "content": system})
|
| 705 |
for k in list(gen_conf.keys()):
|
| 706 |
if k not in ["temperature", "top_p", "max_tokens"]:
|
| 707 |
del gen_conf[k]
|
|
@@ -720,7 +717,8 @@ class BedrockChat(Base):
|
|
| 720 |
response = self.client.converse(
|
| 721 |
modelId=self.model_name,
|
| 722 |
messages=history,
|
| 723 |
-
inferenceConfig=gen_conf
|
|
|
|
| 724 |
)
|
| 725 |
ans = response["output"]["message"]["content"][0]["text"]
|
| 726 |
return ans, num_tokens_from_string(ans)
|
|
|
|
| 667 |
|
| 668 |
def chat(self, system, history, gen_conf):
|
| 669 |
from botocore.exceptions import ClientError
|
|
|
|
|
|
|
| 670 |
for k in list(gen_conf.keys()):
|
| 671 |
if k not in ["temperature", "top_p", "max_tokens"]:
|
| 672 |
del gen_conf[k]
|
|
|
|
| 686 |
response = self.client.converse(
|
| 687 |
modelId=self.model_name,
|
| 688 |
messages=history,
|
| 689 |
+
inferenceConfig=gen_conf,
|
| 690 |
+
system=[{"text": system}] if system else None,
|
| 691 |
)
|
| 692 |
|
| 693 |
# Extract and print the response text.
|
|
|
|
| 699 |
|
| 700 |
def chat_streamly(self, system, history, gen_conf):
|
| 701 |
from botocore.exceptions import ClientError
|
|
|
|
|
|
|
| 702 |
for k in list(gen_conf.keys()):
|
| 703 |
if k not in ["temperature", "top_p", "max_tokens"]:
|
| 704 |
del gen_conf[k]
|
|
|
|
| 717 |
response = self.client.converse(
|
| 718 |
modelId=self.model_name,
|
| 719 |
messages=history,
|
| 720 |
+
inferenceConfig=gen_conf,
|
| 721 |
+
system=[{"text": system}] if system else None,
|
| 722 |
)
|
| 723 |
ans = response["output"]["message"]["content"][0]["text"]
|
| 724 |
return ans, num_tokens_from_string(ans)
|