hins111 commited on
Commit
0cbb38a
·
verified ·
1 Parent(s): 6033f88

Upload 4 files

Browse files
Files changed (4) hide show
  1. models/__init__.py +35 -0
  2. models/chat.py +37 -0
  3. models/exceptions.py +3 -0
  4. models/logging.py +108 -0
models/__init__.py CHANGED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 聊天相关模型
2
+ from .chat import (
3
+ FunctionCall,
4
+ ToolCall,
5
+ MessageContentItem,
6
+ Message,
7
+ ChatCompletionRequest
8
+ )
9
+
10
+ # 异常类
11
+ from .exceptions import ClientDisconnectedError
12
+
13
+ # 日志工具类
14
+ from .logging import (
15
+ StreamToLogger,
16
+ WebSocketConnectionManager,
17
+ WebSocketLogHandler
18
+ )
19
+
20
+ __all__ = [
21
+ # 聊天模型
22
+ 'FunctionCall',
23
+ 'ToolCall',
24
+ 'MessageContentItem',
25
+ 'Message',
26
+ 'ChatCompletionRequest',
27
+
28
+ # 异常
29
+ 'ClientDisconnectedError',
30
+
31
+ # 日志工具
32
+ 'StreamToLogger',
33
+ 'WebSocketConnectionManager',
34
+ 'WebSocketLogHandler'
35
+ ]
models/chat.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Union
2
+ from pydantic import BaseModel
3
+ from config import MODEL_NAME
4
+
5
+
6
+ class FunctionCall(BaseModel):
7
+ name: str
8
+ arguments: str
9
+
10
+
11
+ class ToolCall(BaseModel):
12
+ id: str
13
+ type: str = "function"
14
+ function: FunctionCall
15
+
16
+
17
+ class MessageContentItem(BaseModel):
18
+ type: str
19
+ text: Optional[str] = None
20
+
21
+
22
+ class Message(BaseModel):
23
+ role: str
24
+ content: Union[str, List[MessageContentItem], None] = None
25
+ name: Optional[str] = None
26
+ tool_calls: Optional[List[ToolCall]] = None
27
+ tool_call_id: Optional[str] = None
28
+
29
+
30
+ class ChatCompletionRequest(BaseModel):
31
+ messages: List[Message]
32
+ model: Optional[str] = MODEL_NAME
33
+ stream: Optional[bool] = False
34
+ temperature: Optional[float] = None
35
+ max_output_tokens: Optional[int] = None
36
+ stop: Optional[Union[str, List[str]]] = None
37
+ top_p: Optional[float] = None
models/exceptions.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ class ClientDisconnectedError(Exception):
2
+ """客户端断开连接异常"""
3
+ pass
models/logging.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import datetime
3
+ import json
4
+ import logging
5
+ import sys
6
+ from typing import Dict
7
+ from fastapi import WebSocket, WebSocketDisconnect
8
+
9
+
10
+ class StreamToLogger:
11
+ def __init__(self, logger_instance, log_level=logging.INFO):
12
+ self.logger = logger_instance
13
+ self.log_level = log_level
14
+ self.linebuf = ''
15
+
16
+ def write(self, buf):
17
+ try:
18
+ temp_linebuf = self.linebuf + buf
19
+ self.linebuf = ''
20
+ for line in temp_linebuf.splitlines(True):
21
+ if line.endswith(('\n', '\r')):
22
+ self.logger.log(self.log_level, line.rstrip())
23
+ else:
24
+ self.linebuf += line
25
+ except Exception as e:
26
+ print(f"StreamToLogger 错误: {e}", file=sys.__stderr__)
27
+
28
+ def flush(self):
29
+ try:
30
+ if self.linebuf != '':
31
+ self.logger.log(self.log_level, self.linebuf.rstrip())
32
+ self.linebuf = ''
33
+ except Exception as e:
34
+ print(f"StreamToLogger Flush 错误: {e}", file=sys.__stderr__)
35
+
36
+ def isatty(self):
37
+ return False
38
+
39
+
40
+ class WebSocketConnectionManager:
41
+ def __init__(self):
42
+ self.active_connections: Dict[str, WebSocket] = {}
43
+
44
+ async def connect(self, client_id: str, websocket: WebSocket):
45
+ await websocket.accept()
46
+ self.active_connections[client_id] = websocket
47
+ logger = logging.getLogger("AIStudioProxyServer")
48
+ logger.info(f"WebSocket 日志客户端已连接: {client_id}")
49
+ try:
50
+ await websocket.send_text(json.dumps({
51
+ "type": "connection_status",
52
+ "status": "connected",
53
+ "message": "已连接到实时日志流。",
54
+ "timestamp": datetime.datetime.now().isoformat()
55
+ }))
56
+ except Exception as e:
57
+ logger.warning(f"向 WebSocket 客户端 {client_id} 发送欢迎消息失败: {e}")
58
+
59
+ def disconnect(self, client_id: str):
60
+ if client_id in self.active_connections:
61
+ del self.active_connections[client_id]
62
+ logger = logging.getLogger("AIStudioProxyServer")
63
+ logger.info(f"WebSocket 日志客户端已断开: {client_id}")
64
+
65
+ async def broadcast(self, message: str):
66
+ if not self.active_connections:
67
+ return
68
+ disconnected_clients = []
69
+ active_conns_copy = list(self.active_connections.items())
70
+ logger = logging.getLogger("AIStudioProxyServer")
71
+ for client_id, connection in active_conns_copy:
72
+ try:
73
+ await connection.send_text(message)
74
+ except WebSocketDisconnect:
75
+ logger.info(f"[WS Broadcast] 客户端 {client_id} 在广播期间断开连接。")
76
+ disconnected_clients.append(client_id)
77
+ except RuntimeError as e:
78
+ if "Connection is closed" in str(e):
79
+ logger.info(f"[WS Broadcast] 客户端 {client_id} 的连接已关闭。")
80
+ disconnected_clients.append(client_id)
81
+ else:
82
+ logger.error(f"广播到 WebSocket {client_id} 时发生运行时错误: {e}")
83
+ disconnected_clients.append(client_id)
84
+ except Exception as e:
85
+ logger.error(f"广播到 WebSocket {client_id} 时发生未知错误: {e}")
86
+ disconnected_clients.append(client_id)
87
+ if disconnected_clients:
88
+ for client_id_to_remove in disconnected_clients:
89
+ self.disconnect(client_id_to_remove)
90
+
91
+
92
+ class WebSocketLogHandler(logging.Handler):
93
+ def __init__(self, manager: WebSocketConnectionManager):
94
+ super().__init__()
95
+ self.manager = manager
96
+ self.formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
97
+
98
+ def emit(self, record: logging.LogRecord):
99
+ if self.manager and self.manager.active_connections:
100
+ try:
101
+ log_entry_str = self.format(record)
102
+ try:
103
+ current_loop = asyncio.get_running_loop()
104
+ current_loop.create_task(self.manager.broadcast(log_entry_str))
105
+ except RuntimeError:
106
+ pass
107
+ except Exception as e:
108
+ print(f"WebSocketLogHandler 错误: 广播日志失败 - {e}", file=sys.__stderr__)