Support assigning different models to different API keys
Browse files- main.py +45 -21
- response.py +17 -11
main.py
CHANGED
|
@@ -27,12 +27,6 @@ async def lifespan(app: FastAPI):
|
|
| 27 |
|
| 28 |
app = FastAPI(lifespan=lifespan)
|
| 29 |
|
| 30 |
-
# 模拟存储API Key的数据库
|
| 31 |
-
api_keys_db = {
|
| 32 |
-
"sk-KjjI60Yf0JFcsvgRmXqFwgGmWUd9GZnmi3KlvowmRWpWpQRo": "user1",
|
| 33 |
-
# 可以添加更多的API Key
|
| 34 |
-
}
|
| 35 |
-
|
| 36 |
# 安全性依赖
|
| 37 |
security = HTTPBearer()
|
| 38 |
|
|
@@ -49,7 +43,7 @@ def load_config():
|
|
| 49 |
return []
|
| 50 |
|
| 51 |
config = load_config()
|
| 52 |
-
for index, provider in enumerate(config):
|
| 53 |
model_dict = {}
|
| 54 |
for model in provider['model']:
|
| 55 |
if type(model) == str:
|
|
@@ -57,8 +51,10 @@ for index, provider in enumerate(config):
|
|
| 57 |
if type(model) == dict:
|
| 58 |
model_dict.update({value: key for key, value in model.items()})
|
| 59 |
provider['model'] = model_dict
|
| 60 |
-
config[index] = provider
|
| 61 |
-
|
|
|
|
|
|
|
| 62 |
|
| 63 |
async def process_request(request: RequestModel, provider: Dict):
|
| 64 |
print("provider: ", provider['provider'])
|
|
@@ -93,17 +89,30 @@ class ModelRequestHandler:
|
|
| 93 |
def __init__(self):
|
| 94 |
self.last_provider_index = -1
|
| 95 |
|
| 96 |
-
def get_matching_providers(self, model_name):
|
| 97 |
# for provider in config:
|
| 98 |
# print("provider", model_name, list(provider['model'].keys()))
|
| 99 |
# if model_name in provider['model'].keys():
|
| 100 |
# print("provider", provider)
|
| 101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
|
| 103 |
async def request_model(self, request: RequestModel, token: str):
|
| 104 |
model_name = request.model
|
| 105 |
-
matching_providers = self.get_matching_providers(model_name)
|
| 106 |
-
|
| 107 |
|
| 108 |
if not matching_providers:
|
| 109 |
raise HTTPException(status_code=404, detail="No matching model found")
|
|
@@ -153,7 +162,7 @@ async def log_requests(request: Request, call_next):
|
|
| 153 |
|
| 154 |
def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
| 155 |
token = credentials.credentials
|
| 156 |
-
if token not in
|
| 157 |
raise HTTPException(status_code=403, detail="Invalid or missing API Key")
|
| 158 |
return token
|
| 159 |
|
|
@@ -161,27 +170,42 @@ def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)
|
|
| 161 |
async def request_model(request: RequestModel, token: str = Depends(verify_api_key)):
|
| 162 |
return await model_handler.request_model(request, token)
|
| 163 |
|
| 164 |
-
def get_all_models():
|
| 165 |
all_models = []
|
| 166 |
unique_models = set()
|
| 167 |
|
| 168 |
-
|
| 169 |
-
|
|
|
|
|
|
|
|
|
|
| 170 |
if model not in unique_models:
|
| 171 |
unique_models.add(model)
|
| 172 |
model_info = {
|
| 173 |
"id": model,
|
| 174 |
"object": "model",
|
| 175 |
"created": 1720524448858,
|
| 176 |
-
"owned_by":
|
| 177 |
}
|
| 178 |
all_models.append(model_info)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
|
| 180 |
return all_models
|
| 181 |
|
| 182 |
-
@app.
|
| 183 |
-
async def list_models():
|
| 184 |
-
models = get_all_models()
|
| 185 |
return {
|
| 186 |
"object": "list",
|
| 187 |
"data": models
|
|
|
|
| 27 |
|
| 28 |
app = FastAPI(lifespan=lifespan)
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
# 安全性依赖
|
| 31 |
security = HTTPBearer()
|
| 32 |
|
|
|
|
| 43 |
return []
|
| 44 |
|
| 45 |
config = load_config()
|
| 46 |
+
for index, provider in enumerate(config['providers']):
|
| 47 |
model_dict = {}
|
| 48 |
for model in provider['model']:
|
| 49 |
if type(model) == str:
|
|
|
|
| 51 |
if type(model) == dict:
|
| 52 |
model_dict.update({value: key for key, value in model.items()})
|
| 53 |
provider['model'] = model_dict
|
| 54 |
+
config['providers'][index] = provider
|
| 55 |
+
api_keys_db = config['api_keys']
|
| 56 |
+
api_list = [item["api"] for item in api_keys_db]
|
| 57 |
+
print(json.dumps(config, indent=4, ensure_ascii=False))
|
| 58 |
|
| 59 |
async def process_request(request: RequestModel, provider: Dict):
|
| 60 |
print("provider: ", provider['provider'])
|
|
|
|
| 89 |
def __init__(self):
|
| 90 |
self.last_provider_index = -1
|
| 91 |
|
| 92 |
+
def get_matching_providers(self, model_name, token):
|
| 93 |
# for provider in config:
|
| 94 |
# print("provider", model_name, list(provider['model'].keys()))
|
| 95 |
# if model_name in provider['model'].keys():
|
| 96 |
# print("provider", provider)
|
| 97 |
+
api_index = api_list.index(token)
|
| 98 |
+
provider_rules = {}
|
| 99 |
+
|
| 100 |
+
for model in config['api_keys'][api_index]['model']:
|
| 101 |
+
if "/" in model:
|
| 102 |
+
provider_name = model.split("/")[0]
|
| 103 |
+
model = model.split("/")[1]
|
| 104 |
+
if model_name == model:
|
| 105 |
+
provider_rules[provider_name] = model
|
| 106 |
+
provider_list = []
|
| 107 |
+
for provider in config['providers']:
|
| 108 |
+
if model_name in provider['model'].keys() and ((provider_rules != {} and provider['provider'] in provider_rules.keys()) or provider_rules == {}):
|
| 109 |
+
provider_list.append(provider)
|
| 110 |
+
return provider_list
|
| 111 |
|
| 112 |
async def request_model(self, request: RequestModel, token: str):
|
| 113 |
model_name = request.model
|
| 114 |
+
matching_providers = self.get_matching_providers(model_name, token)
|
| 115 |
+
print("matching_providers", json.dumps(matching_providers, indent=4, ensure_ascii=False))
|
| 116 |
|
| 117 |
if not matching_providers:
|
| 118 |
raise HTTPException(status_code=404, detail="No matching model found")
|
|
|
|
| 162 |
|
| 163 |
def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
| 164 |
token = credentials.credentials
|
| 165 |
+
if token not in api_list:
|
| 166 |
raise HTTPException(status_code=403, detail="Invalid or missing API Key")
|
| 167 |
return token
|
| 168 |
|
|
|
|
| 170 |
async def request_model(request: RequestModel, token: str = Depends(verify_api_key)):
|
| 171 |
return await model_handler.request_model(request, token)
|
| 172 |
|
| 173 |
+
def get_all_models(token):
|
| 174 |
all_models = []
|
| 175 |
unique_models = set()
|
| 176 |
|
| 177 |
+
if token not in api_list:
|
| 178 |
+
raise HTTPException(status_code=403, detail="Invalid or missing API Key")
|
| 179 |
+
api_index = api_list.index(token)
|
| 180 |
+
if config['api_keys'][api_index]['model']:
|
| 181 |
+
for model in config['api_keys'][api_index]['model']:
|
| 182 |
if model not in unique_models:
|
| 183 |
unique_models.add(model)
|
| 184 |
model_info = {
|
| 185 |
"id": model,
|
| 186 |
"object": "model",
|
| 187 |
"created": 1720524448858,
|
| 188 |
+
"owned_by": model
|
| 189 |
}
|
| 190 |
all_models.append(model_info)
|
| 191 |
+
else:
|
| 192 |
+
for provider in config["providers"]:
|
| 193 |
+
for model in provider['model'].keys():
|
| 194 |
+
if model not in unique_models:
|
| 195 |
+
unique_models.add(model)
|
| 196 |
+
model_info = {
|
| 197 |
+
"id": model,
|
| 198 |
+
"object": "model",
|
| 199 |
+
"created": 1720524448858,
|
| 200 |
+
"owned_by": provider['provider']
|
| 201 |
+
}
|
| 202 |
+
all_models.append(model_info)
|
| 203 |
|
| 204 |
return all_models
|
| 205 |
|
| 206 |
+
@app.post("/v1/models")
|
| 207 |
+
async def list_models(token: str = Depends(verify_api_key)):
|
| 208 |
+
models = get_all_models(token)
|
| 209 |
return {
|
| 210 |
"object": "list",
|
| 211 |
"data": models
|
response.py
CHANGED
|
@@ -136,14 +136,20 @@ async def fetch_response(client, url, headers, payload):
|
|
| 136 |
return response.json()
|
| 137 |
|
| 138 |
async def fetch_response_stream(client, url, headers, payload, engine, model):
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
return response.json()
|
| 137 |
|
| 138 |
async def fetch_response_stream(client, url, headers, payload, engine, model):
|
| 139 |
+
for _ in range(2):
|
| 140 |
+
try:
|
| 141 |
+
if engine == "gemini":
|
| 142 |
+
async for chunk in fetch_gemini_response_stream(client, url, headers, payload, model):
|
| 143 |
+
yield chunk
|
| 144 |
+
elif engine == "claude":
|
| 145 |
+
async for chunk in fetch_claude_response_stream(client, url, headers, payload, model):
|
| 146 |
+
yield chunk
|
| 147 |
+
elif engine == "gpt":
|
| 148 |
+
async for chunk in fetch_gpt_response_stream(client, url, headers, payload):
|
| 149 |
+
yield chunk
|
| 150 |
+
else:
|
| 151 |
+
raise ValueError("Unknown response")
|
| 152 |
+
break
|
| 153 |
+
except httpx.ConnectError as e:
|
| 154 |
+
print(f"连接错误: {e}")
|
| 155 |
+
continue
|