File size: 7,032 Bytes
650ab3e
 
 
39ed49b
650ab3e
cbe330b
2878fe7
 
cbe330b
2878fe7
 
cbe330b
2878fe7
 
cbe330b
2878fe7
 
650ab3e
 
39ed49b
650ab3e
f9bb630
 
 
 
 
 
 
89b7097
39ed49b
 
 
84cce07
39ed49b
cbe330b
39ed49b
84cce07
cbe330b
84cce07
 
f9bb630
2878fe7
cbe330b
 
f9bb630
 
 
 
 
 
 
 
 
 
cbe330b
 
 
f9bb630
39ed49b
f9bb630
cbe330b
 
 
 
 
 
39ed49b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cbe330b
2878fe7
39ed49b
89b7097
cbe330b
89b7097
cbe330b
89b7097
cbe330b
 
f9bb630
39ed49b
f9bb630
 
cbe330b
f9bb630
 
650ab3e
39ed49b
cbe330b
39ed49b
650ab3e
78bc769
39ed49b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f9bb630
cbe330b
 
 
 
39ed49b
cbe330b
39ed49b
 
 
 
 
 
 
 
 
 
 
 
89b7097
39ed49b
2878fe7
39ed49b
cbe330b
 
89b7097
cbe330b
39ed49b
cbe330b
39ed49b
 
89b7097
39ed49b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cbe330b
f9bb630
2878fe7
f9bb630
cbe330b
650ab3e
cbe330b
78bc769
39ed49b
2878fe7
39ed49b
cbe330b
c977f3d
2878fe7
 
 
cbe330b
ad2b7e3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
import os
import time
import requests
from typing import List, Dict

# ================== 环境变量 ==================
LITELLM_BASE_URL = os.environ.get("LITELLM_BASE_URL", "http://localhost:7860")
LITELLM_MASTER_KEY = os.environ["LITELLM_MASTER_KEY"]

CLIPROXY_BASE_URL = os.environ["CLIPROXY_BASE_URL"]
CLIPROXY_API_KEY = os.environ.get("CLIPROXY_API_KEY", "")

CREDENTIAL_NAME = os.environ.get("LITELLM_CREDENTIAL_NAME", "cliproxy")
FORCE_CREDENTIAL = os.environ.get("FORCE_USE_CREDENTIAL", "false").lower() == "true"

PRIMARY_MODEL_GROUP = os.environ.get("FALLBACK_PRIMARY_MODEL", "cliproxy/*")
SYNC_INTERVAL = int(os.environ.get("SYNC_INTERVAL_SECONDS", 3600))

SYNC_TAG = "cliproxy-synced"
HEADERS = {"Authorization": f"Bearer {LITELLM_MASTER_KEY}"}

PROVIDER_PRIORITY = {
    "cliproxy": 0,
    "moonshotai": 1,
    "kimi": 1,
    "anthropic": 2,
    "google": 3,
}

# ================== 基础工具 ==================
def wait_for_litellm():
    while True:
        try:
            r = requests.get(f"{LITELLM_BASE_URL}/health", headers=HEADERS, timeout=5)
            if r.status_code == 200:
                print("✅ LiteLLM 已就绪")
                return
        except Exception:
            pass
        time.sleep(5)

def infer_provider(owner: str, model_id: str) -> str:
    owner = owner.lower()
    model_id = model_id.lower()
    if owner == "cliproxy":
        return "openai"
    if "gemini" in model_id or owner == "google":
        return "gemini"
    if "claude" in model_id or owner == "anthropic":
        return "anthropic"
    if owner in ("moonshotai", "kimi"):
        return "openai"
    return "openai"

def is_system_model(name: str) -> bool:
    n = name.lower()
    return n.startswith(("container", "hf", "litellm", "internal"))

def sort_fallback(models: Dict[str, str]) -> List[str]:
    def score(item):
        name, owner = item
        if "embed" in name.lower():
            return (99, name)
        return (PROVIDER_PRIORITY.get(owner.lower(), 50), name)
    return [n for n, _ in sorted(models.items(), key=score)]

# ================== LiteLLM API ==================
def get_existing_models() -> List[dict]:
    r = requests.get(f"{LITELLM_BASE_URL}/v1/models", headers=HEADERS)
    r.raise_for_status()
    return r.json().get("data", [])

def get_cliproxy_models() -> List[dict]:
    headers = {}
    if CLIPROXY_API_KEY:
        headers["Authorization"] = f"Bearer {CLIPROXY_API_KEY}"
    r = requests.get(f"{CLIPROXY_BASE_URL}/models", headers=headers)
    r.raise_for_status()
    data = r.json()
    return data["data"] if isinstance(data, dict) else data

def add_model(original_id: str, owner: str, use_credential: bool):
    provider = infer_provider(owner, original_id)
    name = f"{owner}/{original_id}"

    params = {"model": f"{provider}/{original_id}"}
    if use_credential:
        params["litellm_credential_name"] = CREDENTIAL_NAME
    else:
        params["api_base"] = CLIPROXY_BASE_URL.rstrip("/") + "/v1"
        params["api_key"] = CLIPROXY_API_KEY

    info = {
        "owned_by": owner,
        "tags": [SYNC_TAG],
        "model_type": "embedding" if "embed" in original_id.lower() else "chat",
    }

    payload = {
        "model_name": name,
        "litellm_params": params,
        "model_info": info,
    }

    r = requests.post(f"{LITELLM_BASE_URL}/model/new", json=payload, headers=HEADERS)
    if r.ok:
        print(f"➕ 新增模型 {name}")
    else:
        print(f"❌ 新增失败 {name}: {r.text}")

def update_model(model_id: str, params: dict, info: dict):
    payload = {"id": model_id, "litellm_params": params, "model_info": info}
    r = requests.post(f"{LITELLM_BASE_URL}/model/update", json=payload, headers=HEADERS)
    if r.ok:
        print(f"🔄 更新模型 {model_id}")
    else:
        print(f"❌ 更新失败 {model_id}: {r.text}")

def delete_model(model_id: str, name: str):
    r = requests.post(f"{LITELLM_BASE_URL}/model/delete", json={"id": model_id}, headers=HEADERS)
    if r.ok:
        print(f"🗑️ 删除模型 {name}")

def update_fallback(models: List[str]):
    requests.post(
        f"{LITELLM_BASE_URL}/fallback/delete",
        json={"model": PRIMARY_MODEL_GROUP, "fallback_type": "general"},
        headers=HEADERS,
    )
    if models:
        requests.post(
            f"{LITELLM_BASE_URL}/fallback",
            json={
                "model": PRIMARY_MODEL_GROUP,
                "fallback_models": models[:50],
                "fallback_type": "general",
            },
            headers=HEADERS,
        )

# ================== 幂等同步主逻辑 ==================
def sync():
    print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] 同步开始")

    use_credential = FORCE_CREDENTIAL
    clip_models = get_cliproxy_models()
    clip_map = {f"{m['owned_by']}/{m['id']}": m["owned_by"] for m in clip_models}

    existing = get_existing_models()
    by_name: Dict[str, List[dict]] = {}
    for m in existing:
        if "model_name" in m:
            by_name.setdefault(m["model_name"], []).append(m)

    # === 同步 / 更新 / 去重 ===
    for name, owner in clip_map.items():
        original_id = name.split("/", 1)[1]
        provider = infer_provider(owner, original_id)

        desired_params = {"model": f"{provider}/{original_id}"}
        if use_credential:
            desired_params["litellm_credential_name"] = CREDENTIAL_NAME
        else:
            desired_params["api_base"] = CLIPROXY_BASE_URL.rstrip("/") + "/v1"
            desired_params["api_key"] = CLIPROXY_API_KEY

        desired_info = {
            "owned_by": owner,
            "tags": [SYNC_TAG],
            "model_type": "embedding" if "embed" in original_id.lower() else "chat",
        }

        models = by_name.get(name, [])
        if not models:
            add_model(original_id, owner, use_credential)
            continue

        primary = models[0]
        current_params = primary.get("litellm_params", {})
        has_cred = "litellm_credential_name" in current_params

        if has_cred != use_credential:
            update_model(primary["id"], desired_params, desired_info)

        for dup in models[1:]:
            delete_model(dup["id"], name)

    # === 删除失效模型 ===
    for name, models in by_name.items():
        if name not in clip_map:
            for m in models:
                if SYNC_TAG in m.get("tags", []):
                    delete_model(m["id"], name)

    # === fallback ===
    sorted_models = sort_fallback(clip_map)
    fallback_models = [
        m for m in sorted_models
        if "embed" not in m.lower() and not is_system_model(m)
    ]
    update_fallback(fallback_models)

    print("✅ 同步完成\n")

# ================== 守护 ==================
if __name__ == "__main__":
    wait_for_litellm()
    print(f"🚀 同步守护启动,每 {SYNC_INTERVAL}s 执行一次")
    while True:
        try:
            sync()
        except Exception as e:
            print("同步异常:", e)
        time.sleep(SYNC_INTERVAL)