Spaces:
Paused
Paused
Update app/main.py
Browse files- app/main.py +220 -6
app/main.py
CHANGED
|
@@ -15,6 +15,8 @@ import random
|
|
| 15 |
import urllib.parse
|
| 16 |
from google.oauth2 import service_account
|
| 17 |
import config
|
|
|
|
|
|
|
| 18 |
|
| 19 |
from google.genai import types
|
| 20 |
|
|
@@ -1149,6 +1151,15 @@ async def list_models(api_key: str = Depends(get_api_key)):
|
|
| 1149 |
"root": "gemini-2.5-pro-exp-03-25",
|
| 1150 |
"parent": None,
|
| 1151 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1152 |
{
|
| 1153 |
"id": "gemini-2.5-pro-preview-03-25",
|
| 1154 |
"object": "model",
|
|
@@ -1336,6 +1347,15 @@ def create_openai_error_response(status_code: int, message: str, error_type: str
|
|
| 1336 |
}
|
| 1337 |
}
|
| 1338 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1339 |
@app.post("/v1/chat/completions")
|
| 1340 |
async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_api_key)): # Add request parameter
|
| 1341 |
try:
|
|
@@ -1348,10 +1368,193 @@ async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_ap
|
|
| 1348 |
)
|
| 1349 |
return JSONResponse(status_code=400, content=error_response)
|
| 1350 |
|
| 1351 |
-
#
|
| 1352 |
-
|
| 1353 |
-
|
| 1354 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1355 |
is_encrypted_full_model = request.model.endswith("-encrypt-full")
|
| 1356 |
is_nothinking_model = request.model.endswith("-nothinking")
|
| 1357 |
is_max_thinking_model = request.model.endswith("-max")
|
|
@@ -1418,7 +1621,8 @@ async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_ap
|
|
| 1418 |
types.SafetySetting(category="HARM_CATEGORY_HATE_SPEECH", threshold="OFF"),
|
| 1419 |
types.SafetySetting(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="OFF"),
|
| 1420 |
types.SafetySetting(category="HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold="OFF"),
|
| 1421 |
-
types.SafetySetting(category="HARM_CATEGORY_HARASSMENT", threshold="OFF")
|
|
|
|
| 1422 |
]
|
| 1423 |
generation_config["safety_settings"] = safety_settings
|
| 1424 |
|
|
@@ -1518,8 +1722,18 @@ async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_ap
|
|
| 1518 |
# --- Main Logic ---
|
| 1519 |
last_error = None
|
| 1520 |
|
| 1521 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1522 |
print(f"Processing auto model: {request.model}")
|
|
|
|
| 1523 |
# Define encryption instructions for system_instruction
|
| 1524 |
encryption_instructions = [
|
| 1525 |
"// AI Assistant Configuration //",
|
|
|
|
| 15 |
import urllib.parse
|
| 16 |
from google.oauth2 import service_account
|
| 17 |
import config
|
| 18 |
+
import openai # Added import
|
| 19 |
+
from google.auth.transport.requests import Request as AuthRequest # Added import
|
| 20 |
|
| 21 |
from google.genai import types
|
| 22 |
|
|
|
|
| 1151 |
"root": "gemini-2.5-pro-exp-03-25",
|
| 1152 |
"parent": None,
|
| 1153 |
},
|
| 1154 |
+
{ # Added new model entry for OpenAI endpoint
|
| 1155 |
+
"id": "gemini-2.5-pro-exp-03-25-openai",
|
| 1156 |
+
"object": "model",
|
| 1157 |
+
"created": int(time.time()),
|
| 1158 |
+
"owned_by": "google",
|
| 1159 |
+
"permission": [],
|
| 1160 |
+
"root": "gemini-2.5-pro-exp-03-25", # Underlying model
|
| 1161 |
+
"parent": None,
|
| 1162 |
+
},
|
| 1163 |
{
|
| 1164 |
"id": "gemini-2.5-pro-preview-03-25",
|
| 1165 |
"object": "model",
|
|
|
|
| 1347 |
}
|
| 1348 |
}
|
| 1349 |
|
| 1350 |
+
# Helper for token refresh
|
| 1351 |
+
def _refresh_auth(credentials):
|
| 1352 |
+
try:
|
| 1353 |
+
credentials.refresh(AuthRequest())
|
| 1354 |
+
return credentials.token
|
| 1355 |
+
except Exception as e:
|
| 1356 |
+
print(f"Error refreshing GCP token: {e}")
|
| 1357 |
+
return None
|
| 1358 |
+
|
| 1359 |
@app.post("/v1/chat/completions")
|
| 1360 |
async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_api_key)): # Add request parameter
|
| 1361 |
try:
|
|
|
|
| 1368 |
)
|
| 1369 |
return JSONResponse(status_code=400, content=error_response)
|
| 1370 |
|
| 1371 |
+
# --- Handle specific OpenAI client model ---
|
| 1372 |
+
if request.model == "gemini-2.5-pro-exp-03-25-openai":
|
| 1373 |
+
print(f"INFO: Using OpenAI library path for model: {request.model}")
|
| 1374 |
+
|
| 1375 |
+
# --- Determine Credentials for OpenAI Client (Correct Priority) ---
|
| 1376 |
+
credentials_to_use = None
|
| 1377 |
+
project_id_to_use = None
|
| 1378 |
+
credential_source = "unknown"
|
| 1379 |
+
|
| 1380 |
+
# Priority 1: GOOGLE_CREDENTIALS_JSON (JSON String in Env Var)
|
| 1381 |
+
credentials_json_str = os.environ.get("GOOGLE_CREDENTIALS_JSON")
|
| 1382 |
+
if credentials_json_str:
|
| 1383 |
+
try:
|
| 1384 |
+
credentials_info = json.loads(credentials_json_str)
|
| 1385 |
+
if not isinstance(credentials_info, dict): raise ValueError("JSON is not a dict")
|
| 1386 |
+
required = ["type", "project_id", "private_key_id", "private_key", "client_email"]
|
| 1387 |
+
if any(f not in credentials_info for f in required): raise ValueError("Missing required fields")
|
| 1388 |
+
|
| 1389 |
+
credentials = service_account.Credentials.from_service_account_info(
|
| 1390 |
+
credentials_info, scopes=['https://www.googleapis.com/auth/cloud-platform']
|
| 1391 |
+
)
|
| 1392 |
+
project_id = credentials.project_id
|
| 1393 |
+
credentials_to_use = credentials
|
| 1394 |
+
project_id_to_use = project_id
|
| 1395 |
+
credential_source = "GOOGLE_CREDENTIALS_JSON env var"
|
| 1396 |
+
print(f"INFO: [OpenAI Path] Using credentials from {credential_source} for project: {project_id_to_use}")
|
| 1397 |
+
except Exception as e:
|
| 1398 |
+
print(f"WARNING: [OpenAI Path] Error processing GOOGLE_CREDENTIALS_JSON: {e}. Trying next method.")
|
| 1399 |
+
credentials_to_use = None # Ensure reset if failed
|
| 1400 |
+
|
| 1401 |
+
# Priority 2: Credential Manager (Rotated Files)
|
| 1402 |
+
if credentials_to_use is None:
|
| 1403 |
+
print(f"INFO: [OpenAI Path] Checking Credential Manager (directory: {credential_manager.credentials_dir})")
|
| 1404 |
+
rotated_credentials, rotated_project_id = credential_manager.get_next_credentials()
|
| 1405 |
+
if rotated_credentials and rotated_project_id:
|
| 1406 |
+
credentials_to_use = rotated_credentials
|
| 1407 |
+
project_id_to_use = rotated_project_id
|
| 1408 |
+
credential_source = f"Credential Manager file (Index: {credential_manager.current_index -1 if credential_manager.current_index > 0 else len(credential_manager.credentials_files) - 1})"
|
| 1409 |
+
print(f"INFO: [OpenAI Path] Using credentials from {credential_source} for project: {project_id_to_use}")
|
| 1410 |
+
else:
|
| 1411 |
+
print(f"INFO: [OpenAI Path] No credentials loaded via Credential Manager.")
|
| 1412 |
+
|
| 1413 |
+
# Priority 3: GOOGLE_APPLICATION_CREDENTIALS (File Path in Env Var)
|
| 1414 |
+
if credentials_to_use is None:
|
| 1415 |
+
file_path = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS")
|
| 1416 |
+
if file_path:
|
| 1417 |
+
print(f"INFO: [OpenAI Path] Checking GOOGLE_APPLICATION_CREDENTIALS file path: {file_path}")
|
| 1418 |
+
if os.path.exists(file_path):
|
| 1419 |
+
try:
|
| 1420 |
+
credentials = service_account.Credentials.from_service_account_file(
|
| 1421 |
+
file_path, scopes=['https://www.googleapis.com/auth/cloud-platform']
|
| 1422 |
+
)
|
| 1423 |
+
project_id = credentials.project_id
|
| 1424 |
+
credentials_to_use = credentials
|
| 1425 |
+
project_id_to_use = project_id
|
| 1426 |
+
credential_source = "GOOGLE_APPLICATION_CREDENTIALS file path"
|
| 1427 |
+
print(f"INFO: [OpenAI Path] Using credentials from {credential_source} for project: {project_id_to_use}")
|
| 1428 |
+
except Exception as e:
|
| 1429 |
+
print(f"ERROR: [OpenAI Path] Failed to load credentials from GOOGLE_APPLICATION_CREDENTIALS path ({file_path}): {e}")
|
| 1430 |
+
else:
|
| 1431 |
+
print(f"ERROR: [OpenAI Path] GOOGLE_APPLICATION_CREDENTIALS file does not exist at path: {file_path}")
|
| 1432 |
+
|
| 1433 |
+
# Error if no credentials found after all checks
|
| 1434 |
+
if credentials_to_use is None or project_id_to_use is None:
|
| 1435 |
+
error_msg = "No valid credentials found for OpenAI client path. Tried GOOGLE_CREDENTIALS_JSON, Credential Manager, and GOOGLE_APPLICATION_CREDENTIALS."
|
| 1436 |
+
print(f"ERROR: {error_msg}")
|
| 1437 |
+
error_response = create_openai_error_response(500, error_msg, "server_error")
|
| 1438 |
+
return JSONResponse(status_code=500, content=error_response)
|
| 1439 |
+
# --- Credentials Determined ---
|
| 1440 |
+
|
| 1441 |
+
# Get/Refresh GCP Token from the chosen credentials (credentials_to_use)
|
| 1442 |
+
gcp_token = None
|
| 1443 |
+
if credentials_to_use.expired or not credentials_to_use.token:
|
| 1444 |
+
print(f"INFO: [OpenAI Path] Refreshing GCP token (Source: {credential_source})...")
|
| 1445 |
+
gcp_token = _refresh_auth(credentials_to_use)
|
| 1446 |
+
else:
|
| 1447 |
+
gcp_token = credentials_to_use.token
|
| 1448 |
+
|
| 1449 |
+
if not gcp_token:
|
| 1450 |
+
error_msg = f"Failed to obtain valid GCP token for OpenAI client (Source: {credential_source})."
|
| 1451 |
+
print(f"ERROR: {error_msg}")
|
| 1452 |
+
error_response = create_openai_error_response(500, error_msg, "server_error")
|
| 1453 |
+
return JSONResponse(status_code=500, content=error_response)
|
| 1454 |
+
|
| 1455 |
+
# Configuration using determined Project ID
|
| 1456 |
+
PROJECT_ID = project_id_to_use
|
| 1457 |
+
LOCATION = "us-central1" # Assuming same location as genai client
|
| 1458 |
+
VERTEX_AI_OPENAI_ENDPOINT_URL = (
|
| 1459 |
+
f"https://{LOCATION}-aiplatform.googleapis.com/v1beta1/"
|
| 1460 |
+
f"projects/{PROJECT_ID}/locations/{LOCATION}/endpoints/openapi"
|
| 1461 |
+
)
|
| 1462 |
+
UNDERLYING_MODEL_ID = "gemini-2.5-pro-exp-03-25" # As specified
|
| 1463 |
+
|
| 1464 |
+
# Initialize Async OpenAI Client
|
| 1465 |
+
openai_client = openai.AsyncOpenAI(
|
| 1466 |
+
base_url=VERTEX_AI_OPENAI_ENDPOINT_URL,
|
| 1467 |
+
api_key=gcp_token,
|
| 1468 |
+
)
|
| 1469 |
+
|
| 1470 |
+
# Define standard safety settings (as used elsewhere)
|
| 1471 |
+
openai_safety_settings = [
|
| 1472 |
+
{
|
| 1473 |
+
"category": "HARM_CATEGORY_HARASSMENT",
|
| 1474 |
+
"threshold": "OFF"
|
| 1475 |
+
},
|
| 1476 |
+
{
|
| 1477 |
+
"category": "HARM_CATEGORY_HATE_SPEECH",
|
| 1478 |
+
"threshold": "OFF"
|
| 1479 |
+
},
|
| 1480 |
+
{
|
| 1481 |
+
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
| 1482 |
+
"threshold": "OFF"
|
| 1483 |
+
},
|
| 1484 |
+
{
|
| 1485 |
+
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
| 1486 |
+
"threshold": "OFF"
|
| 1487 |
+
},
|
| 1488 |
+
{
|
| 1489 |
+
"category": 'HARM_CATEGORY_CIVIC_INTEGRITY',
|
| 1490 |
+
"threshold": 'OFF'
|
| 1491 |
+
}
|
| 1492 |
+
]
|
| 1493 |
+
|
| 1494 |
+
# Prepare parameters for OpenAI client call
|
| 1495 |
+
openai_params = {
|
| 1496 |
+
"model": UNDERLYING_MODEL_ID,
|
| 1497 |
+
"messages": [msg.model_dump(exclude_unset=True) for msg in request.messages],
|
| 1498 |
+
"temperature": request.temperature,
|
| 1499 |
+
"max_tokens": request.max_tokens,
|
| 1500 |
+
"top_p": request.top_p,
|
| 1501 |
+
"stream": request.stream,
|
| 1502 |
+
"stop": request.stop,
|
| 1503 |
+
# "presence_penalty": request.presence_penalty,
|
| 1504 |
+
# "frequency_penalty": request.frequency_penalty,
|
| 1505 |
+
"seed": request.seed,
|
| 1506 |
+
"n": request.n,
|
| 1507 |
+
# Note: logprobs/response_logprobs mapping might need adjustment
|
| 1508 |
+
# Note: top_k is not directly supported by standard OpenAI API spec
|
| 1509 |
+
}
|
| 1510 |
+
# Add safety settings via extra_body
|
| 1511 |
+
openai_extra_body = {
|
| 1512 |
+
'google': {
|
| 1513 |
+
'safety_settings': openai_safety_settings
|
| 1514 |
+
}
|
| 1515 |
+
}
|
| 1516 |
+
openai_params = {k: v for k, v in openai_params.items() if v is not None}
|
| 1517 |
+
|
| 1518 |
+
|
| 1519 |
+
# Make the call using OpenAI client
|
| 1520 |
+
if request.stream:
|
| 1521 |
+
async def openai_stream_generator():
|
| 1522 |
+
try:
|
| 1523 |
+
stream = await openai_client.chat.completions.create(
|
| 1524 |
+
**openai_params,
|
| 1525 |
+
extra_body=openai_extra_body # Pass safety settings here
|
| 1526 |
+
)
|
| 1527 |
+
async for chunk in stream:
|
| 1528 |
+
yield f"data: {chunk.model_dump_json()}\n\n"
|
| 1529 |
+
yield "data: [DONE]\n\n"
|
| 1530 |
+
except Exception as stream_error:
|
| 1531 |
+
error_msg = f"Error during OpenAI client streaming for {request.model}: {str(stream_error)}"
|
| 1532 |
+
print(error_msg)
|
| 1533 |
+
error_response_content = create_openai_error_response(500, error_msg, "server_error")
|
| 1534 |
+
yield f"data: {json.dumps(error_response_content)}\n\n"
|
| 1535 |
+
yield "data: [DONE]\n\n"
|
| 1536 |
+
|
| 1537 |
+
return StreamingResponse(openai_stream_generator(), media_type="text/event-stream")
|
| 1538 |
+
else:
|
| 1539 |
+
try:
|
| 1540 |
+
response = await openai_client.chat.completions.create(
|
| 1541 |
+
**openai_params,
|
| 1542 |
+
extra_body=openai_extra_body # Pass safety settings here
|
| 1543 |
+
)
|
| 1544 |
+
return JSONResponse(content=response.model_dump(exclude_unset=True))
|
| 1545 |
+
except Exception as generate_error:
|
| 1546 |
+
error_msg = f"Error calling OpenAI client for {request.model}: {str(generate_error)}"
|
| 1547 |
+
print(error_msg)
|
| 1548 |
+
error_response = create_openai_error_response(500, error_msg, "server_error")
|
| 1549 |
+
return JSONResponse(status_code=500, content=error_response)
|
| 1550 |
+
|
| 1551 |
+
# --- End of specific OpenAI client model handling ---
|
| 1552 |
+
|
| 1553 |
+
# Check model type and extract base model name (Changed to elif)
|
| 1554 |
+
elif request.model.endswith("-auto"):
|
| 1555 |
+
is_auto_model = True
|
| 1556 |
+
is_grounded_search = False
|
| 1557 |
+
is_encrypted_model = False
|
| 1558 |
is_encrypted_full_model = request.model.endswith("-encrypt-full")
|
| 1559 |
is_nothinking_model = request.model.endswith("-nothinking")
|
| 1560 |
is_max_thinking_model = request.model.endswith("-max")
|
|
|
|
| 1621 |
types.SafetySetting(category="HARM_CATEGORY_HATE_SPEECH", threshold="OFF"),
|
| 1622 |
types.SafetySetting(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="OFF"),
|
| 1623 |
types.SafetySetting(category="HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold="OFF"),
|
| 1624 |
+
types.SafetySetting(category="HARM_CATEGORY_HARASSMENT", threshold="OFF"),
|
| 1625 |
+
types.SafetySetting(category="HARM_CATEGORY_CIVIC_INTEGRITY", threshold="OFF")
|
| 1626 |
]
|
| 1627 |
generation_config["safety_settings"] = safety_settings
|
| 1628 |
|
|
|
|
| 1722 |
# --- Main Logic ---
|
| 1723 |
last_error = None
|
| 1724 |
|
| 1725 |
+
# --- Main Logic --- (Ensure flags are correctly set if the first 'if' wasn't met)
|
| 1726 |
+
# Re-evaluate flags based on elif structure for clarity if needed, or rely on the fact that the first 'if' returned.
|
| 1727 |
+
is_auto_model = request.model.endswith("-auto") # This will be False if the first 'if' was True
|
| 1728 |
+
is_grounded_search = request.model.endswith("-search")
|
| 1729 |
+
is_encrypted_model = request.model.endswith("-encrypt")
|
| 1730 |
+
is_encrypted_full_model = request.model.endswith("-encrypt-full")
|
| 1731 |
+
is_nothinking_model = request.model.endswith("-nothinking")
|
| 1732 |
+
is_max_thinking_model = request.model.endswith("-max")
|
| 1733 |
+
|
| 1734 |
+
if is_auto_model: # This remains the primary check after the openai specific one
|
| 1735 |
print(f"Processing auto model: {request.model}")
|
| 1736 |
+
base_model_name = request.model.replace("-auto", "") # Ensure base_model_name is set here too
|
| 1737 |
# Define encryption instructions for system_instruction
|
| 1738 |
encryption_instructions = [
|
| 1739 |
"// AI Assistant Configuration //",
|