Spaces:
Paused
Paused
Commit
·
5d7dc12
1
Parent(s):
cdf27f4
added reasoning support
Browse files- app/routes/chat_api.py +85 -31
app/routes/chat_api.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import asyncio
|
|
|
|
| 2 |
import json # Needed for error streaming
|
| 3 |
import random
|
| 4 |
from fastapi import APIRouter, Depends, Request
|
|
@@ -7,6 +8,7 @@ from typing import List, Dict, Any
|
|
| 7 |
|
| 8 |
# Google and OpenAI specific imports
|
| 9 |
from google.genai import types
|
|
|
|
| 10 |
from google import genai
|
| 11 |
import openai
|
| 12 |
from credentials_manager import _refresh_auth
|
|
@@ -229,7 +231,6 @@ async def chat_completions(fastapi_request: Request, request: OpenAIRequest, api
|
|
| 229 |
async for chunk in stream_response:
|
| 230 |
try:
|
| 231 |
chunk_as_dict = chunk.model_dump(exclude_unset=True, exclude_none=True)
|
| 232 |
-
print(chunk_as_dict)
|
| 233 |
|
| 234 |
# Safely navigate and check for thought flag
|
| 235 |
choices = chunk_as_dict.get('choices')
|
|
@@ -253,6 +254,7 @@ async def chat_completions(fastapi_request: Request, request: OpenAIRequest, api
|
|
| 253 |
del delta['extra_content']
|
| 254 |
|
| 255 |
# Yield the (potentially modified) dictionary as JSON
|
|
|
|
| 256 |
yield f"data: {json.dumps(chunk_as_dict)}\n\n"
|
| 257 |
|
| 258 |
except Exception as chunk_processing_error: # Catch errors from dict manipulation or json.dumps
|
|
@@ -290,39 +292,91 @@ async def chat_completions(fastapi_request: Request, request: OpenAIRequest, api
|
|
| 290 |
extra_body=openai_extra_body
|
| 291 |
)
|
| 292 |
response_dict = response.model_dump(exclude_unset=True, exclude_none=True)
|
| 293 |
-
|
| 294 |
-
# Process reasoning_tokens for non-streaming response
|
| 295 |
try:
|
| 296 |
usage = response_dict.get('usage')
|
|
|
|
|
|
|
| 297 |
if usage and isinstance(usage, dict):
|
| 298 |
-
|
| 299 |
-
if completion_details and isinstance(completion_details, dict):
|
| 300 |
-
num_reasoning_tokens = completion_details.get('reasoning_tokens')
|
| 301 |
-
|
| 302 |
-
if isinstance(num_reasoning_tokens, int) and num_reasoning_tokens > 0:
|
| 303 |
-
choices = response_dict.get('choices')
|
| 304 |
-
if choices and isinstance(choices, list) and len(choices) > 0:
|
| 305 |
-
# Ensure choices[0] and message are dicts, model_dump makes them so
|
| 306 |
-
message_dict = choices[0].get('message')
|
| 307 |
-
if message_dict and isinstance(message_dict, dict):
|
| 308 |
-
full_content = message_dict.get('content')
|
| 309 |
-
if isinstance(full_content, str): # Ensure content is a string
|
| 310 |
-
reasoning_text = full_content[:num_reasoning_tokens]
|
| 311 |
-
actual_content = full_content[num_reasoning_tokens:]
|
| 312 |
-
|
| 313 |
-
message_dict['reasoning_content'] = reasoning_text
|
| 314 |
-
message_dict['content'] = actual_content
|
| 315 |
-
|
| 316 |
-
# Clean up Vertex-specific field
|
| 317 |
-
del completion_details['reasoning_tokens']
|
| 318 |
-
if not completion_details: # If dict is now empty
|
| 319 |
-
del usage['completion_tokens_details']
|
| 320 |
-
if not usage: # If dict is now empty
|
| 321 |
-
del response_dict['usage']
|
| 322 |
-
except Exception as e_non_stream_reasoning:
|
| 323 |
-
print(f"WARNING: Could not process non-streaming reasoning tokens for model {request.model}: {e_non_stream_reasoning}. Response will be returned as is from Vertex.")
|
| 324 |
-
# Fallthrough to return response_dict as is if processing fails
|
| 325 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 326 |
return JSONResponse(content=response_dict)
|
| 327 |
except Exception as generate_error:
|
| 328 |
error_msg_generate = f"Error calling OpenAI client for {request.model}: {str(generate_error)}"
|
|
@@ -396,4 +450,4 @@ async def chat_completions(fastapi_request: Request, request: OpenAIRequest, api
|
|
| 396 |
except Exception as e:
|
| 397 |
error_msg = f"Unexpected error in chat_completions endpoint: {str(e)}"
|
| 398 |
print(error_msg)
|
| 399 |
-
return JSONResponse(status_code=500, content=create_openai_error_response(500, error_msg, "server_error"))
|
|
|
|
| 1 |
import asyncio
|
| 2 |
+
import base64 # Ensure base64 is imported
|
| 3 |
import json # Needed for error streaming
|
| 4 |
import random
|
| 5 |
from fastapi import APIRouter, Depends, Request
|
|
|
|
| 8 |
|
| 9 |
# Google and OpenAI specific imports
|
| 10 |
from google.genai import types
|
| 11 |
+
from google.genai.types import HttpOptions # Added for compute_tokens
|
| 12 |
from google import genai
|
| 13 |
import openai
|
| 14 |
from credentials_manager import _refresh_auth
|
|
|
|
| 231 |
async for chunk in stream_response:
|
| 232 |
try:
|
| 233 |
chunk_as_dict = chunk.model_dump(exclude_unset=True, exclude_none=True)
|
|
|
|
| 234 |
|
| 235 |
# Safely navigate and check for thought flag
|
| 236 |
choices = chunk_as_dict.get('choices')
|
|
|
|
| 254 |
del delta['extra_content']
|
| 255 |
|
| 256 |
# Yield the (potentially modified) dictionary as JSON
|
| 257 |
+
print(chunk_as_dict)
|
| 258 |
yield f"data: {json.dumps(chunk_as_dict)}\n\n"
|
| 259 |
|
| 260 |
except Exception as chunk_processing_error: # Catch errors from dict manipulation or json.dumps
|
|
|
|
| 292 |
extra_body=openai_extra_body
|
| 293 |
)
|
| 294 |
response_dict = response.model_dump(exclude_unset=True, exclude_none=True)
|
| 295 |
+
|
|
|
|
| 296 |
try:
|
| 297 |
usage = response_dict.get('usage')
|
| 298 |
+
vertex_completion_tokens = 0
|
| 299 |
+
|
| 300 |
if usage and isinstance(usage, dict):
|
| 301 |
+
vertex_completion_tokens = usage.get('completion_tokens')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 302 |
|
| 303 |
+
choices = response_dict.get('choices')
|
| 304 |
+
if choices and isinstance(choices, list) and len(choices) > 0:
|
| 305 |
+
message_dict = choices[0].get('message')
|
| 306 |
+
if message_dict and isinstance(message_dict, dict):
|
| 307 |
+
# Always remove extra_content from the message if it exists, before any splitting
|
| 308 |
+
if 'extra_content' in message_dict:
|
| 309 |
+
del message_dict['extra_content']
|
| 310 |
+
print("DEBUG: Removed 'extra_content' from response message.")
|
| 311 |
+
|
| 312 |
+
if isinstance(vertex_completion_tokens, int) and vertex_completion_tokens > 0:
|
| 313 |
+
full_content = message_dict.get('content')
|
| 314 |
+
if isinstance(full_content, str) and full_content:
|
| 315 |
+
|
| 316 |
+
def _get_token_strings_and_split_texts_sync(creds, proj_id, loc, model_id_for_tokenizer, text_to_tokenize, num_completion_tokens_from_usage):
|
| 317 |
+
sync_tokenizer_client = genai.Client(
|
| 318 |
+
vertexai=True, credentials=creds, project=proj_id, location=loc,
|
| 319 |
+
http_options=HttpOptions(api_version="v1")
|
| 320 |
+
)
|
| 321 |
+
if not text_to_tokenize: return "", text_to_tokenize, [] # No reasoning, original content, empty token list
|
| 322 |
+
|
| 323 |
+
token_compute_response = sync_tokenizer_client.models.compute_tokens(
|
| 324 |
+
model=model_id_for_tokenizer, contents=text_to_tokenize
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
all_final_token_strings = []
|
| 328 |
+
if token_compute_response.tokens_info:
|
| 329 |
+
for token_info_item in token_compute_response.tokens_info:
|
| 330 |
+
for api_token_bytes in token_info_item.tokens:
|
| 331 |
+
intermediate_str = api_token_bytes.decode('utf-8', errors='replace')
|
| 332 |
+
final_token_text = ""
|
| 333 |
+
try:
|
| 334 |
+
b64_decoded_bytes = base64.b64decode(intermediate_str)
|
| 335 |
+
final_token_text = b64_decoded_bytes.decode('utf-8', errors='replace')
|
| 336 |
+
except Exception:
|
| 337 |
+
final_token_text = intermediate_str
|
| 338 |
+
all_final_token_strings.append(final_token_text)
|
| 339 |
+
|
| 340 |
+
if not all_final_token_strings: # Should not happen if text_to_tokenize is not empty
|
| 341 |
+
return "", text_to_tokenize, []
|
| 342 |
+
|
| 343 |
+
if not (0 < num_completion_tokens_from_usage <= len(all_final_token_strings)):
|
| 344 |
+
print(f"WARNING_TOKEN_SPLIT: num_completion_tokens_from_usage ({num_completion_tokens_from_usage}) is invalid for total client-tokenized tokens ({len(all_final_token_strings)}). Returning full content as 'content'.")
|
| 345 |
+
return "", "".join(all_final_token_strings), all_final_token_strings
|
| 346 |
+
|
| 347 |
+
completion_part_tokens = all_final_token_strings[-num_completion_tokens_from_usage:]
|
| 348 |
+
reasoning_part_tokens = all_final_token_strings[:-num_completion_tokens_from_usage]
|
| 349 |
+
|
| 350 |
+
reasoning_output_str = "".join(reasoning_part_tokens)
|
| 351 |
+
completion_output_str = "".join(completion_part_tokens)
|
| 352 |
+
|
| 353 |
+
return reasoning_output_str, completion_output_str, all_final_token_strings
|
| 354 |
+
|
| 355 |
+
model_id_for_tokenizer = base_model_name
|
| 356 |
+
|
| 357 |
+
reasoning_text, actual_content, dbg_all_tokens = await asyncio.to_thread(
|
| 358 |
+
_get_token_strings_and_split_texts_sync,
|
| 359 |
+
rotated_credentials, PROJECT_ID, LOCATION,
|
| 360 |
+
model_id_for_tokenizer, full_content, vertex_completion_tokens
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
message_dict['content'] = actual_content # Set the new content (potentially from joined tokens)
|
| 364 |
+
if reasoning_text: # Only add reasoning_content if it's not empty
|
| 365 |
+
message_dict['reasoning_content'] = reasoning_text
|
| 366 |
+
print(f"DEBUG_REASONING_SPLIT_DIRECT_JOIN: Successful. Reasoning len: {len(reasoning_text)}. Content len: {len(actual_content)}")
|
| 367 |
+
print(f" Vertex completion_tokens: {vertex_completion_tokens}. Our tokenizer total tokens: {len(dbg_all_tokens)}")
|
| 368 |
+
elif "".join(dbg_all_tokens) != full_content : # Content was re-joined from tokens but no reasoning
|
| 369 |
+
print(f"INFO: Content reconstructed from tokens. Original len: {len(full_content)}, Reconstructed len: {len(actual_content)}")
|
| 370 |
+
# else: No reasoning, and content is original full_content because num_completion_tokens was invalid or zero.
|
| 371 |
+
|
| 372 |
+
else:
|
| 373 |
+
print(f"WARNING: Full content is not a string or is empty. Cannot perform split. Content: {full_content}")
|
| 374 |
+
else:
|
| 375 |
+
print(f"INFO: No positive vertex_completion_tokens ({vertex_completion_tokens}) found in usage, or no message content. No split performed.")
|
| 376 |
+
|
| 377 |
+
except Exception as e_reasoning_processing:
|
| 378 |
+
print(f"WARNING: Error during non-streaming reasoning token processing for model {request.model} due to: {e_reasoning_processing}.")
|
| 379 |
+
|
| 380 |
return JSONResponse(content=response_dict)
|
| 381 |
except Exception as generate_error:
|
| 382 |
error_msg_generate = f"Error calling OpenAI client for {request.model}: {str(generate_error)}"
|
|
|
|
| 450 |
except Exception as e:
|
| 451 |
error_msg = f"Unexpected error in chat_completions endpoint: {str(e)}"
|
| 452 |
print(error_msg)
|
| 453 |
+
return JSONResponse(status_code=500, content=create_openai_error_response(500, error_msg, "server_error"))
|