Spaces:
Sleeping
Sleeping
| import os | |
| from enum import Enum | |
| from typing import Union, Iterator | |
| from pydantic import BaseModel | |
| from openai import OpenAI | |
| import pytz | |
| from datetime import datetime | |
| import gradio as gr | |
| from dotenv import load_dotenv | |
| import logging | |
| import json | |
| # Cấu hình logging | |
| logging.getLogger("urllib3").setLevel(logging.WARNING) | |
| logging.getLogger("gradio").setLevel(logging.WARNING) | |
| logger = logging.getLogger("thinker") | |
| logger.setLevel(logging.DEBUG) | |
| console_handler = logging.StreamHandler() | |
| formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(message)s', datefmt='%Y-%m-%d %H:%M:%S') | |
| console_handler.setFormatter(formatter) | |
| logger.addHandler(console_handler) | |
| # Load environment variables | |
| load_dotenv() | |
| # Định nghĩa các models cho functions | |
| class GetWeather(BaseModel): | |
| location: str | |
| class GetTime(BaseModel): | |
| timezone: str | |
| # Implement các functions | |
| def get_weather(location: str) -> str: | |
| """Giả lập lấy thông tin thời tiết""" | |
| return f"Thời tiết tại {location}: 30°C, Nắng nhẹ, Độ ẩm: 70%" | |
| def get_time(timezone: str) -> str: | |
| """Lấy thời gian theo múi giờ""" | |
| try: | |
| tz = pytz.timezone(timezone) | |
| current_time = datetime.now(tz) | |
| return current_time.strftime("%Y-%m-%d %H:%M:%S %Z") | |
| except: | |
| return f"Không thể lấy thời gian cho múi giờ {timezone}" | |
| # Định nghĩa tools theo schema của OpenAI | |
| tools = [ | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": "get_weather", | |
| "description": "Lấy thông tin thời tiết cho một địa điểm", | |
| "parameters": GetWeather.model_json_schema() | |
| } | |
| }, | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": "get_time", | |
| "description": "Lấy thời gian hiện tại cho một múi giờ", | |
| "parameters": GetTime.model_json_schema() | |
| } | |
| } | |
| ] | |
| # Khởi tạo OpenAI client | |
| try: | |
| client = OpenAI( | |
| base_url="https://api-inference.huggingface.co/v1/", | |
| api_key=os.getenv("HF_TOKEN") | |
| ) | |
| logger.info("✅ Đã khởi tạo OpenAI client thành công") | |
| except Exception as e: | |
| logger.error(f"❌ Lỗi khởi tạo OpenAI client: {str(e)}") | |
| raise | |
| def respond( | |
| message: str, | |
| history: list[tuple[str, str]], | |
| system_message: str = "You are a helpful assistant. Use the supplied tools to assist the user.", | |
| max_tokens: int = 8192, | |
| temperature: float = 0.1, | |
| top_p: float = 0.7, | |
| ) -> Iterator[str]: | |
| try: | |
| messages = [{"role": "system", "content": system_message}] | |
| for user_msg, assistant_msg in history: | |
| messages.append({"role": "user", "content": user_msg}) | |
| messages.append({"role": "assistant", "content": assistant_msg}) | |
| messages.append({"role": "user", "content": message}) | |
| logger.info(f"📝 Xử lý tin nhắn mới: {message}") | |
| logger.debug(f"Messages: {messages}") | |
| stream = client.chat.completions.create( | |
| model="Qwen/Qwen2.5-72B-Instruct", | |
| messages=messages, | |
| temperature=temperature, | |
| max_tokens=max_tokens, | |
| top_p=top_p, | |
| stream=True, | |
| tools=tools | |
| ) | |
| partial_message = "" | |
| has_content = False | |
| current_tool_call = { | |
| "function": "", | |
| "arguments": "" | |
| } | |
| for chunk in stream: | |
| logger.debug(f"Chunk received: {chunk}") | |
| if not chunk.choices: | |
| continue | |
| delta = chunk.choices[0].delta | |
| if hasattr(delta, 'tool_calls') and delta.tool_calls: | |
| # Kiểm tra tool_calls là list và có phần tử | |
| if isinstance(delta.tool_calls, list) and len(delta.tool_calls) > 0: | |
| tool_call = delta.tool_calls[0] | |
| else: | |
| # Xử lý tool_calls là dict | |
| tool_call = delta.tool_calls | |
| # Xử lý function name | |
| if hasattr(tool_call, 'function') and hasattr(tool_call.function, 'name'): | |
| current_tool_call["function"] = tool_call.function.name | |
| # Xử lý arguments | |
| if hasattr(tool_call, 'function') and hasattr(tool_call.function, 'arguments'): | |
| current_tool_call["arguments"] += tool_call.function.arguments | |
| # Thực thi function khi nhận đủ arguments | |
| if current_tool_call["arguments"].endswith('}'): | |
| try: | |
| args = json.loads(current_tool_call["arguments"]) | |
| if current_tool_call["function"] == "get_weather": | |
| result = get_weather(args["location"]) | |
| has_content = True | |
| yield result | |
| elif current_tool_call["function"] == "get_time": | |
| result = get_time(args["timezone"]) | |
| has_content = True | |
| yield result | |
| except json.JSONDecodeError: | |
| logger.error(f"Invalid JSON: {current_tool_call['arguments']}") | |
| current_tool_call = {"function": "", "arguments": ""} | |
| elif hasattr(delta, 'content') and delta.content: | |
| content = delta.content | |
| has_content = True | |
| partial_message += content | |
| yield partial_message | |
| if not has_content: | |
| logger.warning("⚠️ Không nhận được nội dung từ API") | |
| yield "Xin lỗi, tôi không thể xử lý yêu cầu này. Vui lòng thử lại sau." | |
| logger.info("✅ Đã hoàn thành xử lý tin nhắn") | |
| except Exception as e: | |
| error_msg = f"❌ Lỗi trong quá trình xử lý: {str(e)}" | |
| logger.error(error_msg) | |
| logger.exception(e) | |
| yield "Xin lỗi, tôi đang gặp vấn đề khi xử lý yêu cầu của bạn. Vui lòng thử lại sau." | |
| # Tạo giao diện Gradio | |
| demo = gr.ChatInterface( | |
| respond, | |
| additional_inputs=[ | |
| gr.Textbox("You are a helpful assistant. Use the supplied tools to assist the user.", | |
| label="System Message"), | |
| gr.Slider(0, 8192, value=8192, step=1, label="Max Tokens"), | |
| gr.Slider(0, 2.0, value=0.1, step=0.1, label="Temperature"), | |
| gr.Slider(0, 1.0, value=0.7, step=0.05, label="Top P"), | |
| ], | |
| title="AI Chat", | |
| description="Chat với AI sử dụng Qwen2.5-72B-Instruct", | |
| ) | |
| if __name__ == "__main__": | |
| is_space = os.getenv("SPACE_ID") is not None | |
| demo.queue().launch( | |
| share=not is_space, | |
| server_port=7860, | |
| server_name="0.0.0.0", | |
| show_error=True, | |
| ) |