|
|
from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT |
|
|
from ..presets import * |
|
|
from ..utils import * |
|
|
|
|
|
from .base_model import BaseLLMModel |
|
|
|
|
|
|
|
|
class Claude_Client(BaseLLMModel): |
|
|
def __init__(self, model_name, api_secret) -> None: |
|
|
super().__init__(model_name=model_name) |
|
|
self.api_secret = api_secret |
|
|
if None in [self.api_secret]: |
|
|
raise Exception("请在配置文件或者环境变量中设置Claude的API Secret") |
|
|
self.claude_client = Anthropic(api_key=self.api_secret) |
|
|
|
|
|
def _get_claude_style_history(self): |
|
|
history = [] |
|
|
image_buffer = [] |
|
|
image_count = 0 |
|
|
for message in self.history: |
|
|
if message["role"] == "user": |
|
|
content = [] |
|
|
if image_buffer: |
|
|
if image_count == 1: |
|
|
content.append( |
|
|
{ |
|
|
"type": "image", |
|
|
"source": { |
|
|
"type": "base64", |
|
|
"media_type": f"image/{self.get_image_type(image_buffer[0])}", |
|
|
"data": self.get_base64_image(image_buffer[0]), |
|
|
}, |
|
|
}, |
|
|
) |
|
|
else: |
|
|
image_buffer_length = len(image_buffer) |
|
|
for idx, image in enumerate(image_buffer): |
|
|
content.append( |
|
|
{"type": "text", "text": f"Image {image_count - image_buffer_length + idx + 1}:"}, |
|
|
) |
|
|
content.append( |
|
|
{ |
|
|
"type": "image", |
|
|
"source": { |
|
|
"type": "base64", |
|
|
"media_type": f"image/{self.get_image_type(image)}", |
|
|
"data": self.get_base64_image(image), |
|
|
}, |
|
|
}, |
|
|
) |
|
|
if content: |
|
|
content.append({"type": "text", "text": message["content"]}) |
|
|
history.append(construct_user(content)) |
|
|
image_buffer = [] |
|
|
else: |
|
|
history.append(message) |
|
|
elif message["role"] == "assistant": |
|
|
history.append(message) |
|
|
elif message["role"] == "image": |
|
|
image_buffer.append(message["content"]) |
|
|
image_count += 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return history |
|
|
|
|
|
def get_answer_stream_iter(self): |
|
|
system_prompt = self.system_prompt |
|
|
history = self._get_claude_style_history() |
|
|
|
|
|
try: |
|
|
with self.claude_client.messages.stream( |
|
|
model=self.model_name, |
|
|
max_tokens=self.max_generation_token, |
|
|
messages=history, |
|
|
system=system_prompt, |
|
|
) as stream: |
|
|
partial_text = "" |
|
|
for text in stream.text_stream: |
|
|
partial_text += text |
|
|
yield partial_text |
|
|
except Exception as e: |
|
|
yield i18n(GENERAL_ERROR_MSG) + ": " + str(e) |
|
|
|
|
|
def get_answer_at_once(self): |
|
|
system_prompt = self.system_prompt |
|
|
history = self._get_claude_style_history() |
|
|
|
|
|
response = self.claude_client.messages.create( |
|
|
model=self.model_name, |
|
|
max_tokens=self.max_generation_token, |
|
|
messages=history, |
|
|
system=system_prompt, |
|
|
) |
|
|
if response is not None: |
|
|
return response.content[0].text, response.usage.output_tokens |
|
|
else: |
|
|
return i18n("获取资源错误"), 0 |
|
|
|