airsmodel / utils /chat_response.py
tanbushi's picture
update
702fae5
from pydantic import BaseModel
from typing import List, Optional, Dict, Any
import time
import re
# 聊天响应模型
class ChatChoice(BaseModel):
index: int
message: Dict[str, str]
finish_reason: str
class ChatUsage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
class ChatResponse(BaseModel):
id: str
object: str
created: int
model: str
choices: List[ChatChoice]
usage: ChatUsage
def convert_json_format(input_data):
"""转换 pipeline 输出格式"""
output_generations = []
for item in input_data:
generated_text_list = item.get('generated_text', [])
assistant_content = ""
for message in generated_text_list:
if message.get('role') == 'assistant':
assistant_content = message.get('content', '')
break
# 移除 </think>...</think> 标签
clean_content = re.sub(r'\s*', '', assistant_content, flags=re.DOTALL).strip()
output_generations.append([
{
"text": clean_content,
"generationInfo": {
"finish_reason": "stop"
}
}
])
return {"generations": output_generations}
def create_chat_response(request: Any, pipe=None, tokenizer=None) -> ChatResponse:
"""
创建聊天响应 - 使用 pipeline 生成实际响应
"""
if pipe is None:
# 如果 pipeline 未初始化,返回模拟响应
response_message = {
"role": "assistant",
"content": "模型正在初始化中,请稍后重试..."
}
completion_text = response_message["content"]
else:
# 使用 pipeline 生成响应
messages = request.messages
# 从 request 获取 max_new_tokens,如果没有则使用默认值 1000
# max_new_tokens = request.max_tokens if request.max_tokens is not None else 1000
max_new_tokens = request.max_tokens if request.max_tokens is not None else None
# 调用 pipeline
result = pipe(messages, max_new_tokens=max_new_tokens)
# result = pipe(messages)
# 转换格式
converted_result = convert_json_format(result)
# 获取生成的文本
completion_text = converted_result["generations"][0][0]["text"]
response_message = {
"role": "assistant",
"content": completion_text
}
# 计算 token 数量
if tokenizer:
prompt_tokens = sum(len(tokenizer.encode(msg.get("content", ""))) for msg in request.messages)
completion_tokens = len(tokenizer.encode(completion_text))
else:
# 简化估算
prompt_tokens = sum(len(msg.get("content", "")) for msg in request.messages) // 4
completion_tokens = len(completion_text) // 4
return ChatResponse(
id=f"chatcmpl-{int(time.time())}",
object="chat.completion",
created=int(time.time()),
model=request.model,
choices=[
ChatChoice(
index=0,
message=response_message,
finish_reason="stop"
)
],
usage=ChatUsage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
)
)