File size: 7,032 Bytes
650ab3e 39ed49b 650ab3e cbe330b 2878fe7 cbe330b 2878fe7 cbe330b 2878fe7 cbe330b 2878fe7 650ab3e 39ed49b 650ab3e f9bb630 89b7097 39ed49b 84cce07 39ed49b cbe330b 39ed49b 84cce07 cbe330b 84cce07 f9bb630 2878fe7 cbe330b f9bb630 cbe330b f9bb630 39ed49b f9bb630 cbe330b 39ed49b cbe330b 2878fe7 39ed49b 89b7097 cbe330b 89b7097 cbe330b 89b7097 cbe330b f9bb630 39ed49b f9bb630 cbe330b f9bb630 650ab3e 39ed49b cbe330b 39ed49b 650ab3e 78bc769 39ed49b f9bb630 cbe330b 39ed49b cbe330b 39ed49b 89b7097 39ed49b 2878fe7 39ed49b cbe330b 89b7097 cbe330b 39ed49b cbe330b 39ed49b 89b7097 39ed49b cbe330b f9bb630 2878fe7 f9bb630 cbe330b 650ab3e cbe330b 78bc769 39ed49b 2878fe7 39ed49b cbe330b c977f3d 2878fe7 cbe330b ad2b7e3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 | import os
import time
import requests
from typing import List, Dict
# ================== 环境变量 ==================
LITELLM_BASE_URL = os.environ.get("LITELLM_BASE_URL", "http://localhost:7860")
LITELLM_MASTER_KEY = os.environ["LITELLM_MASTER_KEY"]
CLIPROXY_BASE_URL = os.environ["CLIPROXY_BASE_URL"]
CLIPROXY_API_KEY = os.environ.get("CLIPROXY_API_KEY", "")
CREDENTIAL_NAME = os.environ.get("LITELLM_CREDENTIAL_NAME", "cliproxy")
FORCE_CREDENTIAL = os.environ.get("FORCE_USE_CREDENTIAL", "false").lower() == "true"
PRIMARY_MODEL_GROUP = os.environ.get("FALLBACK_PRIMARY_MODEL", "cliproxy/*")
SYNC_INTERVAL = int(os.environ.get("SYNC_INTERVAL_SECONDS", 3600))
SYNC_TAG = "cliproxy-synced"
HEADERS = {"Authorization": f"Bearer {LITELLM_MASTER_KEY}"}
PROVIDER_PRIORITY = {
"cliproxy": 0,
"moonshotai": 1,
"kimi": 1,
"anthropic": 2,
"google": 3,
}
# ================== 基础工具 ==================
def wait_for_litellm():
while True:
try:
r = requests.get(f"{LITELLM_BASE_URL}/health", headers=HEADERS, timeout=5)
if r.status_code == 200:
print("✅ LiteLLM 已就绪")
return
except Exception:
pass
time.sleep(5)
def infer_provider(owner: str, model_id: str) -> str:
owner = owner.lower()
model_id = model_id.lower()
if owner == "cliproxy":
return "openai"
if "gemini" in model_id or owner == "google":
return "gemini"
if "claude" in model_id or owner == "anthropic":
return "anthropic"
if owner in ("moonshotai", "kimi"):
return "openai"
return "openai"
def is_system_model(name: str) -> bool:
n = name.lower()
return n.startswith(("container", "hf", "litellm", "internal"))
def sort_fallback(models: Dict[str, str]) -> List[str]:
def score(item):
name, owner = item
if "embed" in name.lower():
return (99, name)
return (PROVIDER_PRIORITY.get(owner.lower(), 50), name)
return [n for n, _ in sorted(models.items(), key=score)]
# ================== LiteLLM API ==================
def get_existing_models() -> List[dict]:
r = requests.get(f"{LITELLM_BASE_URL}/v1/models", headers=HEADERS)
r.raise_for_status()
return r.json().get("data", [])
def get_cliproxy_models() -> List[dict]:
headers = {}
if CLIPROXY_API_KEY:
headers["Authorization"] = f"Bearer {CLIPROXY_API_KEY}"
r = requests.get(f"{CLIPROXY_BASE_URL}/models", headers=headers)
r.raise_for_status()
data = r.json()
return data["data"] if isinstance(data, dict) else data
def add_model(original_id: str, owner: str, use_credential: bool):
provider = infer_provider(owner, original_id)
name = f"{owner}/{original_id}"
params = {"model": f"{provider}/{original_id}"}
if use_credential:
params["litellm_credential_name"] = CREDENTIAL_NAME
else:
params["api_base"] = CLIPROXY_BASE_URL.rstrip("/") + "/v1"
params["api_key"] = CLIPROXY_API_KEY
info = {
"owned_by": owner,
"tags": [SYNC_TAG],
"model_type": "embedding" if "embed" in original_id.lower() else "chat",
}
payload = {
"model_name": name,
"litellm_params": params,
"model_info": info,
}
r = requests.post(f"{LITELLM_BASE_URL}/model/new", json=payload, headers=HEADERS)
if r.ok:
print(f"➕ 新增模型 {name}")
else:
print(f"❌ 新增失败 {name}: {r.text}")
def update_model(model_id: str, params: dict, info: dict):
payload = {"id": model_id, "litellm_params": params, "model_info": info}
r = requests.post(f"{LITELLM_BASE_URL}/model/update", json=payload, headers=HEADERS)
if r.ok:
print(f"🔄 更新模型 {model_id}")
else:
print(f"❌ 更新失败 {model_id}: {r.text}")
def delete_model(model_id: str, name: str):
r = requests.post(f"{LITELLM_BASE_URL}/model/delete", json={"id": model_id}, headers=HEADERS)
if r.ok:
print(f"🗑️ 删除模型 {name}")
def update_fallback(models: List[str]):
requests.post(
f"{LITELLM_BASE_URL}/fallback/delete",
json={"model": PRIMARY_MODEL_GROUP, "fallback_type": "general"},
headers=HEADERS,
)
if models:
requests.post(
f"{LITELLM_BASE_URL}/fallback",
json={
"model": PRIMARY_MODEL_GROUP,
"fallback_models": models[:50],
"fallback_type": "general",
},
headers=HEADERS,
)
# ================== 幂等同步主逻辑 ==================
def sync():
print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] 同步开始")
use_credential = FORCE_CREDENTIAL
clip_models = get_cliproxy_models()
clip_map = {f"{m['owned_by']}/{m['id']}": m["owned_by"] for m in clip_models}
existing = get_existing_models()
by_name: Dict[str, List[dict]] = {}
for m in existing:
if "model_name" in m:
by_name.setdefault(m["model_name"], []).append(m)
# === 同步 / 更新 / 去重 ===
for name, owner in clip_map.items():
original_id = name.split("/", 1)[1]
provider = infer_provider(owner, original_id)
desired_params = {"model": f"{provider}/{original_id}"}
if use_credential:
desired_params["litellm_credential_name"] = CREDENTIAL_NAME
else:
desired_params["api_base"] = CLIPROXY_BASE_URL.rstrip("/") + "/v1"
desired_params["api_key"] = CLIPROXY_API_KEY
desired_info = {
"owned_by": owner,
"tags": [SYNC_TAG],
"model_type": "embedding" if "embed" in original_id.lower() else "chat",
}
models = by_name.get(name, [])
if not models:
add_model(original_id, owner, use_credential)
continue
primary = models[0]
current_params = primary.get("litellm_params", {})
has_cred = "litellm_credential_name" in current_params
if has_cred != use_credential:
update_model(primary["id"], desired_params, desired_info)
for dup in models[1:]:
delete_model(dup["id"], name)
# === 删除失效模型 ===
for name, models in by_name.items():
if name not in clip_map:
for m in models:
if SYNC_TAG in m.get("tags", []):
delete_model(m["id"], name)
# === fallback ===
sorted_models = sort_fallback(clip_map)
fallback_models = [
m for m in sorted_models
if "embed" not in m.lower() and not is_system_model(m)
]
update_fallback(fallback_models)
print("✅ 同步完成\n")
# ================== 守护 ==================
if __name__ == "__main__":
wait_for_litellm()
print(f"🚀 同步守护启动,每 {SYNC_INTERVAL}s 执行一次")
while True:
try:
sync()
except Exception as e:
print("同步异常:", e)
time.sleep(SYNC_INTERVAL) |