Commit ·
b5abced
1
Parent(s): da7a18e
changed openai cot streaming handling. added roundrobin mode for credentials. various refactoring
Browse files- app/express_key_manager.py +97 -0
- app/openai_handler.py +284 -0
app/express_key_manager.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from typing import List, Optional, Tuple
|
| 3 |
+
import config as app_config
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class ExpressKeyManager:
|
| 7 |
+
"""
|
| 8 |
+
Manager for Vertex Express API keys with support for both random and round-robin selection strategies.
|
| 9 |
+
Similar to CredentialManager but specifically for Express API keys.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
def __init__(self):
|
| 13 |
+
"""Initialize the Express Key Manager with API keys from config."""
|
| 14 |
+
self.express_keys: List[str] = app_config.VERTEX_EXPRESS_API_KEY_VAL
|
| 15 |
+
self.round_robin_index: int = 0
|
| 16 |
+
|
| 17 |
+
def get_total_keys(self) -> int:
|
| 18 |
+
"""Get the total number of available Express API keys."""
|
| 19 |
+
return len(self.express_keys)
|
| 20 |
+
|
| 21 |
+
def _get_key_with_index(self, key: str, index: int) -> Tuple[str, int]:
|
| 22 |
+
"""Return a tuple of (key, original_index) for logging purposes."""
|
| 23 |
+
return (key, index)
|
| 24 |
+
|
| 25 |
+
def get_random_express_key(self) -> Optional[Tuple[str, int]]:
|
| 26 |
+
"""
|
| 27 |
+
Get a random Express API key.
|
| 28 |
+
Returns (key, original_index) tuple or None if no keys available.
|
| 29 |
+
"""
|
| 30 |
+
if not self.express_keys:
|
| 31 |
+
print("WARNING: No Express API keys available for selection.")
|
| 32 |
+
return None
|
| 33 |
+
|
| 34 |
+
print(f"DEBUG: Using random Express API key selection strategy.")
|
| 35 |
+
|
| 36 |
+
# Create list of indexed keys
|
| 37 |
+
indexed_keys = list(enumerate(self.express_keys))
|
| 38 |
+
# Shuffle to randomize order
|
| 39 |
+
random.shuffle(indexed_keys)
|
| 40 |
+
|
| 41 |
+
# Return the first key (which is random due to shuffle)
|
| 42 |
+
original_idx, key = indexed_keys[0]
|
| 43 |
+
return self._get_key_with_index(key, original_idx)
|
| 44 |
+
|
| 45 |
+
def get_roundrobin_express_key(self) -> Optional[Tuple[str, int]]:
|
| 46 |
+
"""
|
| 47 |
+
Get an Express API key using round-robin selection.
|
| 48 |
+
Returns (key, original_index) tuple or None if no keys available.
|
| 49 |
+
"""
|
| 50 |
+
if not self.express_keys:
|
| 51 |
+
print("WARNING: No Express API keys available for selection.")
|
| 52 |
+
return None
|
| 53 |
+
|
| 54 |
+
print(f"DEBUG: Using round-robin Express API key selection strategy.")
|
| 55 |
+
|
| 56 |
+
# Ensure round_robin_index is within bounds
|
| 57 |
+
if self.round_robin_index >= len(self.express_keys):
|
| 58 |
+
self.round_robin_index = 0
|
| 59 |
+
|
| 60 |
+
# Get the key at current index
|
| 61 |
+
key = self.express_keys[self.round_robin_index]
|
| 62 |
+
original_idx = self.round_robin_index
|
| 63 |
+
|
| 64 |
+
# Move to next index for next call
|
| 65 |
+
self.round_robin_index = (self.round_robin_index + 1) % len(self.express_keys)
|
| 66 |
+
|
| 67 |
+
return self._get_key_with_index(key, original_idx)
|
| 68 |
+
|
| 69 |
+
def get_express_api_key(self) -> Optional[Tuple[str, int]]:
|
| 70 |
+
"""
|
| 71 |
+
Get an Express API key based on the configured selection strategy.
|
| 72 |
+
Checks ROUNDROBIN config and calls the appropriate method.
|
| 73 |
+
Returns (key, original_index) tuple or None if no keys available.
|
| 74 |
+
"""
|
| 75 |
+
if app_config.ROUNDROBIN:
|
| 76 |
+
return self.get_roundrobin_express_key()
|
| 77 |
+
else:
|
| 78 |
+
return self.get_random_express_key()
|
| 79 |
+
|
| 80 |
+
def get_all_keys_indexed(self) -> List[Tuple[int, str]]:
|
| 81 |
+
"""
|
| 82 |
+
Get all Express API keys with their indices.
|
| 83 |
+
Useful for retry logic where we need to try all keys.
|
| 84 |
+
Returns list of (original_index, key) tuples.
|
| 85 |
+
"""
|
| 86 |
+
return list(enumerate(self.express_keys))
|
| 87 |
+
|
| 88 |
+
def refresh_keys(self):
|
| 89 |
+
"""
|
| 90 |
+
Refresh the Express API keys from config.
|
| 91 |
+
This allows for dynamic updates if the config is reloaded.
|
| 92 |
+
"""
|
| 93 |
+
self.express_keys = app_config.VERTEX_EXPRESS_API_KEY_VAL
|
| 94 |
+
# Reset round-robin index if keys changed
|
| 95 |
+
if self.round_robin_index >= len(self.express_keys):
|
| 96 |
+
self.round_robin_index = 0
|
| 97 |
+
print(f"INFO: Express API keys refreshed. Total keys: {self.get_total_keys()}")
|
app/openai_handler.py
ADDED
|
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
OpenAI handler module for creating clients and processing OpenAI Direct mode responses.
|
| 3 |
+
This module encapsulates all OpenAI-specific logic that was previously in chat_api.py.
|
| 4 |
+
"""
|
| 5 |
+
import json
|
| 6 |
+
import time
|
| 7 |
+
import asyncio
|
| 8 |
+
from typing import Dict, Any, AsyncGenerator
|
| 9 |
+
|
| 10 |
+
from fastapi.responses import JSONResponse, StreamingResponse
|
| 11 |
+
import openai
|
| 12 |
+
from google.auth.transport.requests import Request as AuthRequest
|
| 13 |
+
|
| 14 |
+
from models import OpenAIRequest
|
| 15 |
+
from config import VERTEX_REASONING_TAG
|
| 16 |
+
import config as app_config
|
| 17 |
+
from api_helpers import (
|
| 18 |
+
create_openai_error_response,
|
| 19 |
+
openai_fake_stream_generator,
|
| 20 |
+
StreamingReasoningProcessor
|
| 21 |
+
)
|
| 22 |
+
from message_processing import extract_reasoning_by_tags
|
| 23 |
+
from credentials_manager import _refresh_auth
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class OpenAIDirectHandler:
|
| 27 |
+
"""Handles OpenAI Direct mode operations including client creation and response processing."""
|
| 28 |
+
|
| 29 |
+
def __init__(self, credential_manager):
|
| 30 |
+
self.credential_manager = credential_manager
|
| 31 |
+
self.safety_settings = [
|
| 32 |
+
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
|
| 33 |
+
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
|
| 34 |
+
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
|
| 35 |
+
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
|
| 36 |
+
{"category": 'HARM_CATEGORY_CIVIC_INTEGRITY', "threshold": 'OFF'}
|
| 37 |
+
]
|
| 38 |
+
|
| 39 |
+
def create_openai_client(self, project_id: str, gcp_token: str, location: str = "global") -> openai.AsyncOpenAI:
|
| 40 |
+
"""Create an OpenAI client configured for Vertex AI endpoint."""
|
| 41 |
+
endpoint_url = (
|
| 42 |
+
f"https://aiplatform.googleapis.com/v1beta1/"
|
| 43 |
+
f"projects/{project_id}/locations/{location}/endpoints/openapi"
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
return openai.AsyncOpenAI(
|
| 47 |
+
base_url=endpoint_url,
|
| 48 |
+
api_key=gcp_token, # OAuth token
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
def prepare_openai_params(self, request: OpenAIRequest, model_id: str) -> Dict[str, Any]:
|
| 52 |
+
"""Prepare parameters for OpenAI API call."""
|
| 53 |
+
params = {
|
| 54 |
+
"model": model_id,
|
| 55 |
+
"messages": [msg.model_dump(exclude_unset=True) for msg in request.messages],
|
| 56 |
+
"temperature": request.temperature,
|
| 57 |
+
"max_tokens": request.max_tokens,
|
| 58 |
+
"top_p": request.top_p,
|
| 59 |
+
"stream": request.stream,
|
| 60 |
+
"stop": request.stop,
|
| 61 |
+
"seed": request.seed,
|
| 62 |
+
"n": request.n,
|
| 63 |
+
}
|
| 64 |
+
# Remove None values
|
| 65 |
+
return {k: v for k, v in params.items() if v is not None}
|
| 66 |
+
|
| 67 |
+
def prepare_extra_body(self) -> Dict[str, Any]:
|
| 68 |
+
"""Prepare extra body parameters for OpenAI API call."""
|
| 69 |
+
return {
|
| 70 |
+
"extra_body": {
|
| 71 |
+
'google': {
|
| 72 |
+
'safety_settings': self.safety_settings,
|
| 73 |
+
'thought_tag_marker': VERTEX_REASONING_TAG
|
| 74 |
+
}
|
| 75 |
+
}
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
async def handle_streaming_response(
|
| 79 |
+
self,
|
| 80 |
+
openai_client: openai.AsyncOpenAI,
|
| 81 |
+
openai_params: Dict[str, Any],
|
| 82 |
+
openai_extra_body: Dict[str, Any],
|
| 83 |
+
request: OpenAIRequest
|
| 84 |
+
) -> StreamingResponse:
|
| 85 |
+
"""Handle streaming responses for OpenAI Direct mode."""
|
| 86 |
+
if app_config.FAKE_STREAMING_ENABLED:
|
| 87 |
+
print(f"INFO: OpenAI Fake Streaming (SSE Simulation) ENABLED for model '{request.model}'.")
|
| 88 |
+
return StreamingResponse(
|
| 89 |
+
openai_fake_stream_generator(
|
| 90 |
+
openai_client=openai_client,
|
| 91 |
+
openai_params=openai_params,
|
| 92 |
+
openai_extra_body=openai_extra_body,
|
| 93 |
+
request_obj=request,
|
| 94 |
+
is_auto_attempt=False
|
| 95 |
+
),
|
| 96 |
+
media_type="text/event-stream"
|
| 97 |
+
)
|
| 98 |
+
else:
|
| 99 |
+
print(f"INFO: OpenAI True Streaming ENABLED for model '{request.model}'.")
|
| 100 |
+
return StreamingResponse(
|
| 101 |
+
self._true_stream_generator(openai_client, openai_params, openai_extra_body, request),
|
| 102 |
+
media_type="text/event-stream"
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
async def _true_stream_generator(
|
| 106 |
+
self,
|
| 107 |
+
openai_client: openai.AsyncOpenAI,
|
| 108 |
+
openai_params: Dict[str, Any],
|
| 109 |
+
openai_extra_body: Dict[str, Any],
|
| 110 |
+
request: OpenAIRequest
|
| 111 |
+
) -> AsyncGenerator[str, None]:
|
| 112 |
+
"""Generate true streaming response."""
|
| 113 |
+
try:
|
| 114 |
+
# Ensure stream=True is explicitly passed for real streaming
|
| 115 |
+
openai_params_for_stream = {**openai_params, "stream": True}
|
| 116 |
+
stream_response = await openai_client.chat.completions.create(
|
| 117 |
+
**openai_params_for_stream,
|
| 118 |
+
extra_body=openai_extra_body['extra_body']
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
# Create processor for tag-based extraction across chunks
|
| 122 |
+
reasoning_processor = StreamingReasoningProcessor(VERTEX_REASONING_TAG)
|
| 123 |
+
|
| 124 |
+
async for chunk in stream_response:
|
| 125 |
+
try:
|
| 126 |
+
chunk_as_dict = chunk.model_dump(exclude_unset=True, exclude_none=True)
|
| 127 |
+
|
| 128 |
+
choices = chunk_as_dict.get('choices')
|
| 129 |
+
if choices and isinstance(choices, list) and len(choices) > 0:
|
| 130 |
+
delta = choices[0].get('delta')
|
| 131 |
+
if delta and isinstance(delta, dict):
|
| 132 |
+
# Always remove extra_content if present
|
| 133 |
+
if 'extra_content' in delta:
|
| 134 |
+
del delta['extra_content']
|
| 135 |
+
|
| 136 |
+
content = delta.get('content', '')
|
| 137 |
+
if content:
|
| 138 |
+
# Use the processor to extract reasoning
|
| 139 |
+
processed_content, current_reasoning = reasoning_processor.process_chunk(content)
|
| 140 |
+
|
| 141 |
+
# Update delta with processed content
|
| 142 |
+
if current_reasoning:
|
| 143 |
+
delta['reasoning_content'] = current_reasoning
|
| 144 |
+
if processed_content:
|
| 145 |
+
delta['content'] = processed_content
|
| 146 |
+
elif 'content' in delta:
|
| 147 |
+
del delta['content']
|
| 148 |
+
|
| 149 |
+
yield f"data: {json.dumps(chunk_as_dict)}\n\n"
|
| 150 |
+
|
| 151 |
+
except Exception as chunk_error:
|
| 152 |
+
error_msg = f"Error processing OpenAI chunk for {request.model}: {str(chunk_error)}"
|
| 153 |
+
print(f"ERROR: {error_msg}")
|
| 154 |
+
if len(error_msg) > 1024:
|
| 155 |
+
error_msg = error_msg[:1024] + "..."
|
| 156 |
+
error_response = create_openai_error_response(500, error_msg, "server_error")
|
| 157 |
+
yield f"data: {json.dumps(error_response)}\n\n"
|
| 158 |
+
yield "data: [DONE]\n\n"
|
| 159 |
+
return
|
| 160 |
+
|
| 161 |
+
# Handle any remaining buffer content
|
| 162 |
+
if reasoning_processor.tag_buffer and not reasoning_processor.inside_tag:
|
| 163 |
+
# Output any remaining content
|
| 164 |
+
final_chunk = {
|
| 165 |
+
"id": f"chatcmpl-{int(time.time())}",
|
| 166 |
+
"object": "chat.completion.chunk",
|
| 167 |
+
"created": int(time.time()),
|
| 168 |
+
"model": request.model,
|
| 169 |
+
"choices": [{"index": 0, "delta": {"content": reasoning_processor.tag_buffer}, "finish_reason": None}]
|
| 170 |
+
}
|
| 171 |
+
yield f"data: {json.dumps(final_chunk)}\n\n"
|
| 172 |
+
elif reasoning_processor.inside_tag and reasoning_processor.reasoning_buffer:
|
| 173 |
+
# We were inside a tag but never found the closing tag
|
| 174 |
+
print(f"WARNING: Unclosed reasoning tag detected. Partial reasoning: {reasoning_processor.reasoning_buffer[:100]}...")
|
| 175 |
+
|
| 176 |
+
yield "data: [DONE]\n\n"
|
| 177 |
+
|
| 178 |
+
except Exception as stream_error:
|
| 179 |
+
error_msg = str(stream_error)
|
| 180 |
+
if len(error_msg) > 1024:
|
| 181 |
+
error_msg = error_msg[:1024] + "..."
|
| 182 |
+
error_msg_full = f"Error during OpenAI streaming for {request.model}: {error_msg}"
|
| 183 |
+
print(f"ERROR: {error_msg_full}")
|
| 184 |
+
error_response = create_openai_error_response(500, error_msg_full, "server_error")
|
| 185 |
+
yield f"data: {json.dumps(error_response)}\n\n"
|
| 186 |
+
yield "data: [DONE]\n\n"
|
| 187 |
+
|
| 188 |
+
async def handle_non_streaming_response(
|
| 189 |
+
self,
|
| 190 |
+
openai_client: openai.AsyncOpenAI,
|
| 191 |
+
openai_params: Dict[str, Any],
|
| 192 |
+
openai_extra_body: Dict[str, Any],
|
| 193 |
+
request: OpenAIRequest
|
| 194 |
+
) -> JSONResponse:
|
| 195 |
+
"""Handle non-streaming responses for OpenAI Direct mode."""
|
| 196 |
+
try:
|
| 197 |
+
# Ensure stream=False is explicitly passed
|
| 198 |
+
openai_params_non_stream = {**openai_params, "stream": False}
|
| 199 |
+
response = await openai_client.chat.completions.create(
|
| 200 |
+
**openai_params_non_stream,
|
| 201 |
+
extra_body=openai_extra_body['extra_body']
|
| 202 |
+
)
|
| 203 |
+
response_dict = response.model_dump(exclude_unset=True, exclude_none=True)
|
| 204 |
+
|
| 205 |
+
try:
|
| 206 |
+
choices = response_dict.get('choices')
|
| 207 |
+
if choices and isinstance(choices, list) and len(choices) > 0:
|
| 208 |
+
message_dict = choices[0].get('message')
|
| 209 |
+
if message_dict and isinstance(message_dict, dict):
|
| 210 |
+
# Always remove extra_content from the message if it exists
|
| 211 |
+
if 'extra_content' in message_dict:
|
| 212 |
+
del message_dict['extra_content']
|
| 213 |
+
|
| 214 |
+
# Extract reasoning from content
|
| 215 |
+
full_content = message_dict.get('content')
|
| 216 |
+
actual_content = full_content if isinstance(full_content, str) else ""
|
| 217 |
+
|
| 218 |
+
if actual_content:
|
| 219 |
+
print(f"INFO: OpenAI Direct Non-Streaming - Applying tag extraction with fixed marker: '{VERTEX_REASONING_TAG}'")
|
| 220 |
+
reasoning_text, actual_content = extract_reasoning_by_tags(actual_content, VERTEX_REASONING_TAG)
|
| 221 |
+
message_dict['content'] = actual_content
|
| 222 |
+
if reasoning_text:
|
| 223 |
+
message_dict['reasoning_content'] = reasoning_text
|
| 224 |
+
print(f"DEBUG: Tag extraction success. Reasoning len: {len(reasoning_text)}, Content len: {len(actual_content)}")
|
| 225 |
+
else:
|
| 226 |
+
print(f"DEBUG: No content found within fixed tag '{VERTEX_REASONING_TAG}'.")
|
| 227 |
+
else:
|
| 228 |
+
print(f"WARNING: OpenAI Direct Non-Streaming - No initial content found in message.")
|
| 229 |
+
message_dict['content'] = ""
|
| 230 |
+
|
| 231 |
+
except Exception as e_reasoning:
|
| 232 |
+
print(f"WARNING: Error during non-streaming reasoning processing for model {request.model}: {e_reasoning}")
|
| 233 |
+
|
| 234 |
+
return JSONResponse(content=response_dict)
|
| 235 |
+
|
| 236 |
+
except Exception as e:
|
| 237 |
+
error_msg = f"Error calling OpenAI client for {request.model}: {str(e)}"
|
| 238 |
+
print(f"ERROR: {error_msg}")
|
| 239 |
+
return JSONResponse(
|
| 240 |
+
status_code=500,
|
| 241 |
+
content=create_openai_error_response(500, error_msg, "server_error")
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
async def process_request(self, request: OpenAIRequest, base_model_name: str):
|
| 245 |
+
"""Main entry point for processing OpenAI Direct mode requests."""
|
| 246 |
+
print(f"INFO: Using OpenAI Direct Path for model: {request.model}")
|
| 247 |
+
|
| 248 |
+
# Get credentials
|
| 249 |
+
rotated_credentials, rotated_project_id = self.credential_manager.get_credentials()
|
| 250 |
+
|
| 251 |
+
if not rotated_credentials or not rotated_project_id:
|
| 252 |
+
error_msg = "OpenAI Direct Mode requires GCP credentials, but none were available or loaded successfully."
|
| 253 |
+
print(f"ERROR: {error_msg}")
|
| 254 |
+
return JSONResponse(
|
| 255 |
+
status_code=500,
|
| 256 |
+
content=create_openai_error_response(500, error_msg, "server_error")
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
print(f"INFO: [OpenAI Direct Path] Using credentials for project: {rotated_project_id}")
|
| 260 |
+
gcp_token = _refresh_auth(rotated_credentials)
|
| 261 |
+
|
| 262 |
+
if not gcp_token:
|
| 263 |
+
error_msg = f"Failed to obtain valid GCP token for OpenAI client (Project: {rotated_project_id})."
|
| 264 |
+
print(f"ERROR: {error_msg}")
|
| 265 |
+
return JSONResponse(
|
| 266 |
+
status_code=500,
|
| 267 |
+
content=create_openai_error_response(500, error_msg, "server_error")
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
# Create client and prepare parameters
|
| 271 |
+
openai_client = self.create_openai_client(rotated_project_id, gcp_token)
|
| 272 |
+
model_id = f"google/{base_model_name}"
|
| 273 |
+
openai_params = self.prepare_openai_params(request, model_id)
|
| 274 |
+
openai_extra_body = self.prepare_extra_body()
|
| 275 |
+
|
| 276 |
+
# Handle streaming vs non-streaming
|
| 277 |
+
if request.stream:
|
| 278 |
+
return await self.handle_streaming_response(
|
| 279 |
+
openai_client, openai_params, openai_extra_body, request
|
| 280 |
+
)
|
| 281 |
+
else:
|
| 282 |
+
return await self.handle_non_streaming_response(
|
| 283 |
+
openai_client, openai_params, openai_extra_body, request
|
| 284 |
+
)
|