File size: 4,210 Bytes
af641e5
86d14ea
ff017f1
86d14ea
af641e5
86d14ea
f97a9a4
039946a
a8a963a
f97a9a4
a8a963a
f97a9a4
 
 
039946a
 
 
 
 
 
 
 
 
 
 
 
205280e
c78c47f
 
 
 
a8a963a
 
 
 
 
c78c47f
039946a
 
6425e65
 
 
 
039946a
 
 
 
 
 
6425e65
039946a
 
6425e65
 
 
 
 
205280e
 
 
 
 
 
 
 
 
 
 
 
 
039946a
6425e65
205280e
f97a9a4
6425e65
039946a
 
6425e65
 
 
 
 
 
c78c47f
f97a9a4
 
039946a
 
 
 
f97a9a4
 
86d14ea
a8a963a
f97a9a4
039946a
243a90b
 
 
 
f97a9a4
 
 
 
 
 
 
86d14ea
f97a9a4
6425e65
 
f97a9a4
 
 
 
 
 
ff017f1
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
FROM python:3.11-slim

WORKDIR /app

RUN pip install fastapi uvicorn google-genai

RUN cat > main.py << 'PYEOF'
from fastapi import FastAPI, Request, HTTPException, Depends
from fastapi.responses import JSONResponse
from google import genai
import os, uvicorn

app = FastAPI()

# ========== 自定义代理密钥(可选)==========
PROXY_API_KEY = os.getenv("PROXY_API_KEY", "").strip()

def verify_proxy_key(request: Request):
    """如果设置了 PROXY_API_KEY,则验证每次请求都必须携带相同的 key"""
    if not PROXY_API_KEY:
        return  # 未设置则不验证
    auth = request.headers.get("Authorization", "")
    if auth != f"Bearer {PROXY_API_KEY}":
        raise HTTPException(status_code=401, detail="Invalid proxy API key")

# ========== 模型配置 ==========
STATIC_MODELS = [
    {"id": "gemini-embedding-001", "object": "model", "owned_by": "gemini"},
    {"id": "gemini-embedding-2-preview", "object": "model", "owned_by": "gemini"},
]

MODEL_MAP = {
    "text-embedding-004": "models/text-embedding-004",
    "gemini-embedding-001": "models/gemini-embedding-001",
    "gemini-embedding-2-preview": "models/gemini-embedding-2-preview",
}

def get_gemini_key(request: Request = None) -> str:
    """从 Authorization 头提取真正的 Gemini API Key(如果提供了),否则使用环境变量"""
    if request:
        auth_header = request.headers.get("Authorization", "")
        if auth_header.startswith("Bearer "):
            token = auth_header[7:].strip()
            # 如果 token 等于代理密钥,则忽略,继续用环境变量
            if PROXY_API_KEY and token == PROXY_API_KEY:
                pass  # 代理密钥不是 Gemini Key
            else:
                # 不是代理密钥,就当成用户自己的 Gemini Key
                if not token.startswith("hf_"):  # 排除 HF Token
                    return token
    # 回退到环境变量
    return os.getenv("GOOGLE_API_KEY", "").strip()

def get_client(api_key: str) -> genai.Client:
    return genai.Client(api_key=api_key)

def fetch_embedding_models(client):
    try:
        all_models = client.models.list()
        embedding_models = []
        for m in all_models:
            if hasattr(m, 'supported_actions') and 'embedContent' in m.supported_actions:
                model_id = m.name.split('/')[-1]
                embedding_models.append({
                    "id": model_id,
                    "object": "model",
                    "owned_by": "gemini"
                })
        return embedding_models
    except Exception as e:
        print(f"动态获取模型列表失败: {e}")
        return None

@app.get("/v1/models")
async def list_models(request: Request):
    verify_proxy_key(request)
    api_key = get_gemini_key(request)
    if api_key:
        client = get_client(api_key)
        dynamic_models = fetch_embedding_models(client)
        if dynamic_models:
            return JSONResponse({"object": "list", "data": dynamic_models})
    return JSONResponse({"object": "list", "data": STATIC_MODELS})

@app.post("/v1/embeddings")
async def create_embeddings(request: Request):
    verify_proxy_key(request)
    gemini_key = get_gemini_key(request)
    if not gemini_key:
        raise HTTPException(status_code=401, detail="Missing Gemini API key")
    try:
        body = await request.json()
        model_short = body.get("model", "gemini-embedding-001")
        model_full = MODEL_MAP.get(model_short, f"models/{model_short}")
        input_text = body.get("input", "")
        client = get_client(gemini_key)
        result = client.models.embed_content(
            model=model_full,
            contents=input_text
        )
        return JSONResponse({
            "object": "list",
            "data": [{
                "object": "embedding",
                "embedding": result.embeddings[0].values,
                "index": 0
            }],
            "model": model_short
        })
    except HTTPException:
        raise
    except Exception as e:
        return JSONResponse({"error": str(e)}, status_code=500)

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)
PYEOF

CMD ["python", "main.py"]