Update app/main.py
Browse files- app/main.py +145 -107
app/main.py
CHANGED
|
@@ -273,94 +273,66 @@ async def startup_event():
|
|
| 273 |
print("WARNING: Failed to initialize Vertex AI authentication")
|
| 274 |
|
| 275 |
# Conversion functions
|
| 276 |
-
|
|
|
|
|
|
|
|
|
|
| 277 |
"""
|
| 278 |
Convert OpenAI messages to Gemini format.
|
| 279 |
-
Returns
|
| 280 |
"""
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
for part in message.content:
|
| 286 |
-
if isinstance(part, dict) and part.get('type') == 'image_url':
|
| 287 |
-
has_images = True
|
| 288 |
-
break
|
| 289 |
-
elif isinstance(part, ContentPartImage):
|
| 290 |
-
has_images = True
|
| 291 |
-
break
|
| 292 |
-
if has_images:
|
| 293 |
-
break
|
| 294 |
|
| 295 |
-
#
|
| 296 |
-
|
| 297 |
-
|
|
|
|
| 298 |
|
| 299 |
-
#
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
|
|
|
|
|
|
|
|
|
| 307 |
else:
|
| 308 |
-
#
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
elif message.role == "user":
|
| 314 |
-
prompt += f"Human: {content_text}\n"
|
| 315 |
-
elif message.role == "assistant":
|
| 316 |
-
prompt += f"AI: {content_text}\n"
|
| 317 |
|
| 318 |
-
#
|
| 319 |
-
|
| 320 |
-
prompt += "AI: "
|
| 321 |
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
# If images are present, create a list of content parts
|
| 325 |
-
gemini_contents = []
|
| 326 |
-
|
| 327 |
-
# Process all messages in their original order
|
| 328 |
-
for message in messages:
|
| 329 |
-
|
| 330 |
-
# For string content, add as text
|
| 331 |
if isinstance(message.content, str):
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
# For list content, process each part
|
| 336 |
elif isinstance(message.content, list):
|
| 337 |
-
#
|
| 338 |
-
text_content = ""
|
| 339 |
-
|
| 340 |
for part in message.content:
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 344 |
elif isinstance(part, ContentPartText):
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
# Add the combined text content if any
|
| 348 |
-
if text_content:
|
| 349 |
-
prefix = "Human: " if message.role == "user" else "AI: "
|
| 350 |
-
gemini_contents.append(f"{prefix}{text_content}")
|
| 351 |
-
|
| 352 |
-
# Then process image parts
|
| 353 |
-
for part in message.content:
|
| 354 |
-
# Handle image parts
|
| 355 |
-
if isinstance(part, dict) and part.get('type') == 'image_url':
|
| 356 |
-
image_url = part.get('image_url', {}).get('url', '')
|
| 357 |
-
if image_url.startswith('data:'):
|
| 358 |
-
# Extract mime type and base64 data
|
| 359 |
-
mime_match = re.match(r'data:([^;]+);base64,(.+)', image_url)
|
| 360 |
-
if mime_match:
|
| 361 |
-
mime_type, b64_data = mime_match.groups()
|
| 362 |
-
image_bytes = base64.b64decode(b64_data)
|
| 363 |
-
gemini_contents.append(types.Part.from_bytes(data=image_bytes, mime_type=mime_type))
|
| 364 |
elif isinstance(part, ContentPartImage):
|
| 365 |
image_url = part.image_url.url
|
| 366 |
if image_url.startswith('data:'):
|
|
@@ -369,14 +341,38 @@ def create_gemini_prompt(messages: List[OpenAIMessage]) -> Union[str, List[Any]]
|
|
| 369 |
if mime_match:
|
| 370 |
mime_type, b64_data = mime_match.groups()
|
| 371 |
image_bytes = base64.b64decode(b64_data)
|
| 372 |
-
|
| 373 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 374 |
|
| 375 |
-
def create_encrypted_gemini_prompt(messages: List[OpenAIMessage]) -> Union[
|
| 376 |
"""
|
| 377 |
Convert OpenAI messages to Gemini format with special encoding for the encrypt model.
|
| 378 |
This function URL-encodes user messages and adds specific system instructions.
|
| 379 |
"""
|
|
|
|
|
|
|
| 380 |
# Check if any message contains image content
|
| 381 |
has_images = False
|
| 382 |
for message in messages:
|
|
@@ -395,8 +391,8 @@ def create_encrypted_gemini_prompt(messages: List[OpenAIMessage]) -> Union[str,
|
|
| 395 |
if has_images:
|
| 396 |
return create_gemini_prompt(messages)
|
| 397 |
|
| 398 |
-
# Define the
|
| 399 |
-
|
| 400 |
STRICT OPERATING PROTOCOL:
|
| 401 |
1. **Critical Output Format:** Your responses MUST ALWAYS be standard **UNENCODED PLAIN TEXT**. Absolutely NO URL encoding or any other form of encoding in your output. This is non-negotiable.
|
| 402 |
2. **Input Processing (Internal Only):** User messages and conversation history provided to you will be URL-encoded for privacy. Decode these inputs internally before processing.
|
|
@@ -422,8 +418,8 @@ Ready for your request."""
|
|
| 422 |
# Create a new list of messages with the pre-messages and encoded content
|
| 423 |
new_messages = []
|
| 424 |
|
| 425 |
-
# Add a system message with
|
| 426 |
-
new_messages.append(OpenAIMessage(role="system", content=
|
| 427 |
|
| 428 |
# Add pre-messages
|
| 429 |
new_messages.extend(pre_messages)
|
|
@@ -431,19 +427,7 @@ Ready for your request."""
|
|
| 431 |
# Process all messages in their original order
|
| 432 |
for i, message in enumerate(messages):
|
| 433 |
if message.role == "system":
|
| 434 |
-
#
|
| 435 |
-
# if isinstance(message.content, str):
|
| 436 |
-
# system_content = message.content
|
| 437 |
-
# elif isinstance(message.content, list) and message.content and isinstance(message.content[0], dict) and 'text' in message.content[0]:
|
| 438 |
-
# system_content = message.content[0]['text']
|
| 439 |
-
# else:
|
| 440 |
-
# system_content = str(message.content)
|
| 441 |
-
|
| 442 |
-
# # URL encode the system message content
|
| 443 |
-
# new_messages.append(OpenAIMessage(
|
| 444 |
-
# role="system",
|
| 445 |
-
# content=urllib.parse.quote(system_content)
|
| 446 |
-
# ))
|
| 447 |
new_messages.append(message)
|
| 448 |
|
| 449 |
elif message.role == "user":
|
|
@@ -454,12 +438,26 @@ Ready for your request."""
|
|
| 454 |
content=urllib.parse.quote(message.content)
|
| 455 |
))
|
| 456 |
elif isinstance(message.content, list):
|
| 457 |
-
#
|
| 458 |
-
|
| 459 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 460 |
else:
|
| 461 |
-
# For
|
| 462 |
-
# Check if this is the last
|
| 463 |
is_last_assistant = True
|
| 464 |
for remaining_msg in messages[i+1:]:
|
| 465 |
if remaining_msg.role != "user":
|
|
@@ -473,13 +471,30 @@ Ready for your request."""
|
|
| 473 |
role=message.role,
|
| 474 |
content=urllib.parse.quote(message.content)
|
| 475 |
))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 476 |
else:
|
| 477 |
-
# For non-string content, keep as is
|
| 478 |
new_messages.append(message)
|
| 479 |
else:
|
| 480 |
-
# For other
|
| 481 |
new_messages.append(message)
|
| 482 |
|
|
|
|
| 483 |
# Now use the standard function to convert to Gemini format
|
| 484 |
return create_gemini_prompt(new_messages)
|
| 485 |
|
|
@@ -826,6 +841,23 @@ async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_ap
|
|
| 826 |
prompt = create_encrypted_gemini_prompt(request.messages)
|
| 827 |
else:
|
| 828 |
prompt = create_gemini_prompt(request.messages)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 829 |
|
| 830 |
if request.stream:
|
| 831 |
# Handle streaming response
|
|
@@ -838,10 +870,13 @@ async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_ap
|
|
| 838 |
# If multiple candidates are requested, we'll generate them sequentially
|
| 839 |
for candidate_index in range(candidate_count):
|
| 840 |
# Generate content with streaming
|
| 841 |
-
# Handle
|
|
|
|
|
|
|
|
|
|
| 842 |
responses = client.models.generate_content_stream(
|
| 843 |
model=gemini_model,
|
| 844 |
-
contents=prompt,
|
| 845 |
config=generation_config,
|
| 846 |
)
|
| 847 |
|
|
@@ -873,10 +908,13 @@ async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_ap
|
|
| 873 |
# Make sure generation_config has candidate_count set
|
| 874 |
if "candidate_count" not in generation_config:
|
| 875 |
generation_config["candidate_count"] = request.n
|
| 876 |
-
# Handle
|
|
|
|
|
|
|
|
|
|
| 877 |
response = client.models.generate_content(
|
| 878 |
model=gemini_model,
|
| 879 |
-
contents=prompt,
|
| 880 |
config=generation_config,
|
| 881 |
)
|
| 882 |
|
|
|
|
| 273 |
print("WARNING: Failed to initialize Vertex AI authentication")
|
| 274 |
|
| 275 |
# Conversion functions
|
| 276 |
+
# Define supported roles for Gemini API
|
| 277 |
+
SUPPORTED_ROLES = ["user", "model"]
|
| 278 |
+
|
| 279 |
+
def create_gemini_prompt(messages: List[OpenAIMessage]) -> Union[types.Content, List[types.Content]]:
|
| 280 |
"""
|
| 281 |
Convert OpenAI messages to Gemini format.
|
| 282 |
+
Returns a Content object or list of Content objects as required by the Gemini API.
|
| 283 |
"""
|
| 284 |
+
print("Converting OpenAI messages to Gemini format...")
|
| 285 |
+
|
| 286 |
+
# Create a list to hold the Gemini-formatted messages
|
| 287 |
+
gemini_messages = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 288 |
|
| 289 |
+
# Process all messages in their original order
|
| 290 |
+
for idx, message in enumerate(messages):
|
| 291 |
+
# Map OpenAI roles to Gemini roles
|
| 292 |
+
role = message.role
|
| 293 |
|
| 294 |
+
# If role is "system", use "user" as specified
|
| 295 |
+
if role == "system":
|
| 296 |
+
role = "user"
|
| 297 |
+
# If role is "assistant", map to "model"
|
| 298 |
+
elif role == "assistant":
|
| 299 |
+
role = "model"
|
| 300 |
+
|
| 301 |
+
# Handle unsupported roles as per user's feedback
|
| 302 |
+
if role not in SUPPORTED_ROLES:
|
| 303 |
+
if role == "tool":
|
| 304 |
+
role = "user"
|
| 305 |
else:
|
| 306 |
+
# If it's the last message, treat it as a user message
|
| 307 |
+
if idx == len(messages) - 1:
|
| 308 |
+
role = "user"
|
| 309 |
+
else:
|
| 310 |
+
role = "model"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 311 |
|
| 312 |
+
# Create parts list for this message
|
| 313 |
+
parts = []
|
|
|
|
| 314 |
|
| 315 |
+
# Handle different content types
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 316 |
if isinstance(message.content, str):
|
| 317 |
+
# Simple string content
|
| 318 |
+
parts.append(types.Part(text=message.content))
|
|
|
|
|
|
|
| 319 |
elif isinstance(message.content, list):
|
| 320 |
+
# List of content parts (may include text and images)
|
|
|
|
|
|
|
| 321 |
for part in message.content:
|
| 322 |
+
if isinstance(part, dict):
|
| 323 |
+
if part.get('type') == 'text':
|
| 324 |
+
parts.append(types.Part(text=part.get('text', '')))
|
| 325 |
+
elif part.get('type') == 'image_url':
|
| 326 |
+
image_url = part.get('image_url', {}).get('url', '')
|
| 327 |
+
if image_url.startswith('data:'):
|
| 328 |
+
# Extract mime type and base64 data
|
| 329 |
+
mime_match = re.match(r'data:([^;]+);base64,(.+)', image_url)
|
| 330 |
+
if mime_match:
|
| 331 |
+
mime_type, b64_data = mime_match.groups()
|
| 332 |
+
image_bytes = base64.b64decode(b64_data)
|
| 333 |
+
parts.append(types.Part.from_bytes(data=image_bytes, mime_type=mime_type))
|
| 334 |
elif isinstance(part, ContentPartText):
|
| 335 |
+
parts.append(types.Part(text=part.text))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
elif isinstance(part, ContentPartImage):
|
| 337 |
image_url = part.image_url.url
|
| 338 |
if image_url.startswith('data:'):
|
|
|
|
| 341 |
if mime_match:
|
| 342 |
mime_type, b64_data = mime_match.groups()
|
| 343 |
image_bytes = base64.b64decode(b64_data)
|
| 344 |
+
parts.append(types.Part.from_bytes(data=image_bytes, mime_type=mime_type))
|
| 345 |
+
else:
|
| 346 |
+
# Fallback for unexpected format
|
| 347 |
+
parts.append(types.Part(text=str(message.content)))
|
| 348 |
+
|
| 349 |
+
# Create a Content object with role and parts
|
| 350 |
+
content = types.Content(
|
| 351 |
+
role=role,
|
| 352 |
+
parts=parts
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
# Add to our list
|
| 356 |
+
gemini_messages.append(content)
|
| 357 |
+
|
| 358 |
+
print(f"Converted to {len(gemini_messages)} Gemini messages")
|
| 359 |
+
|
| 360 |
+
# If there's only one message, return it directly
|
| 361 |
+
if len(gemini_messages) == 1:
|
| 362 |
+
return gemini_messages[0]
|
| 363 |
+
|
| 364 |
+
# Otherwise return the list
|
| 365 |
+
return gemini_messages
|
| 366 |
+
|
| 367 |
+
# No need for the separate image handling branch as we now handle all content types in one flow
|
| 368 |
|
| 369 |
+
def create_encrypted_gemini_prompt(messages: List[OpenAIMessage]) -> Union[types.Content, List[types.Content]]:
|
| 370 |
"""
|
| 371 |
Convert OpenAI messages to Gemini format with special encoding for the encrypt model.
|
| 372 |
This function URL-encodes user messages and adds specific system instructions.
|
| 373 |
"""
|
| 374 |
+
print("Creating encrypted Gemini prompt...")
|
| 375 |
+
|
| 376 |
# Check if any message contains image content
|
| 377 |
has_images = False
|
| 378 |
for message in messages:
|
|
|
|
| 391 |
if has_images:
|
| 392 |
return create_gemini_prompt(messages)
|
| 393 |
|
| 394 |
+
# Define the instructions to be added
|
| 395 |
+
instructions = """// AI Assistant Configuration //
|
| 396 |
STRICT OPERATING PROTOCOL:
|
| 397 |
1. **Critical Output Format:** Your responses MUST ALWAYS be standard **UNENCODED PLAIN TEXT**. Absolutely NO URL encoding or any other form of encoding in your output. This is non-negotiable.
|
| 398 |
2. **Input Processing (Internal Only):** User messages and conversation history provided to you will be URL-encoded for privacy. Decode these inputs internally before processing.
|
|
|
|
| 418 |
# Create a new list of messages with the pre-messages and encoded content
|
| 419 |
new_messages = []
|
| 420 |
|
| 421 |
+
# Add a system message with instructions at the beginning
|
| 422 |
+
new_messages.append(OpenAIMessage(role="system", content=instructions))
|
| 423 |
|
| 424 |
# Add pre-messages
|
| 425 |
new_messages.extend(pre_messages)
|
|
|
|
| 427 |
# Process all messages in their original order
|
| 428 |
for i, message in enumerate(messages):
|
| 429 |
if message.role == "system":
|
| 430 |
+
# Pass system messages through as is
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 431 |
new_messages.append(message)
|
| 432 |
|
| 433 |
elif message.role == "user":
|
|
|
|
| 438 |
content=urllib.parse.quote(message.content)
|
| 439 |
))
|
| 440 |
elif isinstance(message.content, list):
|
| 441 |
+
# For list content (like with images), we need to handle each part
|
| 442 |
+
encoded_parts = []
|
| 443 |
+
for part in message.content:
|
| 444 |
+
if isinstance(part, dict) and part.get('type') == 'text':
|
| 445 |
+
# URL encode text parts
|
| 446 |
+
encoded_parts.append({
|
| 447 |
+
'type': 'text',
|
| 448 |
+
'text': urllib.parse.quote(part.get('text', ''))
|
| 449 |
+
})
|
| 450 |
+
else:
|
| 451 |
+
# Pass through non-text parts (like images)
|
| 452 |
+
encoded_parts.append(part)
|
| 453 |
+
|
| 454 |
+
new_messages.append(OpenAIMessage(
|
| 455 |
+
role=message.role,
|
| 456 |
+
content=encoded_parts
|
| 457 |
+
))
|
| 458 |
else:
|
| 459 |
+
# For assistant messages
|
| 460 |
+
# Check if this is the last assistant message in the conversation
|
| 461 |
is_last_assistant = True
|
| 462 |
for remaining_msg in messages[i+1:]:
|
| 463 |
if remaining_msg.role != "user":
|
|
|
|
| 471 |
role=message.role,
|
| 472 |
content=urllib.parse.quote(message.content)
|
| 473 |
))
|
| 474 |
+
elif isinstance(message.content, list):
|
| 475 |
+
# Handle list content similar to user messages
|
| 476 |
+
encoded_parts = []
|
| 477 |
+
for part in message.content:
|
| 478 |
+
if isinstance(part, dict) and part.get('type') == 'text':
|
| 479 |
+
encoded_parts.append({
|
| 480 |
+
'type': 'text',
|
| 481 |
+
'text': urllib.parse.quote(part.get('text', ''))
|
| 482 |
+
})
|
| 483 |
+
else:
|
| 484 |
+
encoded_parts.append(part)
|
| 485 |
+
|
| 486 |
+
new_messages.append(OpenAIMessage(
|
| 487 |
+
role=message.role,
|
| 488 |
+
content=encoded_parts
|
| 489 |
+
))
|
| 490 |
else:
|
| 491 |
+
# For non-string/list content, keep as is
|
| 492 |
new_messages.append(message)
|
| 493 |
else:
|
| 494 |
+
# For other assistant messages, keep as is
|
| 495 |
new_messages.append(message)
|
| 496 |
|
| 497 |
+
print(f"Created encrypted prompt with {len(new_messages)} messages")
|
| 498 |
# Now use the standard function to convert to Gemini format
|
| 499 |
return create_gemini_prompt(new_messages)
|
| 500 |
|
|
|
|
| 841 |
prompt = create_encrypted_gemini_prompt(request.messages)
|
| 842 |
else:
|
| 843 |
prompt = create_gemini_prompt(request.messages)
|
| 844 |
+
|
| 845 |
+
# Log the structure of the prompt (without exposing sensitive content)
|
| 846 |
+
if isinstance(prompt, list):
|
| 847 |
+
print(f"Prompt structure: {len(prompt)} messages")
|
| 848 |
+
for i, msg in enumerate(prompt):
|
| 849 |
+
role = msg.role if hasattr(msg, 'role') else 'unknown'
|
| 850 |
+
parts_count = len(msg.parts) if hasattr(msg, 'parts') else 0
|
| 851 |
+
parts_types = [type(p).__name__ for p in (msg.parts if hasattr(msg, 'parts') else [])]
|
| 852 |
+
print(f" Message {i+1}: role={role}, parts={parts_count}, types={parts_types}")
|
| 853 |
+
elif isinstance(prompt, types.Content):
|
| 854 |
+
print("Prompt structure: 1 message")
|
| 855 |
+
role = prompt.role if hasattr(prompt, 'role') else 'unknown'
|
| 856 |
+
parts_count = len(prompt.parts) if hasattr(prompt, 'parts') else 0
|
| 857 |
+
parts_types = [type(p).__name__ for p in (prompt.parts if hasattr(prompt, 'parts') else [])]
|
| 858 |
+
print(f" Message 1: role={role}, parts={parts_count}, types={parts_types}")
|
| 859 |
+
else:
|
| 860 |
+
print("Prompt structure: Unknown format")
|
| 861 |
|
| 862 |
if request.stream:
|
| 863 |
# Handle streaming response
|
|
|
|
| 870 |
# If multiple candidates are requested, we'll generate them sequentially
|
| 871 |
for candidate_index in range(candidate_count):
|
| 872 |
# Generate content with streaming
|
| 873 |
+
# Handle the new message format for streaming using Gemini types
|
| 874 |
+
print(f"Sending streaming request to Gemini API")
|
| 875 |
+
|
| 876 |
+
# The prompt is now either a Content object or a list of Content objects
|
| 877 |
responses = client.models.generate_content_stream(
|
| 878 |
model=gemini_model,
|
| 879 |
+
contents=prompt,
|
| 880 |
config=generation_config,
|
| 881 |
)
|
| 882 |
|
|
|
|
| 908 |
# Make sure generation_config has candidate_count set
|
| 909 |
if "candidate_count" not in generation_config:
|
| 910 |
generation_config["candidate_count"] = request.n
|
| 911 |
+
# Handle the new message format using Gemini types
|
| 912 |
+
print(f"Sending request to Gemini API")
|
| 913 |
+
|
| 914 |
+
# The prompt is now either a Content object or a list of Content objects
|
| 915 |
response = client.models.generate_content(
|
| 916 |
model=gemini_model,
|
| 917 |
+
contents=prompt,
|
| 918 |
config=generation_config,
|
| 919 |
)
|
| 920 |
|