Update app/main.py
Browse files- app/main.py +172 -57
app/main.py
CHANGED
|
@@ -19,6 +19,7 @@ import config
|
|
| 19 |
from google.genai import types
|
| 20 |
|
| 21 |
from google import genai
|
|
|
|
| 22 |
|
| 23 |
client = None
|
| 24 |
|
|
@@ -81,7 +82,7 @@ class CredentialManager:
|
|
| 81 |
self.credentials_files = glob.glob(pattern)
|
| 82 |
|
| 83 |
if not self.credentials_files:
|
| 84 |
-
print(f"No credential files found in {self.credentials_dir}")
|
| 85 |
return False
|
| 86 |
|
| 87 |
print(f"Found {len(self.credentials_files)} credential files: {[os.path.basename(f) for f in self.credentials_files]}")
|
|
@@ -1057,62 +1058,7 @@ async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_ap
|
|
| 1057 |
]
|
| 1058 |
generation_config["safety_settings"] = safety_settings
|
| 1059 |
|
| 1060 |
-
# --- Helper function to check response validity ---
|
| 1061 |
-
def is_response_valid(response):
|
| 1062 |
-
if response is None:
|
| 1063 |
-
return False
|
| 1064 |
|
| 1065 |
-
# Check if candidates exist
|
| 1066 |
-
if not hasattr(response, 'candidates') or not response.candidates:
|
| 1067 |
-
return False
|
| 1068 |
-
|
| 1069 |
-
# Get the first candidate
|
| 1070 |
-
candidate = response.candidates[0]
|
| 1071 |
-
|
| 1072 |
-
# Try different ways to access the text content
|
| 1073 |
-
text_content = None
|
| 1074 |
-
|
| 1075 |
-
# Method 1: Direct text attribute on candidate
|
| 1076 |
-
if hasattr(candidate, 'text'):
|
| 1077 |
-
text_content = candidate.text
|
| 1078 |
-
# Method 2: Text attribute on response
|
| 1079 |
-
elif hasattr(response, 'text'):
|
| 1080 |
-
text_content = response.text
|
| 1081 |
-
# Method 3: Content with parts
|
| 1082 |
-
elif hasattr(candidate, 'content') and hasattr(candidate.content, 'parts'):
|
| 1083 |
-
# Look for text in parts
|
| 1084 |
-
for part in candidate.content.parts:
|
| 1085 |
-
if hasattr(part, 'text') and part.text:
|
| 1086 |
-
text_content = part.text
|
| 1087 |
-
break
|
| 1088 |
-
|
| 1089 |
-
# Check the extracted text content
|
| 1090 |
-
if text_content is None:
|
| 1091 |
-
# No text content was found at all. Check for other parts as a fallback?
|
| 1092 |
-
# For now, let's consider no text as invalid for retry purposes,
|
| 1093 |
-
# as the primary goal is text generation.
|
| 1094 |
-
# If other non-text parts WERE valid outcomes, this logic would need adjustment.
|
| 1095 |
-
# Original check considered any parts as valid if text was missing/empty:
|
| 1096 |
-
# if hasattr(candidate, 'content') and hasattr(candidate.content, 'parts'):
|
| 1097 |
-
# if len(candidate.content.parts) > 0:
|
| 1098 |
-
# return True
|
| 1099 |
-
return False # Treat no text found as invalid
|
| 1100 |
-
elif text_content == '':
|
| 1101 |
-
# Explicit empty string found
|
| 1102 |
-
return False # Treat empty string as invalid for retry
|
| 1103 |
-
else:
|
| 1104 |
-
# Non-empty text content found
|
| 1105 |
-
return True # Valid response
|
| 1106 |
-
|
| 1107 |
-
# Also check if the response itself has text
|
| 1108 |
-
if hasattr(response, 'text') and response.text:
|
| 1109 |
-
return True
|
| 1110 |
-
|
| 1111 |
-
# If we got here, the response is invalid
|
| 1112 |
-
print(f"Invalid response: No text content found in response structure: {str(response)[:200]}...")
|
| 1113 |
-
return False
|
| 1114 |
-
|
| 1115 |
-
|
| 1116 |
# --- Helper function to make the API call (handles stream/non-stream) ---
|
| 1117 |
async def make_gemini_call(model_name, prompt_func, current_gen_config):
|
| 1118 |
prompt = prompt_func(request.messages)
|
|
@@ -1133,7 +1079,11 @@ async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_ap
|
|
| 1133 |
|
| 1134 |
|
| 1135 |
if request.stream:
|
| 1136 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1137 |
response_id = f"chatcmpl-{int(time.time())}"
|
| 1138 |
candidate_count = request.n or 1
|
| 1139 |
|
|
@@ -1320,6 +1270,171 @@ async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_ap
|
|
| 1320 |
# Ensure we return a JSON response even for stream requests if error happens early
|
| 1321 |
return JSONResponse(status_code=500, content=error_response)
|
| 1322 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1323 |
# --- Need to import asyncio ---
|
| 1324 |
# import asyncio # Add this import at the top of the file # Already added below
|
| 1325 |
|
|
|
|
| 19 |
from google.genai import types
|
| 20 |
|
| 21 |
from google import genai
|
| 22 |
+
import math
|
| 23 |
|
| 24 |
client = None
|
| 25 |
|
|
|
|
| 82 |
self.credentials_files = glob.glob(pattern)
|
| 83 |
|
| 84 |
if not self.credentials_files:
|
| 85 |
+
# print(f"No credential files found in {self.credentials_dir}")
|
| 86 |
return False
|
| 87 |
|
| 88 |
print(f"Found {len(self.credentials_files)} credential files: {[os.path.basename(f) for f in self.credentials_files]}")
|
|
|
|
| 1058 |
]
|
| 1059 |
generation_config["safety_settings"] = safety_settings
|
| 1060 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1061 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1062 |
# --- Helper function to make the API call (handles stream/non-stream) ---
|
| 1063 |
async def make_gemini_call(model_name, prompt_func, current_gen_config):
|
| 1064 |
prompt = prompt_func(request.messages)
|
|
|
|
| 1079 |
|
| 1080 |
|
| 1081 |
if request.stream:
|
| 1082 |
+
# Check if fake streaming is enabled
|
| 1083 |
+
if config.FAKE_STREAMING:
|
| 1084 |
+
return await fake_stream_generator(model_name, prompt, current_gen_config, request)
|
| 1085 |
+
|
| 1086 |
+
# Regular streaming call
|
| 1087 |
response_id = f"chatcmpl-{int(time.time())}"
|
| 1088 |
candidate_count = request.n or 1
|
| 1089 |
|
|
|
|
| 1270 |
# Ensure we return a JSON response even for stream requests if error happens early
|
| 1271 |
return JSONResponse(status_code=500, content=error_response)
|
| 1272 |
|
| 1273 |
+
# --- Helper function to check response validity ---
|
| 1274 |
+
# Moved function definition here from inside chat_completions
|
| 1275 |
+
def is_response_valid(response):
|
| 1276 |
+
"""Checks if the Gemini response contains valid, non-empty text content."""
|
| 1277 |
+
if response is None:
|
| 1278 |
+
return False
|
| 1279 |
+
|
| 1280 |
+
# Check if candidates exist and are not empty
|
| 1281 |
+
if not hasattr(response, 'candidates') or not response.candidates:
|
| 1282 |
+
# Blocked responses might lack candidates
|
| 1283 |
+
if hasattr(response, 'prompt_feedback') and response.prompt_feedback.block_reason:
|
| 1284 |
+
print(f"Response blocked: {response.prompt_feedback.block_reason}")
|
| 1285 |
+
# Consider blocked prompts as 'invalid' for retry logic,
|
| 1286 |
+
# but note the specific reason if needed elsewhere.
|
| 1287 |
+
return False
|
| 1288 |
+
print("Response has no candidates.")
|
| 1289 |
+
return False
|
| 1290 |
+
|
| 1291 |
+
# Get the first candidate
|
| 1292 |
+
candidate = response.candidates[0]
|
| 1293 |
+
|
| 1294 |
+
# Check finish reason - if blocked, it's invalid
|
| 1295 |
+
if hasattr(candidate, 'finish_reason') and candidate.finish_reason != 1: # 1 == STOP
|
| 1296 |
+
print(f"Candidate finish reason indicates issue: {candidate.finish_reason}")
|
| 1297 |
+
#SAFETY = 2, RECITATION = 3, OTHER = 4
|
| 1298 |
+
return False
|
| 1299 |
+
|
| 1300 |
+
# Try different ways to access the text content
|
| 1301 |
+
text_content = None
|
| 1302 |
+
|
| 1303 |
+
# Method 1: Direct text attribute on candidate (sometimes present)
|
| 1304 |
+
if hasattr(candidate, 'text'):
|
| 1305 |
+
text_content = candidate.text
|
| 1306 |
+
# Method 2: Check within candidate.content.parts (standard way)
|
| 1307 |
+
elif hasattr(candidate, 'content') and hasattr(candidate.content, 'parts'):
|
| 1308 |
+
for part in candidate.content.parts:
|
| 1309 |
+
if hasattr(part, 'text'):
|
| 1310 |
+
text_content = part.text # Use the first text part found
|
| 1311 |
+
break
|
| 1312 |
+
# Method 3: Direct text attribute on the root response object (less common)
|
| 1313 |
+
elif hasattr(response, 'text'):
|
| 1314 |
+
text_content = response.text
|
| 1315 |
+
|
| 1316 |
+
# Check the extracted text content
|
| 1317 |
+
if text_content is None:
|
| 1318 |
+
print("No text content found in response/candidates.")
|
| 1319 |
+
return False
|
| 1320 |
+
elif text_content == '':
|
| 1321 |
+
print("Response text content is an empty string.")
|
| 1322 |
+
# Decide if empty string is valid. For retry, maybe not.
|
| 1323 |
+
return False # Treat empty string as invalid for retry
|
| 1324 |
+
else:
|
| 1325 |
+
# Non-empty text content found
|
| 1326 |
+
return True # Valid response
|
| 1327 |
+
|
| 1328 |
+
# Fallback - should not be reached if logic above is correct
|
| 1329 |
+
# print(f"Invalid response structure: No valid text found. {str(response)[:200]}...")
|
| 1330 |
+
# return False # Covered by text_content is None check
|
| 1331 |
+
|
| 1332 |
+
# --- Fake streaming implementation ---
|
| 1333 |
+
async def fake_stream_generator(model_name, prompt, current_gen_config, request):
|
| 1334 |
+
"""
|
| 1335 |
+
Simulates streaming by making a non-streaming API call and chunking the response.
|
| 1336 |
+
While waiting for the response, sends keep-alive messages to the client.
|
| 1337 |
+
"""
|
| 1338 |
+
response_id = f"chatcmpl-{int(time.time())}"
|
| 1339 |
+
|
| 1340 |
+
async def fake_stream_inner():
|
| 1341 |
+
# Create a task for the non-streaming API call
|
| 1342 |
+
print(f"FAKE STREAMING: Making non-streaming request to Gemini API (Model: {model_name})")
|
| 1343 |
+
api_call_task = asyncio.create_task(
|
| 1344 |
+
client.aio.models.generate_content(
|
| 1345 |
+
model=model_name,
|
| 1346 |
+
contents=prompt,
|
| 1347 |
+
config=current_gen_config,
|
| 1348 |
+
)
|
| 1349 |
+
)
|
| 1350 |
+
|
| 1351 |
+
# Send keep-alive messages while waiting for the response
|
| 1352 |
+
keep_alive_sent = 0
|
| 1353 |
+
while not api_call_task.done():
|
| 1354 |
+
# Create a keep-alive message
|
| 1355 |
+
keep_alive_chunk = {
|
| 1356 |
+
"id": "chatcmpl-keepalive",
|
| 1357 |
+
"object": "chat.completion.chunk",
|
| 1358 |
+
"created": int(time.time()),
|
| 1359 |
+
"model": request.model,
|
| 1360 |
+
"choices": [{"delta": {"content": ""}, "index": 0, "finish_reason": None}]
|
| 1361 |
+
}
|
| 1362 |
+
keep_alive_message = f"data: {json.dumps(keep_alive_chunk)}\n\n"
|
| 1363 |
+
|
| 1364 |
+
# Send the keep-alive message
|
| 1365 |
+
yield keep_alive_message
|
| 1366 |
+
keep_alive_sent += 1
|
| 1367 |
+
|
| 1368 |
+
# Wait before sending the next keep-alive message
|
| 1369 |
+
await asyncio.sleep(config.FAKE_STREAMING_INTERVAL)
|
| 1370 |
+
|
| 1371 |
+
try:
|
| 1372 |
+
# Get the response from the completed task
|
| 1373 |
+
response = api_call_task.result()
|
| 1374 |
+
|
| 1375 |
+
# Check if the response is valid
|
| 1376 |
+
if not is_response_valid(response):
|
| 1377 |
+
raise ValueError("Invalid or empty response received")
|
| 1378 |
+
|
| 1379 |
+
# Extract the full text content
|
| 1380 |
+
full_text = ""
|
| 1381 |
+
if hasattr(response, 'text'):
|
| 1382 |
+
full_text = response.text
|
| 1383 |
+
elif hasattr(response, 'candidates') and response.candidates:
|
| 1384 |
+
candidate = response.candidates[0]
|
| 1385 |
+
if hasattr(candidate, 'text'):
|
| 1386 |
+
full_text = candidate.text
|
| 1387 |
+
elif hasattr(candidate, 'content') and hasattr(candidate.content, 'parts'):
|
| 1388 |
+
for part in candidate.content.parts:
|
| 1389 |
+
if hasattr(part, 'text'):
|
| 1390 |
+
full_text += part.text
|
| 1391 |
+
|
| 1392 |
+
if not full_text:
|
| 1393 |
+
raise ValueError("No text content found in response")
|
| 1394 |
+
|
| 1395 |
+
print(f"FAKE STREAMING: Received full response ({len(full_text)} chars), chunking into smaller pieces")
|
| 1396 |
+
|
| 1397 |
+
# Split the full text into chunks
|
| 1398 |
+
# Calculate a reasonable chunk size based on text length
|
| 1399 |
+
# Aim for ~10 chunks, but with a minimum size of 20 chars
|
| 1400 |
+
chunk_size = max(20, math.ceil(len(full_text) / 10))
|
| 1401 |
+
|
| 1402 |
+
# Send each chunk as a separate SSE message
|
| 1403 |
+
for i in range(0, len(full_text), chunk_size):
|
| 1404 |
+
chunk_text = full_text[i:i+chunk_size]
|
| 1405 |
+
chunk_data = {
|
| 1406 |
+
"id": response_id,
|
| 1407 |
+
"object": "chat.completion.chunk",
|
| 1408 |
+
"created": int(time.time()),
|
| 1409 |
+
"model": request.model,
|
| 1410 |
+
"choices": [
|
| 1411 |
+
{
|
| 1412 |
+
"index": 0,
|
| 1413 |
+
"delta": {
|
| 1414 |
+
"content": chunk_text
|
| 1415 |
+
},
|
| 1416 |
+
"finish_reason": None
|
| 1417 |
+
}
|
| 1418 |
+
]
|
| 1419 |
+
}
|
| 1420 |
+
yield f"data: {json.dumps(chunk_data)}\n\n"
|
| 1421 |
+
|
| 1422 |
+
# Small delay between chunks to simulate streaming
|
| 1423 |
+
await asyncio.sleep(0.05)
|
| 1424 |
+
|
| 1425 |
+
# Send the final chunk
|
| 1426 |
+
yield create_final_chunk(request.model, response_id)
|
| 1427 |
+
yield "data: [DONE]\n\n"
|
| 1428 |
+
|
| 1429 |
+
except Exception as e:
|
| 1430 |
+
error_msg = f"Error in fake streaming (Model: {model_name}): {str(e)}"
|
| 1431 |
+
print(error_msg)
|
| 1432 |
+
error_response = create_openai_error_response(500, error_msg, "server_error")
|
| 1433 |
+
yield f"data: {json.dumps(error_response)}\n\n"
|
| 1434 |
+
yield "data: [DONE]\n\n"
|
| 1435 |
+
|
| 1436 |
+
return StreamingResponse(fake_stream_inner(), media_type="text/event-stream")
|
| 1437 |
+
|
| 1438 |
# --- Need to import asyncio ---
|
| 1439 |
# import asyncio # Add this import at the top of the file # Already added below
|
| 1440 |
|