Update Dockerfile
Browse files- Dockerfile +43 -17
Dockerfile
CHANGED
|
@@ -1,7 +1,11 @@
|
|
| 1 |
FROM python:3.11-slim
|
|
|
|
| 2 |
WORKDIR /app
|
| 3 |
-
|
|
|
|
| 4 |
RUN pip install fastapi uvicorn google-genai
|
|
|
|
|
|
|
| 5 |
RUN cat > main.py << 'PYEOF'
|
| 6 |
from fastapi import FastAPI, Request
|
| 7 |
from fastapi.responses import JSONResponse
|
|
@@ -12,28 +16,50 @@ app = FastAPI()
|
|
| 12 |
GOOGLE_API_KEY = os.environ["GOOGLE_API_KEY"]
|
| 13 |
client = genai.Client(api_key=GOOGLE_API_KEY)
|
| 14 |
|
| 15 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
@app.get("/v1/models")
|
| 17 |
async def list_models():
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
"data": [
|
| 21 |
-
{"id": "gemini-embedding-001", "object": "model", "owned_by": "gemini"},
|
| 22 |
-
{"id": "gemini-embedding-2-preview", "object": "model", "owned_by": "gemini"}
|
| 23 |
-
]
|
| 24 |
-
})
|
| 25 |
|
| 26 |
@app.post("/v1/embeddings")
|
| 27 |
async def create_embeddings(request: Request):
|
| 28 |
try:
|
| 29 |
body = await request.json()
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
if not model.startswith("models/"):
|
| 34 |
-
model_full = f"models/{model}"
|
| 35 |
-
else:
|
| 36 |
-
model_full = model
|
| 37 |
input_text = body.get("input", "")
|
| 38 |
result = client.models.embed_content(
|
| 39 |
model=model_full,
|
|
@@ -46,7 +72,7 @@ async def create_embeddings(request: Request):
|
|
| 46 |
"embedding": result.embeddings[0].values,
|
| 47 |
"index": 0
|
| 48 |
}],
|
| 49 |
-
"model":
|
| 50 |
})
|
| 51 |
except Exception as e:
|
| 52 |
return JSONResponse({"error": str(e)}, status_code=500)
|
|
|
|
| 1 |
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
# 安装依赖
|
| 6 |
RUN pip install fastapi uvicorn google-genai
|
| 7 |
+
|
| 8 |
+
# 生成 main.py(支持自动获取模型列表)
|
| 9 |
RUN cat > main.py << 'PYEOF'
|
| 10 |
from fastapi import FastAPI, Request
|
| 11 |
from fastapi.responses import JSONResponse
|
|
|
|
| 16 |
GOOGLE_API_KEY = os.environ["GOOGLE_API_KEY"]
|
| 17 |
client = genai.Client(api_key=GOOGLE_API_KEY)
|
| 18 |
|
| 19 |
+
# 静态模型列表:作为 API 调用失败时的后备方案
|
| 20 |
+
STATIC_MODELS = [
|
| 21 |
+
{"id": "gemini-embedding-001", "object": "model", "owned_by": "gemini"},
|
| 22 |
+
{"id": "gemini-embedding-2-preview", "object": "model", "owned_by": "gemini"},
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
# 模型名称映射(短名称 -> 完整路径),用于 /v1/embeddings 端点
|
| 26 |
+
MODEL_MAP = {
|
| 27 |
+
"text-embedding-004": "models/text-embedding-004",
|
| 28 |
+
"gemini-embedding-001": "models/gemini-embedding-001",
|
| 29 |
+
"gemini-embedding-2-preview": "models/gemini-embedding-2-preview",
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
def fetch_embedding_models():
|
| 33 |
+
"""动态获取支持 embedContent 的 Gemini Embedding 模型列表"""
|
| 34 |
+
try:
|
| 35 |
+
all_models = client.models.list()
|
| 36 |
+
embedding_models = []
|
| 37 |
+
for m in all_models:
|
| 38 |
+
if hasattr(m, 'supported_actions') and 'embedContent' in m.supported_actions:
|
| 39 |
+
# 提取短名称 (e.g., "models/gemini-embedding-001" -> "gemini-embedding-001")
|
| 40 |
+
model_id = m.name.split('/')[-1]
|
| 41 |
+
embedding_models.append({
|
| 42 |
+
"id": model_id,
|
| 43 |
+
"object": "model",
|
| 44 |
+
"owned_by": "gemini"
|
| 45 |
+
})
|
| 46 |
+
return embedding_models
|
| 47 |
+
except Exception as e:
|
| 48 |
+
print(f"动态获取模型列表失败,回退到静态列表: {e}")
|
| 49 |
+
return STATIC_MODELS
|
| 50 |
+
|
| 51 |
@app.get("/v1/models")
|
| 52 |
async def list_models():
|
| 53 |
+
models = fetch_embedding_models()
|
| 54 |
+
return JSONResponse({"object": "list", "data": models})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
@app.post("/v1/embeddings")
|
| 57 |
async def create_embeddings(request: Request):
|
| 58 |
try:
|
| 59 |
body = await request.json()
|
| 60 |
+
model_short = body.get("model", "gemini-embedding-001")
|
| 61 |
+
# 自动补全模型名称为完整路径
|
| 62 |
+
model_full = MODEL_MAP.get(model_short, f"models/{model_short}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
input_text = body.get("input", "")
|
| 64 |
result = client.models.embed_content(
|
| 65 |
model=model_full,
|
|
|
|
| 72 |
"embedding": result.embeddings[0].values,
|
| 73 |
"index": 0
|
| 74 |
}],
|
| 75 |
+
"model": model_short
|
| 76 |
})
|
| 77 |
except Exception as e:
|
| 78 |
return JSONResponse({"error": str(e)}, status_code=500)
|