hins111's picture
Upload 11 files
6a0eaad verified
import argparse # 新增导入
from flask import Flask, request, jsonify
import requests
import time
import uuid
import logging
import json
import sys # 新增导入
from typing import Dict, Any
from datetime import datetime, UTC
# 自定义日志 Handler,确保刷新
class FlushingStreamHandler(logging.StreamHandler):
def emit(self, record):
try:
super().emit(record)
self.flush()
except Exception:
self.handleError(record)
# 配置日志(更改为中文)
log_format = '%(asctime)s [%(levelname)s] %(message)s'
formatter = logging.Formatter(log_format)
# 创建一个 handler 明确指向 sys.stderr 并使用自定义的 FlushingStreamHandler
# sys.stderr 在子进程中应该被 gui_launcher.py 的 PIPE 捕获
stderr_handler = FlushingStreamHandler(sys.stderr)
stderr_handler.setFormatter(formatter)
stderr_handler.setLevel(logging.INFO)
# 获取根 logger 并添加我们的 handler
# 这能确保所有传播到根 logger 的日志 (包括 Flask 和 Werkzeug 的,如果它们没有自己的特定 handler)
# 都会经过这个 handler。
root_logger = logging.getLogger()
# 清除可能存在的由 basicConfig 或其他库添加的默认 handlers,以避免重复日志或意外输出
if root_logger.hasHandlers():
root_logger.handlers.clear()
root_logger.addHandler(stderr_handler)
root_logger.setLevel(logging.INFO) # 确保根 logger 级别也设置了
logger = logging.getLogger(__name__) # 获取名为 'llm' 的 logger,它会继承根 logger 的配置
app = Flask(__name__)
# Flask 的 app.logger 默认会传播到 root logger。
# 如果需要,也可以为 app.logger 和 werkzeug logger 单独配置,但通常让它们传播到 root 就够了。
# 例如:
# app.logger.handlers.clear() # 清除 Flask 可能添加的默认 handler
# app.logger.addHandler(stderr_handler)
# app.logger.setLevel(logging.INFO)
#
# werkzeug_logger = logging.getLogger('werkzeug')
# werkzeug_logger.handlers.clear()
# werkzeug_logger.addHandler(stderr_handler)
# werkzeug_logger.setLevel(logging.INFO)
# 启用模型配置:直接定义启用的模型名称
# 用户可添加/删除模型名称,动态生成元数据
ENABLED_MODELS = {
"gemini-2.5-pro-preview-05-06",
"gemini-2.5-flash-preview-04-17",
"gemini-2.0-flash",
"gemini-2.0-flash-lite",
"gemini-1.5-pro",
"gemini-1.5-flash",
"gemini-1.5-flash-8b",
}
# API 配置
API_URL = "" # 将在 main 函数中根据参数设置
DEFAULT_MAIN_SERVER_PORT = 2048
# 请替换为你的 API 密钥(请勿公开分享)
API_KEY = "123456"
# 模拟 Ollama 聊天响应数据库
OLLAMA_MOCK_RESPONSES = {
"What is the capital of France?": "The capital of France is Paris.",
"Tell me about AI.": "AI is the simulation of human intelligence in machines, enabling tasks like reasoning and learning.",
"Hello": "Hi! How can I assist you today?"
}
@app.route("/", methods=["GET"])
def root_endpoint():
"""模拟 Ollama 根路径,返回 'Ollama is running'"""
logger.info("收到根路径请求")
return "Ollama is running", 200
@app.route("/api/tags", methods=["GET"])
def tags_endpoint():
"""模拟 Ollama 的 /api/tags 端点,动态生成启用模型列表"""
logger.info("收到 /api/tags 请求")
models = []
for model_name in ENABLED_MODELS:
# 推导 family:从模型名称提取前缀(如 "gpt-4o" -> "gpt")
family = model_name.split('-')[0].lower() if '-' in model_name else model_name.lower()
# 特殊处理已知模型
if 'llama' in model_name:
family = 'llama'
format = 'gguf'
size = 1234567890
parameter_size = '405B' if '405b' in model_name else 'unknown'
quantization_level = 'Q4_0'
elif 'mistral' in model_name:
family = 'mistral'
format = 'gguf'
size = 1234567890
parameter_size = 'unknown'
quantization_level = 'unknown'
else:
format = 'unknown'
size = 9876543210
parameter_size = 'unknown'
quantization_level = 'unknown'
models.append({
"name": model_name,
"model": model_name,
"modified_at": datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%S.%fZ"),
"size": size,
"digest": str(uuid.uuid4()),
"details": {
"parent_model": "",
"format": format,
"family": family,
"families": [family],
"parameter_size": parameter_size,
"quantization_level": quantization_level
}
})
logger.info(f"返回 {len(models)} 个模型: {[m['name'] for m in models]}")
return jsonify({"models": models}), 200
def generate_ollama_mock_response(prompt: str, model: str) -> Dict[str, Any]:
"""生成模拟的 Ollama 聊天响应,符合 /api/chat 格式"""
response_content = OLLAMA_MOCK_RESPONSES.get(
prompt, f"Echo: {prompt} (这是来自模拟 Ollama 服务器的响应。)"
)
return {
"model": model,
"created_at": datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ"),
"message": {
"role": "assistant",
"content": response_content
},
"done": True,
"total_duration": 123456789,
"load_duration": 1234567,
"prompt_eval_count": 10,
"prompt_eval_duration": 2345678,
"eval_count": 20,
"eval_duration": 3456789
}
def convert_api_to_ollama_response(api_response: Dict[str, Any], model: str) -> Dict[str, Any]:
"""将 API 的 OpenAI 格式响应转换为 Ollama 格式"""
try:
content = api_response["choices"][0]["message"]["content"]
total_duration = api_response.get("usage", {}).get("total_tokens", 30) * 1000000
prompt_tokens = api_response.get("usage", {}).get("prompt_tokens", 10)
completion_tokens = api_response.get("usage", {}).get("completion_tokens", 20)
return {
"model": model,
"created_at": datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ"),
"message": {
"role": "assistant",
"content": content
},
"done": True,
"total_duration": total_duration,
"load_duration": 1234567,
"prompt_eval_count": prompt_tokens,
"prompt_eval_duration": prompt_tokens * 100000,
"eval_count": completion_tokens,
"eval_duration": completion_tokens * 100000
}
except KeyError as e:
logger.error(f"转换API响应失败: 缺少键 {str(e)}")
return {"error": f"无效的API响应格式: 缺少键 {str(e)}"}
def print_request_params(data: Dict[str, Any], endpoint: str) -> None:
"""打印请求参数"""
model = data.get("model", "未指定")
temperature = data.get("temperature", "未指定")
stream = data.get("stream", False)
messages_info = []
for msg in data.get("messages", []):
role = msg.get("role", "未知")
content = msg.get("content", "")
content_preview = content[:50] + "..." if len(content) > 50 else content
messages_info.append(f"[{role}] {content_preview}")
params_str = {
"端点": endpoint,
"模型": model,
"温度": temperature,
"流式输出": stream,
"消息数量": len(data.get("messages", [])),
"消息预览": messages_info
}
logger.info(f"请求参数: {json.dumps(params_str, ensure_ascii=False, indent=2)}")
@app.route("/api/chat", methods=["POST"])
def ollama_chat_endpoint():
"""模拟 Ollama 的 /api/chat 端点,所有模型都能使用"""
try:
data = request.get_json()
if not data or "messages" not in data:
logger.error("无效请求: 缺少 'messages' 字段")
return jsonify({"error": "无效请求: 缺少 'messages' 字段"}), 400
messages = data.get("messages", [])
if not messages or not isinstance(messages, list):
logger.error("无效请求: 'messages' 必须是非空列表")
return jsonify({"error": "无效请求: 'messages' 必须是非空列表"}), 400
model = data.get("model", "llama3.2")
user_message = next(
(msg["content"] for msg in reversed(messages) if msg.get("role") == "user"),
""
)
if not user_message:
logger.error("未找到用户消息")
return jsonify({"error": "未找到用户消息"}), 400
# 打印请求参数
print_request_params(data, "/api/chat")
logger.info(f"处理 /api/chat 请求, 模型: {model}")
# 移除模型限制,所有模型都使用API
api_request = {
"model": model,
"messages": messages,
"stream": False,
"temperature": data.get("temperature", 0.7)
}
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {API_KEY}"
}
try:
logger.info(f"转发请求到API: {API_URL}")
response = requests.post(API_URL, json=api_request, headers=headers, timeout=300000)
response.raise_for_status()
api_response = response.json()
ollama_response = convert_api_to_ollama_response(api_response, model)
logger.info(f"收到来自API的响应,模型: {model}")
return jsonify(ollama_response), 200
except requests.RequestException as e:
logger.error(f"API请求失败: {str(e)}")
# 如果API请求失败,使用模拟响应作为备用
logger.info(f"使用模拟响应作为备用方案,模型: {model}")
response = generate_ollama_mock_response(user_message, model)
return jsonify(response), 200
except Exception as e:
logger.error(f"/api/chat 服务器错误: {str(e)}")
return jsonify({"error": f"服务器错误: {str(e)}"}), 500
@app.route("/v1/chat/completions", methods=["POST"])
def api_chat_endpoint():
"""转发到API的 /v1/chat/completions 端点,并转换为 Ollama 格式"""
try:
data = request.get_json()
if not data or "messages" not in data:
logger.error("无效请求: 缺少 'messages' 字段")
return jsonify({"error": "无效请求: 缺少 'messages' 字段"}), 400
messages = data.get("messages", [])
if not messages or not isinstance(messages, list):
logger.error("无效请求: 'messages' 必须是非空列表")
return jsonify({"error": "无效请求: 'messages' 必须是非空列表"}), 400
model = data.get("model", "grok-3")
user_message = next(
(msg["content"] for msg in reversed(messages) if msg.get("role") == "user"),
""
)
if not user_message:
logger.error("未找到用户消息")
return jsonify({"error": "未找到用户消息"}), 400
# 打印请求参数
print_request_params(data, "/v1/chat/completions")
logger.info(f"处理 /v1/chat/completions 请求, 模型: {model}")
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {API_KEY}"
}
try:
logger.info(f"转发请求到API: {API_URL}")
response = requests.post(API_URL, json=data, headers=headers, timeout=300000)
response.raise_for_status()
api_response = response.json()
ollama_response = convert_api_to_ollama_response(api_response, model)
logger.info(f"收到来自API的响应,模型: {model}")
return jsonify(ollama_response), 200
except requests.RequestException as e:
logger.error(f"API请求失败: {str(e)}")
return jsonify({"error": f"API请求失败: {str(e)}"}), 500
except Exception as e:
logger.error(f"/v1/chat/completions 服务器错误: {str(e)}")
return jsonify({"error": f"服务器错误: {str(e)}"}), 500
def main():
"""启动模拟服务器"""
global API_URL # 声明我们要修改全局变量
parser = argparse.ArgumentParser(description="LLM Mock Service for AI Studio Proxy")
parser.add_argument(
"--main-server-port",
type=int,
default=DEFAULT_MAIN_SERVER_PORT,
help=f"Port of the main AI Studio Proxy server (default: {DEFAULT_MAIN_SERVER_PORT})"
)
args = parser.parse_args()
API_URL = f"http://localhost:{args.main_server_port}/v1/chat/completions"
logger.info(f"模拟 Ollama 和 API 代理服务器将转发请求到: {API_URL}")
logger.info("正在启动模拟 Ollama 和 API 代理服务器,地址: http://localhost:11434")
app.run(host="0.0.0.0", port=11434, debug=False)
if __name__ == "__main__":
main()