Spaces:
Sleeping
Sleeping
| from fastapi import APIRouter, Depends | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from datetime import datetime | |
| # import uuid | |
| from global_state import get | |
| from db.tbs_db import TbsDb | |
| from auth import get_current_user | |
| from db_model.user import UserModel | |
| from db_model.chat import ChatModel | |
| router = APIRouter() | |
| db_module_filename = f"{get('project_root')}/db/cloudflare.py" | |
| async def chat_completions(chat_model:ChatModel, current_user: UserModel = Depends(get_current_user)): | |
| try: | |
| model = chat_model.model | |
| except: | |
| model = '' | |
| if (model=='')or(model is None): | |
| model = await get_default_model() | |
| api_key_info = await get_api_key(model) | |
| api_key = api_key_info.get('api_key', '') | |
| group_name = api_key_info.get('group_name', '') | |
| base_url = api_key_info.get('base_url', '') | |
| if group_name=='gemini': # google api,生成 gemini 的 llm | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| llm = ChatGoogleGenerativeAI( | |
| api_key = api_key, | |
| model = model, | |
| ) | |
| else: # 下面就是 chatgpt 兼容 api | |
| from langchain_openai import ChatOpenAI | |
| # 初始化 ChatOpenAI 模型 | |
| llm = ChatOpenAI( | |
| model = model, | |
| api_key = api_key, | |
| base_url = base_url, | |
| ) | |
| # 生成prompt模板 | |
| lc_messages = [(message.role, message.content) for message in chat_model.messages] | |
| prompt_template = ChatPromptTemplate.from_messages(lc_messages) | |
| chain = prompt_template | llm | |
| try: | |
| result = chain.invoke({}) # AIMessage 类对象 | |
| except Exception as e: | |
| return {'error': str(e)} | |
| # 转换为OpenAI格式 | |
| converted_data = convert_to_openai_format(result) | |
| return converted_data | |
| # 从数据库获取默认模型 | |
| async def get_default_model(): | |
| query = f"SELECT * FROM api_names order by default_order limit 1" | |
| response = TbsDb(db_module_filename, "Cloudflare").get_item(query) | |
| try: | |
| result = response['result'][0]['results'][0]['api_name'] | |
| except: | |
| result = '' | |
| return result | |
| async def get_api_key(model): | |
| query = f""" | |
| SELECT an.api_name, ak.api_key, an.base_url, ag.group_name | |
| FROM api_keys ak | |
| JOIN api_groups ag ON ak.api_group_id = ag.id | |
| JOIN api_names an ON an.api_group_id = ag.id | |
| WHERE ak.category='LLM' and an.api_name='{model}' and disabled=0 | |
| ORDER BY ak.last_call_at | |
| limit 1 | |
| """ | |
| response = TbsDb(db_module_filename, "Cloudflare").get_item(query) | |
| try: | |
| result = response['result'][0]['results'][0] | |
| api_key = result['api_key'] | |
| except: | |
| api_key = '' | |
| query = f"update api_keys set last_call_at=datetime('now') where api_key='{api_key}'" | |
| TbsDb(db_module_filename, "Cloudflare").execute_query(query) | |
| return result | |
| def convert_to_openai_format(original_json): | |
| # 创建新的JSON对象 | |
| new_json = { | |
| "id": "chatcmpl-123", # 这里可以生成一个唯一的ID,或者使用传入的id | |
| "object": "chat.completion", | |
| "created": int(datetime.now().timestamp()), # 当前时间戳 | |
| "choices": [ | |
| { | |
| "index": 0, | |
| "message": { | |
| "role": "assistant", | |
| "content": original_json.content # 使用原始内容 | |
| }, | |
| "finish_reason": "stop" | |
| } | |
| ], | |
| "usage": { | |
| "prompt_tokens": original_json.usage_metadata.get("input_tokens",0), | |
| "completion_tokens": original_json.usage_metadata.get("output_tokens", 0), | |
| "total_tokens": original_json.usage_metadata.get("total_tokens", 0) | |
| } | |
| } | |
| return new_json | |