Spaces:
Paused
Paused
Commit
·
cfdf66d
1
Parent(s):
d9e170e
added search models
Browse files- src/config.py +33 -4
- src/google_api_client.py +10 -2
- src/openai_transformers.py +10 -3
src/config.py
CHANGED
|
@@ -35,8 +35,8 @@ DEFAULT_SAFETY_SETTINGS = [
|
|
| 35 |
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"}
|
| 36 |
]
|
| 37 |
|
| 38 |
-
#
|
| 39 |
-
|
| 40 |
{
|
| 41 |
"name": "models/gemini-2.5-pro-preview-05-06",
|
| 42 |
"version": "001",
|
|
@@ -93,7 +93,7 @@ SUPPORTED_MODELS = [
|
|
| 93 |
"name": "models/gemini-2.5-flash-preview-04-17",
|
| 94 |
"version": "001",
|
| 95 |
"displayName": "Gemini 2.5 Flash Preview 04-17",
|
| 96 |
-
"description": "Preview version of Gemini 2.5 Flash from
|
| 97 |
"inputTokenLimit": 1048576,
|
| 98 |
"outputTokenLimit": 65535,
|
| 99 |
"supportedGenerationMethods": ["generateContent", "streamGenerateContent"],
|
|
@@ -115,4 +115,33 @@ SUPPORTED_MODELS = [
|
|
| 115 |
"topP": 0.95,
|
| 116 |
"topK": 64
|
| 117 |
}
|
| 118 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"}
|
| 36 |
]
|
| 37 |
|
| 38 |
+
# Base Models (without search variants)
|
| 39 |
+
BASE_MODELS = [
|
| 40 |
{
|
| 41 |
"name": "models/gemini-2.5-pro-preview-05-06",
|
| 42 |
"version": "001",
|
|
|
|
| 93 |
"name": "models/gemini-2.5-flash-preview-04-17",
|
| 94 |
"version": "001",
|
| 95 |
"displayName": "Gemini 2.5 Flash Preview 04-17",
|
| 96 |
+
"description": "Preview version of Gemini 2.5 Flash from April 17th",
|
| 97 |
"inputTokenLimit": 1048576,
|
| 98 |
"outputTokenLimit": 65535,
|
| 99 |
"supportedGenerationMethods": ["generateContent", "streamGenerateContent"],
|
|
|
|
| 115 |
"topP": 0.95,
|
| 116 |
"topK": 64
|
| 117 |
}
|
| 118 |
+
]
|
| 119 |
+
|
| 120 |
+
# Generate search variants for applicable models
|
| 121 |
+
def _generate_search_variants():
|
| 122 |
+
"""Generate search variants for models that support content generation."""
|
| 123 |
+
search_models = []
|
| 124 |
+
for model in BASE_MODELS:
|
| 125 |
+
# Only add search variants for models that support content generation
|
| 126 |
+
if "generateContent" in model["supportedGenerationMethods"]:
|
| 127 |
+
search_variant = model.copy()
|
| 128 |
+
search_variant["name"] = model["name"] + "-search"
|
| 129 |
+
search_variant["displayName"] = model["displayName"] + " with Google Search"
|
| 130 |
+
search_variant["description"] = model["description"] + " (includes Google Search grounding)"
|
| 131 |
+
search_models.append(search_variant)
|
| 132 |
+
return search_models
|
| 133 |
+
|
| 134 |
+
# Supported Models (includes both base models and search variants)
|
| 135 |
+
SUPPORTED_MODELS = BASE_MODELS + _generate_search_variants()
|
| 136 |
+
|
| 137 |
+
# Helper function to get base model name from search variant
|
| 138 |
+
def get_base_model_name(model_name):
|
| 139 |
+
"""Convert search variant model name to base model name."""
|
| 140 |
+
if model_name.endswith("-search"):
|
| 141 |
+
return model_name[:-7] # Remove "-search" suffix
|
| 142 |
+
return model_name
|
| 143 |
+
|
| 144 |
+
# Helper function to check if model uses search grounding
|
| 145 |
+
def is_search_model(model_name):
|
| 146 |
+
"""Check if model name indicates search grounding should be enabled."""
|
| 147 |
+
return model_name.endswith("-search")
|
src/google_api_client.py
CHANGED
|
@@ -11,7 +11,7 @@ from google.auth.transport.requests import Request as GoogleAuthRequest
|
|
| 11 |
|
| 12 |
from .auth import get_credentials, save_credentials, get_user_project_id, onboard_user
|
| 13 |
from .utils import get_user_agent
|
| 14 |
-
from .config import CODE_ASSIST_ENDPOINT, DEFAULT_SAFETY_SETTINGS
|
| 15 |
import asyncio
|
| 16 |
|
| 17 |
|
|
@@ -310,7 +310,15 @@ def build_gemini_payload_from_native(native_request: dict, model_from_path: str)
|
|
| 310 |
native_request["generationConfig"]["thinkingConfig"]["includeThoughts"] = True
|
| 311 |
native_request["generationConfig"]["thinkingConfig"]["thinkingBudget"] = -1
|
| 312 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
return {
|
| 314 |
-
"model": model_from_path,
|
| 315 |
"request": native_request
|
| 316 |
}
|
|
|
|
| 11 |
|
| 12 |
from .auth import get_credentials, save_credentials, get_user_project_id, onboard_user
|
| 13 |
from .utils import get_user_agent
|
| 14 |
+
from .config import CODE_ASSIST_ENDPOINT, DEFAULT_SAFETY_SETTINGS, get_base_model_name, is_search_model
|
| 15 |
import asyncio
|
| 16 |
|
| 17 |
|
|
|
|
| 310 |
native_request["generationConfig"]["thinkingConfig"]["includeThoughts"] = True
|
| 311 |
native_request["generationConfig"]["thinkingConfig"]["thinkingBudget"] = -1
|
| 312 |
|
| 313 |
+
# Add Google Search grounding for search models
|
| 314 |
+
if is_search_model(model_from_path):
|
| 315 |
+
if "tools" not in native_request:
|
| 316 |
+
native_request["tools"] = []
|
| 317 |
+
# Add googleSearch tool if not already present
|
| 318 |
+
if not any(tool.get("googleSearch") for tool in native_request["tools"]):
|
| 319 |
+
native_request["tools"].append({"googleSearch": {}})
|
| 320 |
+
|
| 321 |
return {
|
| 322 |
+
"model": get_base_model_name(model_from_path), # Use base model name for API call
|
| 323 |
"request": native_request
|
| 324 |
}
|
src/openai_transformers.py
CHANGED
|
@@ -8,7 +8,7 @@ import uuid
|
|
| 8 |
from typing import Dict, Any
|
| 9 |
|
| 10 |
from .models import OpenAIChatCompletionRequest, OpenAIChatCompletionResponse
|
| 11 |
-
from .config import DEFAULT_SAFETY_SETTINGS
|
| 12 |
|
| 13 |
|
| 14 |
def openai_request_to_gemini(openai_request: OpenAIChatCompletionRequest) -> Dict[str, Any]:
|
|
@@ -91,12 +91,19 @@ def openai_request_to_gemini(openai_request: OpenAIChatCompletionRequest) -> Dic
|
|
| 91 |
if openai_request.response_format.get("type") == "json_object":
|
| 92 |
generation_config["responseMimeType"] = "application/json"
|
| 93 |
|
| 94 |
-
|
|
|
|
| 95 |
"contents": contents,
|
| 96 |
"generationConfig": generation_config,
|
| 97 |
"safetySettings": DEFAULT_SAFETY_SETTINGS,
|
| 98 |
-
"model": openai_request.model
|
| 99 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
|
| 102 |
def gemini_response_to_openai(gemini_response: Dict[str, Any], model: str) -> Dict[str, Any]:
|
|
|
|
| 8 |
from typing import Dict, Any
|
| 9 |
|
| 10 |
from .models import OpenAIChatCompletionRequest, OpenAIChatCompletionResponse
|
| 11 |
+
from .config import DEFAULT_SAFETY_SETTINGS, is_search_model, get_base_model_name
|
| 12 |
|
| 13 |
|
| 14 |
def openai_request_to_gemini(openai_request: OpenAIChatCompletionRequest) -> Dict[str, Any]:
|
|
|
|
| 91 |
if openai_request.response_format.get("type") == "json_object":
|
| 92 |
generation_config["responseMimeType"] = "application/json"
|
| 93 |
|
| 94 |
+
# Build the request payload
|
| 95 |
+
request_payload = {
|
| 96 |
"contents": contents,
|
| 97 |
"generationConfig": generation_config,
|
| 98 |
"safetySettings": DEFAULT_SAFETY_SETTINGS,
|
| 99 |
+
"model": get_base_model_name(openai_request.model) # Use base model name for API call
|
| 100 |
}
|
| 101 |
+
|
| 102 |
+
# Add Google Search grounding for search models
|
| 103 |
+
if is_search_model(openai_request.model):
|
| 104 |
+
request_payload["tools"] = [{"googleSearch": {}}]
|
| 105 |
+
|
| 106 |
+
return request_payload
|
| 107 |
|
| 108 |
|
| 109 |
def gemini_response_to_openai(gemini_response: Dict[str, Any], model: str) -> Dict[str, Any]:
|