lightspeed commited on
Commit
a8c0fef
·
verified ·
1 Parent(s): 0a0e99d

Upload 3 files

Browse files
Files changed (3) hide show
  1. config.py +457 -0
  2. log.py +143 -0
  3. multi_user_auth_web.py +275 -0
config.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration constants for the Geminicli2api proxy server.
3
+ Centralizes all configuration to avoid duplication across modules.
4
+ """
5
+ import os
6
+ from typing import Any, Optional
7
+
8
+ # Client Configuration
9
+
10
+ # 需要自动封禁的错误码 (默认值,可通过环境变量或配置覆盖)
11
+ AUTO_BAN_ERROR_CODES = [401, 403]
12
+
13
+ # Default Safety Settings for Google API
14
+ DEFAULT_SAFETY_SETTINGS = [
15
+ {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
16
+ {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
17
+ {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
18
+ {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
19
+ {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"}
20
+ ]
21
+
22
+ # Helper function to get base model name from any variant
23
+ def get_base_model_name(model_name):
24
+ """Convert variant model name to base model name."""
25
+ # Remove all possible suffixes in order
26
+ suffixes = ["-maxthinking", "-nothinking", "-search"]
27
+ for suffix in suffixes:
28
+ if model_name.endswith(suffix):
29
+ return model_name[:-len(suffix)]
30
+ return model_name
31
+
32
+ # Helper function to check if model uses search grounding
33
+ def is_search_model(model_name):
34
+ """Check if model name indicates search grounding should be enabled."""
35
+ return "-search" in model_name
36
+
37
+ # Helper function to check if model uses no thinking
38
+ def is_nothinking_model(model_name):
39
+ """Check if model name indicates thinking should be disabled."""
40
+ return "-nothinking" in model_name
41
+
42
+ # Helper function to check if model uses max thinking
43
+ def is_maxthinking_model(model_name):
44
+ """Check if model name indicates maximum thinking budget should be used."""
45
+ return "-maxthinking" in model_name
46
+
47
+ # Helper function to get thinking budget for a model
48
+ def get_thinking_budget(model_name):
49
+ """Get the appropriate thinking budget for a model based on its name and variant."""
50
+
51
+ if is_nothinking_model(model_name):
52
+ return 128 # Limited thinking for pro
53
+ elif is_maxthinking_model(model_name):
54
+ return 32768
55
+ else:
56
+ # Default thinking budget for regular models
57
+ return -1 # Default for all models
58
+
59
+ # Helper function to check if thinking should be included in output
60
+ def should_include_thoughts(model_name):
61
+ """Check if thoughts should be included in the response."""
62
+ if is_nothinking_model(model_name):
63
+ # For nothinking mode, still include thoughts if it's a pro model
64
+ base_model = get_base_model_name(model_name)
65
+ return "gemini-2.5-pro" in base_model
66
+ else:
67
+ # For all other modes, include thoughts
68
+ return True
69
+
70
+ # Dynamic Configuration System - Optimized for memory efficiency
71
+ async def get_config_value(key: str, default: Any = None, env_var: Optional[str] = None) -> Any:
72
+ """Get configuration value with priority: ENV > Storage > default."""
73
+ # Priority 1: Environment variable
74
+ if env_var and os.getenv(env_var):
75
+ return os.getenv(env_var)
76
+
77
+ # Priority 2: Storage system
78
+ try:
79
+ from src.storage_adapter import get_storage_adapter
80
+ storage_adapter = await get_storage_adapter()
81
+ value = await storage_adapter.get_config(key)
82
+ # 检查值是否存在(不是None),允许空字符串、0、False等有效值
83
+ if value is not None:
84
+ return value
85
+ except Exception as e:
86
+ # Debug: print import/storage errors
87
+ # print(f"Config storage error for key {key}: {e}")
88
+ pass
89
+
90
+ return default
91
+
92
+
93
+ # Configuration getters - all async
94
+ async def get_proxy_config():
95
+ """Get proxy configuration."""
96
+ proxy_url = await get_config_value("proxy", env_var="PROXY")
97
+ return proxy_url if proxy_url else None
98
+
99
+ async def get_calls_per_rotation() -> int:
100
+ """Get calls per rotation setting."""
101
+ env_value = os.getenv("CALLS_PER_ROTATION")
102
+ if env_value:
103
+ try:
104
+ return int(env_value)
105
+ except ValueError:
106
+ pass
107
+
108
+ return int(await get_config_value("calls_per_rotation", 100))
109
+
110
+ async def get_auto_ban_enabled() -> bool:
111
+ """Get auto ban enabled setting."""
112
+ env_value = os.getenv("AUTO_BAN")
113
+ if env_value:
114
+ return env_value.lower() in ("true", "1", "yes", "on")
115
+
116
+ return bool(await get_config_value("auto_ban_enabled", False))
117
+
118
+ async def get_auto_ban_error_codes() -> list:
119
+ """
120
+ Get auto ban error codes.
121
+
122
+ Environment variable: AUTO_BAN_ERROR_CODES (comma-separated, e.g., "400,403")
123
+ TOML config key: auto_ban_error_codes
124
+ Default: [400, 403]
125
+ """
126
+ env_value = os.getenv("AUTO_BAN_ERROR_CODES")
127
+ if env_value:
128
+ try:
129
+ return [int(code.strip()) for code in env_value.split(",") if code.strip()]
130
+ except ValueError:
131
+ pass
132
+
133
+ codes = await get_config_value("auto_ban_error_codes")
134
+ if codes and isinstance(codes, list):
135
+ return codes
136
+ return AUTO_BAN_ERROR_CODES
137
+
138
+ async def get_retry_429_max_retries() -> int:
139
+ """Get max retries for 429 errors."""
140
+ env_value = os.getenv("RETRY_429_MAX_RETRIES")
141
+ if env_value:
142
+ try:
143
+ return int(env_value)
144
+ except ValueError:
145
+ pass
146
+
147
+ return int(await get_config_value("retry_429_max_retries", 5))
148
+
149
+ async def get_retry_429_enabled() -> bool:
150
+ """Get 429 retry enabled setting."""
151
+ env_value = os.getenv("RETRY_429_ENABLED")
152
+ if env_value:
153
+ return env_value.lower() in ("true", "1", "yes", "on")
154
+
155
+ return bool(await get_config_value("retry_429_enabled", True))
156
+
157
+ async def get_retry_429_interval() -> float:
158
+ """Get 429 retry interval in seconds."""
159
+ env_value = os.getenv("RETRY_429_INTERVAL")
160
+ if env_value:
161
+ try:
162
+ return float(env_value)
163
+ except ValueError:
164
+ pass
165
+
166
+ return float(await get_config_value("retry_429_interval", 1))
167
+
168
+
169
+ # Model name lists for different features
170
+ BASE_MODELS = [
171
+ "gemini-2.5-pro-preview-06-05",
172
+ "gemini-2.5-pro",
173
+ "gemini-2.5-pro-preview-05-06",
174
+ "gemini-2.5-flash",
175
+ ]
176
+
177
+ def get_available_models(router_type="openai"):
178
+ """
179
+ Get available models with feature prefixes.
180
+
181
+ Args:
182
+ router_type: "openai" or "gemini"
183
+
184
+ Returns:
185
+ List of model names with feature prefixes
186
+ """
187
+ models = []
188
+
189
+ for base_model in BASE_MODELS:
190
+ # 基础模型
191
+ models.append(base_model)
192
+
193
+ # 假流式模型 (前缀格式)
194
+ models.append(f"假流式/{base_model}")
195
+
196
+ # 流式抗截断模型 (仅在流式传输时有效,前缀格式)
197
+ models.append(f"流式抗截断/{base_model}")
198
+
199
+ # 支持thinking模式后缀与功能前缀组合
200
+ for thinking_suffix in ["-maxthinking", "-nothinking", "-search"]:
201
+ # 基础模型 + thinking后缀
202
+ models.append(f"{base_model}{thinking_suffix}")
203
+
204
+ # 假流式 + thinking后缀
205
+ models.append(f"假流式/{base_model}{thinking_suffix}")
206
+
207
+ # 流式抗截断 + thinking后缀
208
+ models.append(f"流式抗截断/{base_model}{thinking_suffix}")
209
+
210
+ return models
211
+
212
+ def is_fake_streaming_model(model_name: str) -> bool:
213
+ """Check if model name indicates fake streaming should be used."""
214
+ return model_name.startswith("假流式/")
215
+
216
+ def is_anti_truncation_model(model_name: str) -> bool:
217
+ """Check if model name indicates anti-truncation should be used."""
218
+ return model_name.startswith("流式抗截断/")
219
+
220
+ def get_base_model_from_feature_model(model_name: str) -> str:
221
+ """Get base model name from feature model name."""
222
+ # Remove feature prefixes
223
+ for prefix in ["假流式/", "流式抗截断/"]:
224
+ if model_name.startswith(prefix):
225
+ return model_name[len(prefix):]
226
+ return model_name
227
+
228
+ async def get_anti_truncation_max_attempts() -> int:
229
+ """
230
+ Get maximum attempts for anti-truncation continuation.
231
+
232
+ Environment variable: ANTI_TRUNCATION_MAX_ATTEMPTS
233
+ TOML config key: anti_truncation_max_attempts
234
+ Default: 3
235
+ """
236
+ env_value = os.getenv("ANTI_TRUNCATION_MAX_ATTEMPTS")
237
+ if env_value:
238
+ try:
239
+ return int(env_value)
240
+ except ValueError:
241
+ pass
242
+
243
+ return int(await get_config_value("anti_truncation_max_attempts", 3))
244
+
245
+ # Server Configuration
246
+ async def get_server_host() -> str:
247
+ """
248
+ Get server host setting.
249
+
250
+ Environment variable: HOST
251
+ TOML config key: host
252
+ Default: 0.0.0.0
253
+ """
254
+ return str(await get_config_value("host", "0.0.0.0", "HOST"))
255
+
256
+ async def get_server_port() -> int:
257
+ """
258
+ Get server port setting.
259
+
260
+ Environment variable: PORT
261
+ TOML config key: port
262
+ Default: 7861
263
+ """
264
+ env_value = os.getenv("PORT")
265
+ if env_value:
266
+ try:
267
+ return int(env_value)
268
+ except ValueError:
269
+ pass
270
+
271
+ return int(await get_config_value("port", 7861))
272
+
273
+ async def get_api_password() -> str:
274
+ """
275
+ Get API password setting for chat endpoints.
276
+
277
+ Environment variable: API_PASSWORD
278
+ TOML config key: api_password
279
+ Default: Uses PASSWORD env var for compatibility, otherwise 'pwd'
280
+ """
281
+ # 优先使用 API_PASSWORD,如果没有则使用通用 PASSWORD 保证兼容性
282
+ api_password = await get_config_value("api_password", None, "API_PASSWORD")
283
+ if api_password is not None:
284
+ return str(api_password)
285
+
286
+ # 兼容性:使用通用密码
287
+ return str(await get_config_value("password", "pwd", "PASSWORD"))
288
+
289
+ async def get_panel_password() -> str:
290
+ """
291
+ Get panel password setting for web interface.
292
+
293
+ Environment variable: PANEL_PASSWORD
294
+ TOML config key: panel_password
295
+ Default: Uses PASSWORD env var for compatibility, otherwise 'pwd'
296
+ """
297
+ # 优先使用 PANEL_PASSWORD,如果没有则使用通用 PASSWORD 保证兼容性
298
+ panel_password = await get_config_value("panel_password", None, "PANEL_PASSWORD")
299
+ if panel_password is not None:
300
+ return str(panel_password)
301
+
302
+ # 兼容性:使用通用密码
303
+ return str(await get_config_value("password", "pwd", "PASSWORD"))
304
+
305
+ async def get_server_password() -> str:
306
+ """
307
+ Get server password setting (deprecated, use get_api_password or get_panel_password).
308
+
309
+ Environment variable: PASSWORD
310
+ TOML config key: password
311
+ Default: pwd
312
+ """
313
+ return str(await get_config_value("password", "pwd", "PASSWORD"))
314
+
315
+ async def get_credentials_dir() -> str:
316
+ """
317
+ Get credentials directory setting.
318
+
319
+ Environment variable: CREDENTIALS_DIR
320
+ TOML config key: credentials_dir
321
+ Default: ./creds
322
+ """
323
+ return str(await get_config_value("credentials_dir", "./creds", "CREDENTIALS_DIR"))
324
+
325
+ async def get_code_assist_endpoint() -> str:
326
+ """
327
+ Get Code Assist endpoint setting.
328
+
329
+ Environment variable: CODE_ASSIST_ENDPOINT
330
+ TOML config key: code_assist_endpoint
331
+ Default: https://cloudcode-pa.googleapis.com
332
+ """
333
+ return str(await get_config_value("code_assist_endpoint", "https://cloudcode-pa.googleapis.com", "CODE_ASSIST_ENDPOINT"))
334
+
335
+ async def get_auto_load_env_creds() -> bool:
336
+ """
337
+ Get auto load environment credentials setting.
338
+
339
+ Environment variable: AUTO_LOAD_ENV_CREDS
340
+ TOML config key: auto_load_env_creds
341
+ Default: False
342
+ """
343
+ env_value = os.getenv("AUTO_LOAD_ENV_CREDS")
344
+ if env_value:
345
+ return env_value.lower() in ("true", "1", "yes", "on")
346
+
347
+ return bool(await get_config_value("auto_load_env_creds", False))
348
+
349
+ async def get_compatibility_mode_enabled() -> bool:
350
+ """
351
+ Get compatibility mode setting.
352
+
353
+ 兼容性模式:启用后所有system消息全部转换成user,停用system_instructions。
354
+ 该选项可能会降低模型理解能力,但是能避免流式空回的情况。
355
+
356
+ Environment variable: COMPATIBILITY_MODE
357
+ TOML config key: compatibility_mode_enabled
358
+ Default: True
359
+ """
360
+ env_value = os.getenv("COMPATIBILITY_MODE")
361
+ if env_value:
362
+ return env_value.lower() in ("true", "1", "yes", "on")
363
+
364
+ return bool(await get_config_value("compatibility_mode_enabled", True))
365
+
366
+ async def get_oauth_proxy_url() -> str:
367
+ """
368
+ Get OAuth proxy URL setting.
369
+
370
+ 用于Google OAuth2认证的代理URL。
371
+
372
+ Environment variable: OAUTH_PROXY_URL
373
+ TOML config key: oauth_proxy_url
374
+ Default: https://oauth2.googleapis.com
375
+ """
376
+ return str(await get_config_value("oauth_proxy_url", "https://oauth2.googleapis.com", "OAUTH_PROXY_URL"))
377
+
378
+ async def get_googleapis_proxy_url() -> str:
379
+ """
380
+ Get Google APIs proxy URL setting.
381
+
382
+ 用于Google APIs调用的代理URL。
383
+
384
+ Environment variable: GOOGLEAPIS_PROXY_URL
385
+ TOML config key: googleapis_proxy_url
386
+ Default: https://www.googleapis.com
387
+ """
388
+ return str(await get_config_value("googleapis_proxy_url", "https://www.googleapis.com", "GOOGLEAPIS_PROXY_URL"))
389
+
390
+
391
+ async def get_resource_manager_api_url() -> str:
392
+ """
393
+ Get Google Cloud Resource Manager API URL setting.
394
+
395
+ 用于Google Cloud Resource Manager API的URL。
396
+
397
+ Environment variable: RESOURCE_MANAGER_API_URL
398
+ TOML config key: resource_manager_api_url
399
+ Default: https://cloudresourcemanager.googleapis.com
400
+ """
401
+ return str(await get_config_value("resource_manager_api_url", "https://cloudresourcemanager.googleapis.com", "RESOURCE_MANAGER_API_URL"))
402
+
403
+ async def get_service_usage_api_url() -> str:
404
+ """
405
+ Get Google Cloud Service Usage API URL setting.
406
+
407
+ 用于Google Cloud Service Usage API的URL。
408
+
409
+ Environment variable: SERVICE_USAGE_API_URL
410
+ TOML config key: service_usage_api_url
411
+ Default: https://serviceusage.googleapis.com
412
+ """
413
+ return str(await get_config_value("service_usage_api_url", "https://serviceusage.googleapis.com", "SERVICE_USAGE_API_URL"))
414
+
415
+
416
+ # MongoDB Configuration
417
+ async def get_mongodb_uri() -> str:
418
+ """
419
+ Get MongoDB connection URI setting.
420
+
421
+ MongoDB连接URI,用于分布式部署时的数据存储。
422
+ 设置此项后将不再使用本地/creds和TOML文件。
423
+
424
+ Environment variable: MONGODB_URI
425
+ TOML config key: mongodb_uri
426
+ Default: None (使用本地文件存储)
427
+
428
+ 示例格式:
429
+ - mongodb://username:password@localhost:27017/database
430
+ - mongodb+srv://username:password@cluster.mongodb.net/database
431
+ """
432
+ return str(await get_config_value("mongodb_uri", "", "MONGODB_URI"))
433
+
434
+ async def get_mongodb_database() -> str:
435
+ """
436
+ Get MongoDB database name setting.
437
+
438
+ MongoDB数据库名称。
439
+
440
+ Environment variable: MONGODB_DATABASE
441
+ TOML config key: mongodb_database
442
+ Default: gcli2api
443
+ """
444
+ return str(await get_config_value("mongodb_database", "gcli2api", "MONGODB_DATABASE"))
445
+
446
+ async def is_mongodb_mode() -> bool:
447
+ """
448
+ Check if MongoDB mode is enabled.
449
+
450
+ 检查是否启用了MongoDB模式。
451
+ 如果配置了MongoDB URI,则启用MongoDB模式,不再使用本地文件。
452
+
453
+ Returns:
454
+ bool: True if MongoDB mode is enabled, False otherwise
455
+ """
456
+ mongodb_uri = await get_mongodb_uri()
457
+ return bool(mongodb_uri and mongodb_uri.strip())
log.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 日志模块 - 使用环境变量配置
3
+ """
4
+ import os
5
+ import sys
6
+ import threading
7
+ from datetime import datetime
8
+
9
+ # 日志级别定义
10
+ LOG_LEVELS = {
11
+ 'debug': 0,
12
+ 'info': 1,
13
+ 'warning': 2,
14
+ 'error': 3,
15
+ 'critical': 4
16
+ }
17
+
18
+ # 线程锁,用于文件写入同步
19
+ _file_lock = threading.Lock()
20
+
21
+ # 文件写入状态标志
22
+ _file_writing_disabled = False
23
+ _disable_reason = None
24
+
25
+ def _get_current_log_level():
26
+ """获取当前日志级别"""
27
+ level = os.getenv('LOG_LEVEL', 'info').lower()
28
+ return LOG_LEVELS.get(level, LOG_LEVELS['info'])
29
+
30
+ def _get_log_file_path():
31
+ """获取日志文件路径"""
32
+ return os.getenv('LOG_FILE', 'log.txt')
33
+
34
+ def _write_to_file(message: str):
35
+ """线程安全地写入日志文件"""
36
+ global _file_writing_disabled, _disable_reason
37
+
38
+ # 如果文件写入已被禁用,直接返回
39
+ if _file_writing_disabled:
40
+ return
41
+
42
+ try:
43
+ log_file = _get_log_file_path()
44
+ with _file_lock:
45
+ with open(log_file, 'a', encoding='utf-8') as f:
46
+ f.write(message + '\n')
47
+ f.flush() # 强制刷新到磁盘,确保实时写入
48
+ except (PermissionError, OSError, IOError) as e:
49
+ # 检测只读文件系统或权限问题,禁用文件写入
50
+ _file_writing_disabled = True
51
+ _disable_reason = str(e)
52
+ print(f"Warning: File system appears to be read-only or permission denied. Disabling log file writing: {e}", file=sys.stderr)
53
+ print(f"Log messages will continue to display in console only.", file=sys.stderr)
54
+ except Exception as e:
55
+ # 其他异常仍然输出警告但不禁用写入(可能是临时问题)
56
+ print(f"Warning: Failed to write to log file: {e}", file=sys.stderr)
57
+
58
+ def _log(level: str, message: str):
59
+ """
60
+ 内部日志函数
61
+ """
62
+ level = level.lower()
63
+ if level not in LOG_LEVELS:
64
+ print(f"Warning: Unknown log level '{level}'", file=sys.stderr)
65
+ return
66
+
67
+ # 检查日志级别
68
+ current_level = _get_current_log_level()
69
+ if LOG_LEVELS[level] < current_level:
70
+ return
71
+
72
+ # 格式化日志消息
73
+ timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
74
+ entry = f"[{timestamp}] [{level.upper()}] {message}"
75
+
76
+ # 输出到控制台
77
+ if level in ('error', 'critical'):
78
+ print(entry, file=sys.stderr)
79
+ else:
80
+ print(entry)
81
+
82
+ # 实时写入文件
83
+ _write_to_file(entry)
84
+
85
+ def set_log_level(level: str):
86
+ """设置日志级别提示"""
87
+ level = level.lower()
88
+ if level not in LOG_LEVELS:
89
+ print(f"Warning: Unknown log level '{level}'. Valid levels: {', '.join(LOG_LEVELS.keys())}")
90
+ return False
91
+
92
+ print(f"Note: To set log level '{level}', please set LOG_LEVEL environment variable")
93
+ return True
94
+
95
+ class Logger:
96
+ """支持 log('info', 'msg') 和 log.info('msg') 两种调用方式"""
97
+
98
+ def __call__(self, level: str, message: str):
99
+ """支持 log('info', 'message') 调用方式"""
100
+ _log(level, message)
101
+
102
+ def debug(self, message: str):
103
+ """记录调试信息"""
104
+ _log('debug', message)
105
+
106
+ def info(self, message: str):
107
+ """记录一般信息"""
108
+ _log('info', message)
109
+
110
+ def warning(self, message: str):
111
+ """记录警告信息"""
112
+ _log('warning', message)
113
+
114
+ def error(self, message: str):
115
+ """记录错误信息"""
116
+ _log('error', message)
117
+
118
+ def critical(self, message: str):
119
+ """记录严重错误信息"""
120
+ _log('critical', message)
121
+
122
+ def get_current_level(self) -> str:
123
+ """获取当前日志级别名称"""
124
+ current_level = _get_current_log_level()
125
+ for name, value in LOG_LEVELS.items():
126
+ if value == current_level:
127
+ return name
128
+ return 'info'
129
+
130
+ def get_log_file(self) -> str:
131
+ """获取当前日志文件路径"""
132
+ return _get_log_file_path()
133
+
134
+
135
+ # 导出全局日志实例
136
+ log = Logger()
137
+
138
+ # 导出的公共接口
139
+ __all__ = ['log', 'set_log_level', 'LOG_LEVELS']
140
+
141
+ # 使用说明:
142
+ # 1. 设置日志级别: export LOG_LEVEL=debug (或在.env文件中设置)
143
+ # 2. 设置日志文件: export LOG_FILE=log.txt (或在.env文件中设置)
multi_user_auth_web.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ OAuth Web 服务器 - 独立的OAuth认证服务
4
+ 提供简化的OAuth认证界面,只包含验证功能,不包含上传和管理功能
5
+ """
6
+
7
+ from log import log
8
+ import asyncio
9
+ from contextlib import asynccontextmanager
10
+ from fastapi import FastAPI, HTTPException, Depends
11
+ from fastapi.responses import HTMLResponse, JSONResponse
12
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
13
+ from pydantic import BaseModel
14
+
15
+ from src.auth import (
16
+ create_auth_url,
17
+ verify_password,
18
+ generate_auth_token,
19
+ verify_auth_token,
20
+ asyncio_complete_auth_flow,
21
+ complete_auth_flow_from_callback_url,
22
+ CALLBACK_HOST,
23
+ )
24
+
25
+ # 创建FastAPI应用
26
+ app = FastAPI(
27
+ title="Google OAuth 认证服务",
28
+ description="独立的OAuth认证服务,用于获取Google Cloud认证文件",
29
+ )
30
+
31
+ # HTTP Bearer认证
32
+ security = HTTPBearer()
33
+
34
+ # 请求模型
35
+ class LoginRequest(BaseModel):
36
+ password: str
37
+
38
+ class AuthStartRequest(BaseModel):
39
+ project_id: str = None # 现在是可选的,支持自动检测
40
+
41
+ class AuthCallbackRequest(BaseModel):
42
+ project_id: str = None # 现在是可选的,支持自动检测
43
+
44
+ class AuthCallbackUrlRequest(BaseModel):
45
+ callback_url: str # OAuth回调完整URL
46
+ project_id: str = None # 可选的项目ID
47
+
48
+ def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)):
49
+ """验证认证令牌"""
50
+ if not verify_auth_token(credentials.credentials):
51
+ raise HTTPException(status_code=401, detail="无效的认证令牌")
52
+ return credentials.credentials
53
+
54
+
55
+ @app.get("/", response_class=HTMLResponse)
56
+ async def serve_oauth_page():
57
+ """提供OAuth认证页面"""
58
+ try:
59
+ # 读取HTML文件
60
+ html_file_path = "./front/multi_user_auth_web.html"
61
+
62
+ with open(html_file_path, "r", encoding="utf-8") as f:
63
+ html_content = f.read()
64
+
65
+ return HTMLResponse(content=html_content)
66
+ except FileNotFoundError:
67
+ raise HTTPException(status_code=404, detail="认证页面不存在")
68
+ except Exception as e:
69
+ log.error(f"加载认证页面失败: {e}")
70
+ raise HTTPException(status_code=500, detail="服务器内部错误")
71
+
72
+ @app.post("/auth/login")
73
+ async def login(request: LoginRequest):
74
+ """用户登录"""
75
+ try:
76
+ if await verify_password(request.password):
77
+ token = generate_auth_token()
78
+ return JSONResponse(content={"token": token, "message": "登录成功"})
79
+ else:
80
+ raise HTTPException(status_code=401, detail="密码错误")
81
+ except HTTPException:
82
+ raise
83
+ except Exception as e:
84
+ log.error(f"登录失败: {e}")
85
+ raise HTTPException(status_code=500, detail=str(e))
86
+
87
+
88
+ @app.post("/auth/start")
89
+ async def start_auth(request: AuthStartRequest, token: str = Depends(verify_token)):
90
+ """开始认证流程,支持自动检测项目ID"""
91
+ try:
92
+ # 如果没有提供项目ID,尝试自动检测
93
+ project_id = request.project_id
94
+ if not project_id:
95
+ log.info("未提供项目ID,后续将尝试自动检测...")
96
+
97
+ # 使用认证令牌作为用户会话标识
98
+ user_session = token if token else None
99
+ result = await create_auth_url(project_id, user_session)
100
+
101
+ if result['success']:
102
+ # 构建动态回调URL
103
+ callback_port = result.get('callback_port')
104
+ callback_url = f"http://{CALLBACK_HOST}:{callback_port}" if callback_port else None
105
+
106
+ response_data = {
107
+ "auth_url": result['auth_url'],
108
+ "state": result['state'],
109
+ "auto_project_detection": result.get('auto_project_detection', False),
110
+ "detected_project_id": result.get('detected_project_id')
111
+ }
112
+
113
+ # 如果有回调端口信息,添加到响应中
114
+ if callback_port:
115
+ response_data["callback_port"] = callback_port
116
+ response_data["callback_url"] = callback_url
117
+
118
+ return JSONResponse(content=response_data)
119
+ else:
120
+ raise HTTPException(status_code=500, detail=result['error'])
121
+
122
+ except HTTPException:
123
+ raise
124
+ except Exception as e:
125
+ log.error(f"开始认证流程失败: {e}")
126
+ raise HTTPException(status_code=500, detail=str(e))
127
+
128
+
129
+ @app.post("/auth/callback")
130
+ async def auth_callback(request: AuthCallbackRequest, token: str = Depends(verify_token)):
131
+ """处理认证回调(异步等待),支持自动检测项目ID"""
132
+ try:
133
+ # 项目ID现在是可选的,在回调处理中进行自动检测
134
+ project_id = request.project_id
135
+
136
+ # 使用认证令牌作为用户会话标识
137
+ user_session = token if token else None
138
+ # 异步等待OAuth回调完成
139
+ result = await asyncio_complete_auth_flow(project_id, user_session)
140
+
141
+ if result['success']:
142
+ return JSONResponse(content={
143
+ "credentials": result['credentials'],
144
+ "file_path": result['file_path'],
145
+ "message": "认证成功,凭证已保存",
146
+ "auto_detected_project": result.get('auto_detected_project', False)
147
+ })
148
+ else:
149
+ # 如果需要手动项目ID或项目选择,在响应中标明
150
+ if result.get('requires_manual_project_id'):
151
+ # 使用JSON响应
152
+ return JSONResponse(
153
+ status_code=400,
154
+ content={
155
+ "error": result['error'],
156
+ "requires_manual_project_id": True
157
+ }
158
+ )
159
+ elif result.get('requires_project_selection'):
160
+ # 返回项目列表供用户选择
161
+ return JSONResponse(
162
+ status_code=400,
163
+ content={
164
+ "error": result['error'],
165
+ "requires_project_selection": True,
166
+ "available_projects": result['available_projects']
167
+ }
168
+ )
169
+ else:
170
+ raise HTTPException(status_code=400, detail=result['error'])
171
+
172
+ except HTTPException:
173
+ raise
174
+ except Exception as e:
175
+ log.error(f"处理认证回调失败: {e}")
176
+ raise HTTPException(status_code=500, detail=str(e))
177
+
178
+
179
+ @app.post("/auth/callback-url")
180
+ async def auth_callback_url(request: AuthCallbackUrlRequest, token: str = Depends(verify_token)):
181
+ """从回调URL直接完成认证,无需启动本地服务器"""
182
+ try:
183
+ # 验证URL格式
184
+ if not request.callback_url or not request.callback_url.startswith(('http://', 'https://')):
185
+ raise HTTPException(status_code=400, detail="请提供有效的回调URL")
186
+
187
+ # 从回调URL完成认证
188
+ result = await complete_auth_flow_from_callback_url(request.callback_url, request.project_id)
189
+
190
+ if result['success']:
191
+ return JSONResponse(content={
192
+ "credentials": result['credentials'],
193
+ "file_path": result['file_path'],
194
+ "message": "从回调URL认证成功,凭证已保存",
195
+ "auto_detected_project": result.get('auto_detected_project', False)
196
+ })
197
+ else:
198
+ # 处理各种错误情况
199
+ if result.get('requires_manual_project_id'):
200
+ return JSONResponse(
201
+ status_code=400,
202
+ content={
203
+ "error": result['error'],
204
+ "requires_manual_project_id": True
205
+ }
206
+ )
207
+ elif result.get('requires_project_selection'):
208
+ return JSONResponse(
209
+ status_code=400,
210
+ content={
211
+ "error": result['error'],
212
+ "requires_project_selection": True,
213
+ "available_projects": result['available_projects']
214
+ }
215
+ )
216
+ else:
217
+ raise HTTPException(status_code=400, detail=result['error'])
218
+
219
+ except HTTPException:
220
+ raise
221
+ except Exception as e:
222
+ log.error(f"从回调URL处理认证失败: {e}")
223
+ raise HTTPException(status_code=500, detail=str(e))
224
+
225
+
226
+ @asynccontextmanager
227
+ async def lifespan(app: FastAPI):
228
+ log.info("OAuth认证服务启动中...")
229
+
230
+ # OAuth回调服务器现在动态按需启动,每个认证流程使用独立端口
231
+ log.info("OAuth回调服务器将为每个认证流程动态分配端口")
232
+
233
+ # 从配置获取密码和端口
234
+ from config import get_panel_password, get_server_port
235
+ password = await get_panel_password()
236
+ port = await get_server_port()
237
+
238
+ log.info("Web服务已由 ASGI 服务器启动")
239
+
240
+ print("\n" + "="*60)
241
+ print("🚀 Google OAuth 认证服务已启动")
242
+ print("="*60)
243
+ print(f"📱 Web界面: http://localhost:{port}")
244
+ print(f"🔐 默认密码: {'已设置' if password else 'pwd (请设置PASSWORD环境变量)'}")
245
+ print(f"🔄 多用户并发: 支持多用户同时认证(动态端口分配)")
246
+ print("="*60 + "\n")
247
+
248
+ try:
249
+ yield
250
+ finally:
251
+ log.info("OAuth认证服务关闭中...")
252
+ # OAuth服务器由认证流程自动管理,无需手动清理
253
+ log.info("OAuth认证服务已关闭")
254
+
255
+ # 注册 lifespan 处理器
256
+ app.router.lifespan_context = lifespan
257
+
258
+ if __name__ == "__main__":
259
+ from hypercorn.asyncio import serve
260
+ from hypercorn.config import Config
261
+
262
+ async def main():
263
+ # 从配置获取端口
264
+ from config import get_server_port
265
+ PORT = await get_server_port()
266
+
267
+ config = Config()
268
+ config.bind = [f"0.0.0.0:{PORT}"]
269
+ config.accesslog = "-"
270
+ config.errorlog = "-"
271
+ config.loglevel = "INFO"
272
+
273
+ await serve(app, config)
274
+
275
+ asyncio.run(main())