thinker / app.py
hoduyquocbao's picture
fix: Xử lý an toàn tool_calls trong hàm respond
82e5fad
import os
from enum import Enum
from typing import Union, Iterator
from pydantic import BaseModel
from openai import OpenAI
import pytz
from datetime import datetime
import gradio as gr
from dotenv import load_dotenv
import logging
import json
# Cấu hình logging
logging.getLogger("urllib3").setLevel(logging.WARNING)
logging.getLogger("gradio").setLevel(logging.WARNING)
logger = logging.getLogger("thinker")
logger.setLevel(logging.DEBUG)
console_handler = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
# Load environment variables
load_dotenv()
# Định nghĩa các models cho functions
class GetWeather(BaseModel):
location: str
class GetTime(BaseModel):
timezone: str
# Implement các functions
def get_weather(location: str) -> str:
"""Giả lập lấy thông tin thời tiết"""
return f"Thời tiết tại {location}: 30°C, Nắng nhẹ, Độ ẩm: 70%"
def get_time(timezone: str) -> str:
"""Lấy thời gian theo múi giờ"""
try:
tz = pytz.timezone(timezone)
current_time = datetime.now(tz)
return current_time.strftime("%Y-%m-%d %H:%M:%S %Z")
except:
return f"Không thể lấy thời gian cho múi giờ {timezone}"
# Định nghĩa tools theo schema của OpenAI
tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Lấy thông tin thời tiết cho một địa điểm",
"parameters": GetWeather.model_json_schema()
}
},
{
"type": "function",
"function": {
"name": "get_time",
"description": "Lấy thời gian hiện tại cho một múi giờ",
"parameters": GetTime.model_json_schema()
}
}
]
# Khởi tạo OpenAI client
try:
client = OpenAI(
base_url="https://api-inference.huggingface.co/v1/",
api_key=os.getenv("HF_TOKEN")
)
logger.info("✅ Đã khởi tạo OpenAI client thành công")
except Exception as e:
logger.error(f"❌ Lỗi khởi tạo OpenAI client: {str(e)}")
raise
def respond(
message: str,
history: list[tuple[str, str]],
system_message: str = "You are a helpful assistant. Use the supplied tools to assist the user.",
max_tokens: int = 8192,
temperature: float = 0.1,
top_p: float = 0.7,
) -> Iterator[str]:
try:
messages = [{"role": "system", "content": system_message}]
for user_msg, assistant_msg in history:
messages.append({"role": "user", "content": user_msg})
messages.append({"role": "assistant", "content": assistant_msg})
messages.append({"role": "user", "content": message})
logger.info(f"📝 Xử lý tin nhắn mới: {message}")
logger.debug(f"Messages: {messages}")
stream = client.chat.completions.create(
model="Qwen/Qwen2.5-72B-Instruct",
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
stream=True,
tools=tools
)
partial_message = ""
has_content = False
current_tool_call = {
"function": "",
"arguments": ""
}
for chunk in stream:
logger.debug(f"Chunk received: {chunk}")
if not chunk.choices:
continue
delta = chunk.choices[0].delta
if hasattr(delta, 'tool_calls') and delta.tool_calls:
# Kiểm tra tool_calls là list và có phần tử
if isinstance(delta.tool_calls, list) and len(delta.tool_calls) > 0:
tool_call = delta.tool_calls[0]
else:
# Xử lý tool_calls là dict
tool_call = delta.tool_calls
# Xử lý function name
if hasattr(tool_call, 'function') and hasattr(tool_call.function, 'name'):
current_tool_call["function"] = tool_call.function.name
# Xử lý arguments
if hasattr(tool_call, 'function') and hasattr(tool_call.function, 'arguments'):
current_tool_call["arguments"] += tool_call.function.arguments
# Thực thi function khi nhận đủ arguments
if current_tool_call["arguments"].endswith('}'):
try:
args = json.loads(current_tool_call["arguments"])
if current_tool_call["function"] == "get_weather":
result = get_weather(args["location"])
has_content = True
yield result
elif current_tool_call["function"] == "get_time":
result = get_time(args["timezone"])
has_content = True
yield result
except json.JSONDecodeError:
logger.error(f"Invalid JSON: {current_tool_call['arguments']}")
current_tool_call = {"function": "", "arguments": ""}
elif hasattr(delta, 'content') and delta.content:
content = delta.content
has_content = True
partial_message += content
yield partial_message
if not has_content:
logger.warning("⚠️ Không nhận được nội dung từ API")
yield "Xin lỗi, tôi không thể xử lý yêu cầu này. Vui lòng thử lại sau."
logger.info("✅ Đã hoàn thành xử lý tin nhắn")
except Exception as e:
error_msg = f"❌ Lỗi trong quá trình xử lý: {str(e)}"
logger.error(error_msg)
logger.exception(e)
yield "Xin lỗi, tôi đang gặp vấn đề khi xử lý yêu cầu của bạn. Vui lòng thử lại sau."
# Tạo giao diện Gradio
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox("You are a helpful assistant. Use the supplied tools to assist the user.",
label="System Message"),
gr.Slider(0, 8192, value=8192, step=1, label="Max Tokens"),
gr.Slider(0, 2.0, value=0.1, step=0.1, label="Temperature"),
gr.Slider(0, 1.0, value=0.7, step=0.05, label="Top P"),
],
title="AI Chat",
description="Chat với AI sử dụng Qwen2.5-72B-Instruct",
)
if __name__ == "__main__":
is_space = os.getenv("SPACE_ID") is not None
demo.queue().launch(
share=not is_space,
server_port=7860,
server_name="0.0.0.0",
show_error=True,
)