Update app.py
Browse files
app.py
CHANGED
|
@@ -629,6 +629,10 @@ def update_context_display(provider, model_name):
|
|
| 629 |
|
| 630 |
def is_vision_model(provider, model_name):
|
| 631 |
"""Check if a model supports vision/images"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 632 |
if provider in VISION_MODELS:
|
| 633 |
if model_name in VISION_MODELS[provider]:
|
| 634 |
return True
|
|
@@ -1132,6 +1136,7 @@ def extract_ai_response(result, provider):
|
|
| 1132 |
# ==========================================================
|
| 1133 |
|
| 1134 |
def openrouter_streaming_handler(response, history, message):
|
|
|
|
| 1135 |
try:
|
| 1136 |
updated_history = history + [{"role": "user", "content": message}]
|
| 1137 |
assistant_response = ""
|
|
@@ -1163,66 +1168,54 @@ def openrouter_streaming_handler(response, history, message):
|
|
| 1163 |
# Add error message to the current response
|
| 1164 |
yield updated_history + [{"role": "assistant", "content": f"Error during streaming: {str(e)}"}]
|
| 1165 |
|
| 1166 |
-
def openai_streaming_handler(response,
|
|
|
|
| 1167 |
try:
|
| 1168 |
-
|
| 1169 |
-
|
| 1170 |
-
|
| 1171 |
-
|
| 1172 |
-
full_response = ""
|
| 1173 |
for chunk in response:
|
| 1174 |
if hasattr(chunk.choices[0].delta, "content") and chunk.choices[0].delta.content is not None:
|
| 1175 |
content = chunk.choices[0].delta.content
|
| 1176 |
-
|
| 1177 |
-
|
| 1178 |
-
yield chatbot
|
| 1179 |
-
|
| 1180 |
except Exception as e:
|
| 1181 |
logger.error(f"Error in OpenAI streaming handler: {str(e)}")
|
| 1182 |
# Add error message to the current response
|
| 1183 |
-
|
| 1184 |
-
yield chatbot
|
| 1185 |
|
| 1186 |
-
def groq_streaming_handler(response,
|
|
|
|
| 1187 |
try:
|
| 1188 |
-
|
| 1189 |
-
|
| 1190 |
-
|
| 1191 |
-
|
| 1192 |
-
full_response = ""
|
| 1193 |
for chunk in response:
|
| 1194 |
if hasattr(chunk.choices[0].delta, "content") and chunk.choices[0].delta.content is not None:
|
| 1195 |
content = chunk.choices[0].delta.content
|
| 1196 |
-
|
| 1197 |
-
|
| 1198 |
-
yield chatbot
|
| 1199 |
-
|
| 1200 |
except Exception as e:
|
| 1201 |
logger.error(f"Error in Groq streaming handler: {str(e)}")
|
| 1202 |
# Add error message to the current response
|
| 1203 |
-
|
| 1204 |
-
yield chatbot
|
| 1205 |
|
| 1206 |
-
def together_streaming_handler(response,
|
|
|
|
| 1207 |
try:
|
| 1208 |
-
|
| 1209 |
-
|
| 1210 |
-
|
| 1211 |
-
|
| 1212 |
-
full_response = ""
|
| 1213 |
for chunk in response:
|
| 1214 |
if hasattr(chunk.choices[0].delta, "content") and chunk.choices[0].delta.content is not None:
|
| 1215 |
content = chunk.choices[0].delta.content
|
| 1216 |
-
|
| 1217 |
-
|
| 1218 |
-
yield chatbot
|
| 1219 |
-
|
| 1220 |
except Exception as e:
|
| 1221 |
logger.error(f"Error in Together streaming handler: {str(e)}")
|
| 1222 |
# Add error message to the current response
|
| 1223 |
-
|
| 1224 |
-
|
| 1225 |
-
|
| 1226 |
# ==========================================================
|
| 1227 |
# MAIN FUNCTION TO ASK AI
|
| 1228 |
# ==========================================================
|
|
@@ -1236,11 +1229,8 @@ def ask_ai(message, history, provider, model_choice, temperature, max_tokens, to
|
|
| 1236 |
if not message.strip() and not images and not documents:
|
| 1237 |
return history
|
| 1238 |
|
| 1239 |
-
#
|
| 1240 |
-
|
| 1241 |
-
|
| 1242 |
-
# Create messages from chat history
|
| 1243 |
-
messages = format_to_message_dict(chat_history)
|
| 1244 |
|
| 1245 |
# Add system message if provided
|
| 1246 |
if system_message and system_message.strip():
|
|
@@ -1252,7 +1242,7 @@ def ask_ai(message, history, provider, model_choice, temperature, max_tokens, to
|
|
| 1252 |
# Prepare message with images and documents if any
|
| 1253 |
content = prepare_message_with_media(message, images, documents)
|
| 1254 |
|
| 1255 |
-
# Add current message
|
| 1256 |
messages.append({"role": "user", "content": content})
|
| 1257 |
|
| 1258 |
# Common parameters for all providers
|
|
@@ -1272,8 +1262,11 @@ def ask_ai(message, history, provider, model_choice, temperature, max_tokens, to
|
|
| 1272 |
model_id, _ = get_model_info(provider, model_choice)
|
| 1273 |
if not model_id:
|
| 1274 |
error_message = f"Error: Model '{model_choice}' not found in OpenRouter"
|
| 1275 |
-
|
| 1276 |
-
return
|
|
|
|
|
|
|
|
|
|
| 1277 |
|
| 1278 |
# Build OpenRouter payload
|
| 1279 |
payload = {
|
|
@@ -1319,13 +1312,35 @@ def ask_ai(message, history, provider, model_choice, temperature, max_tokens, to
|
|
| 1319 |
|
| 1320 |
# Handle streaming response
|
| 1321 |
if stream_output and response.status_code == 200:
|
| 1322 |
-
# Add
|
| 1323 |
-
|
| 1324 |
|
| 1325 |
# Set up generator for streaming updates
|
| 1326 |
def streaming_generator():
|
| 1327 |
-
|
| 1328 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1329 |
|
| 1330 |
return streaming_generator()
|
| 1331 |
|
|
@@ -1337,9 +1352,11 @@ def ask_ai(message, history, provider, model_choice, temperature, max_tokens, to
|
|
| 1337 |
# Extract AI response
|
| 1338 |
ai_response = extract_ai_response(result, provider)
|
| 1339 |
|
| 1340 |
-
# Add response to history
|
| 1341 |
-
|
| 1342 |
-
|
|
|
|
|
|
|
| 1343 |
|
| 1344 |
# Handle error response
|
| 1345 |
else:
|
|
@@ -1351,16 +1368,20 @@ def ask_ai(message, history, provider, model_choice, temperature, max_tokens, to
|
|
| 1351 |
error_message += f"\n\nResponse: {response.text}"
|
| 1352 |
|
| 1353 |
logger.error(error_message)
|
| 1354 |
-
|
| 1355 |
-
|
|
|
|
|
|
|
| 1356 |
|
| 1357 |
elif provider == "OpenAI":
|
| 1358 |
# Get model ID from registry
|
| 1359 |
model_id, _ = get_model_info(provider, model_choice)
|
| 1360 |
if not model_id:
|
| 1361 |
error_message = f"Error: Model '{model_choice}' not found in OpenAI"
|
| 1362 |
-
|
| 1363 |
-
|
|
|
|
|
|
|
| 1364 |
|
| 1365 |
# Build OpenAI payload
|
| 1366 |
payload = {
|
|
@@ -1381,34 +1402,44 @@ def ask_ai(message, history, provider, model_choice, temperature, max_tokens, to
|
|
| 1381 |
|
| 1382 |
# Handle streaming response
|
| 1383 |
if stream_output:
|
| 1384 |
-
# Add
|
| 1385 |
-
|
| 1386 |
|
| 1387 |
# Set up generator for streaming updates
|
| 1388 |
def streaming_generator():
|
| 1389 |
-
|
| 1390 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1391 |
|
| 1392 |
return streaming_generator()
|
| 1393 |
|
| 1394 |
# Handle normal response
|
| 1395 |
else:
|
| 1396 |
ai_response = extract_ai_response(response, provider)
|
| 1397 |
-
|
| 1398 |
-
|
|
|
|
|
|
|
| 1399 |
except Exception as e:
|
| 1400 |
error_message = f"OpenAI API Error: {str(e)}"
|
| 1401 |
logger.error(error_message)
|
| 1402 |
-
|
| 1403 |
-
|
|
|
|
|
|
|
| 1404 |
|
| 1405 |
elif provider == "HuggingFace":
|
| 1406 |
# Get model ID from registry
|
| 1407 |
model_id, _ = get_model_info(provider, model_choice)
|
| 1408 |
if not model_id:
|
| 1409 |
error_message = f"Error: Model '{model_choice}' not found in HuggingFace"
|
| 1410 |
-
|
| 1411 |
-
|
|
|
|
|
|
|
| 1412 |
|
| 1413 |
# Build HuggingFace payload
|
| 1414 |
payload = {
|
|
@@ -1426,21 +1457,27 @@ def ask_ai(message, history, provider, model_choice, temperature, max_tokens, to
|
|
| 1426 |
|
| 1427 |
# Extract response
|
| 1428 |
ai_response = extract_ai_response(response, provider)
|
| 1429 |
-
|
| 1430 |
-
|
|
|
|
|
|
|
| 1431 |
except Exception as e:
|
| 1432 |
error_message = f"HuggingFace API Error: {str(e)}"
|
| 1433 |
logger.error(error_message)
|
| 1434 |
-
|
| 1435 |
-
|
|
|
|
|
|
|
| 1436 |
|
| 1437 |
elif provider == "Groq":
|
| 1438 |
# Get model ID from registry
|
| 1439 |
model_id, _ = get_model_info(provider, model_choice)
|
| 1440 |
if not model_id:
|
| 1441 |
error_message = f"Error: Model '{model_choice}' not found in Groq"
|
| 1442 |
-
|
| 1443 |
-
|
|
|
|
|
|
|
| 1444 |
|
| 1445 |
# Build Groq payload
|
| 1446 |
payload = {
|
|
@@ -1460,34 +1497,44 @@ def ask_ai(message, history, provider, model_choice, temperature, max_tokens, to
|
|
| 1460 |
|
| 1461 |
# Handle streaming response
|
| 1462 |
if stream_output:
|
| 1463 |
-
# Add
|
| 1464 |
-
|
| 1465 |
|
| 1466 |
# Set up generator for streaming updates
|
| 1467 |
def streaming_generator():
|
| 1468 |
-
|
| 1469 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1470 |
|
| 1471 |
return streaming_generator()
|
| 1472 |
|
| 1473 |
# Handle normal response
|
| 1474 |
else:
|
| 1475 |
ai_response = extract_ai_response(response, provider)
|
| 1476 |
-
|
| 1477 |
-
|
|
|
|
|
|
|
| 1478 |
except Exception as e:
|
| 1479 |
error_message = f"Groq API Error: {str(e)}"
|
| 1480 |
logger.error(error_message)
|
| 1481 |
-
|
| 1482 |
-
|
|
|
|
|
|
|
| 1483 |
|
| 1484 |
elif provider == "Cohere":
|
| 1485 |
# Get model ID from registry
|
| 1486 |
model_id, _ = get_model_info(provider, model_choice)
|
| 1487 |
if not model_id:
|
| 1488 |
error_message = f"Error: Model '{model_choice}' not found in Cohere"
|
| 1489 |
-
|
| 1490 |
-
|
|
|
|
|
|
|
| 1491 |
|
| 1492 |
# Build Cohere payload (doesn't support streaming the same way)
|
| 1493 |
payload = {
|
|
@@ -1505,21 +1552,27 @@ def ask_ai(message, history, provider, model_choice, temperature, max_tokens, to
|
|
| 1505 |
|
| 1506 |
# Extract response
|
| 1507 |
ai_response = extract_ai_response(response, provider)
|
| 1508 |
-
|
| 1509 |
-
|
|
|
|
|
|
|
| 1510 |
except Exception as e:
|
| 1511 |
error_message = f"Cohere API Error: {str(e)}"
|
| 1512 |
logger.error(error_message)
|
| 1513 |
-
|
| 1514 |
-
|
|
|
|
|
|
|
| 1515 |
|
| 1516 |
elif provider == "Together":
|
| 1517 |
# Get model ID from registry
|
| 1518 |
model_id, _ = get_model_info(provider, model_choice)
|
| 1519 |
if not model_id:
|
| 1520 |
error_message = f"Error: Model '{model_choice}' not found in Together"
|
| 1521 |
-
|
| 1522 |
-
|
|
|
|
|
|
|
| 1523 |
|
| 1524 |
# Build Together payload
|
| 1525 |
payload = {
|
|
@@ -1538,34 +1591,44 @@ def ask_ai(message, history, provider, model_choice, temperature, max_tokens, to
|
|
| 1538 |
|
| 1539 |
# Handle streaming response
|
| 1540 |
if stream_output:
|
| 1541 |
-
# Add
|
| 1542 |
-
|
| 1543 |
|
| 1544 |
# Set up generator for streaming updates
|
| 1545 |
def streaming_generator():
|
| 1546 |
-
|
| 1547 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1548 |
|
| 1549 |
return streaming_generator()
|
| 1550 |
|
| 1551 |
# Handle normal response
|
| 1552 |
else:
|
| 1553 |
ai_response = extract_ai_response(response, provider)
|
| 1554 |
-
|
| 1555 |
-
|
|
|
|
|
|
|
| 1556 |
except Exception as e:
|
| 1557 |
error_message = f"Together API Error: {str(e)}"
|
| 1558 |
logger.error(error_message)
|
| 1559 |
-
|
| 1560 |
-
|
|
|
|
|
|
|
| 1561 |
|
| 1562 |
elif provider == "OVH":
|
| 1563 |
# Get model ID from registry
|
| 1564 |
model_id, _ = get_model_info(provider, model_choice)
|
| 1565 |
if not model_id:
|
| 1566 |
error_message = f"Error: Model '{model_choice}' not found in OVH"
|
| 1567 |
-
|
| 1568 |
-
|
|
|
|
|
|
|
| 1569 |
|
| 1570 |
# Build OVH payload
|
| 1571 |
payload = {
|
|
@@ -1583,21 +1646,27 @@ def ask_ai(message, history, provider, model_choice, temperature, max_tokens, to
|
|
| 1583 |
|
| 1584 |
# Extract response
|
| 1585 |
ai_response = extract_ai_response(response, provider)
|
| 1586 |
-
|
| 1587 |
-
|
|
|
|
|
|
|
| 1588 |
except Exception as e:
|
| 1589 |
error_message = f"OVH API Error: {str(e)}"
|
| 1590 |
logger.error(error_message)
|
| 1591 |
-
|
| 1592 |
-
|
|
|
|
|
|
|
| 1593 |
|
| 1594 |
elif provider == "Cerebras":
|
| 1595 |
# Get model ID from registry
|
| 1596 |
model_id, _ = get_model_info(provider, model_choice)
|
| 1597 |
if not model_id:
|
| 1598 |
error_message = f"Error: Model '{model_choice}' not found in Cerebras"
|
| 1599 |
-
|
| 1600 |
-
|
|
|
|
|
|
|
| 1601 |
|
| 1602 |
# Build Cerebras payload
|
| 1603 |
payload = {
|
|
@@ -1615,21 +1684,27 @@ def ask_ai(message, history, provider, model_choice, temperature, max_tokens, to
|
|
| 1615 |
|
| 1616 |
# Extract response
|
| 1617 |
ai_response = extract_ai_response(response, provider)
|
| 1618 |
-
|
| 1619 |
-
|
|
|
|
|
|
|
| 1620 |
except Exception as e:
|
| 1621 |
error_message = f"Cerebras API Error: {str(e)}"
|
| 1622 |
logger.error(error_message)
|
| 1623 |
-
|
| 1624 |
-
|
|
|
|
|
|
|
| 1625 |
|
| 1626 |
elif provider == "GoogleAI":
|
| 1627 |
# Get model ID from registry
|
| 1628 |
model_id, _ = get_model_info(provider, model_choice)
|
| 1629 |
if not model_id:
|
| 1630 |
error_message = f"Error: Model '{model_choice}' not found in GoogleAI"
|
| 1631 |
-
|
| 1632 |
-
|
|
|
|
|
|
|
| 1633 |
|
| 1634 |
# Build GoogleAI payload
|
| 1635 |
payload = {
|
|
@@ -1648,24 +1723,32 @@ def ask_ai(message, history, provider, model_choice, temperature, max_tokens, to
|
|
| 1648 |
|
| 1649 |
# Extract response
|
| 1650 |
ai_response = extract_ai_response(response, provider)
|
| 1651 |
-
|
| 1652 |
-
|
|
|
|
|
|
|
| 1653 |
except Exception as e:
|
| 1654 |
error_message = f"GoogleAI API Error: {str(e)}"
|
| 1655 |
logger.error(error_message)
|
| 1656 |
-
|
| 1657 |
-
|
|
|
|
|
|
|
| 1658 |
|
| 1659 |
else:
|
| 1660 |
error_message = f"Error: Unsupported provider '{provider}'"
|
| 1661 |
-
|
| 1662 |
-
|
|
|
|
|
|
|
| 1663 |
|
| 1664 |
except Exception as e:
|
| 1665 |
error_message = f"Error: {str(e)}"
|
| 1666 |
logger.error(f"Exception during API call: {error_message}")
|
| 1667 |
-
|
| 1668 |
-
|
|
|
|
|
|
|
| 1669 |
|
| 1670 |
def clear_chat():
|
| 1671 |
"""Reset all inputs"""
|
|
@@ -2160,14 +2243,20 @@ def create_app():
|
|
| 2160 |
|
| 2161 |
def update_vision_indicator(provider, model_choice):
|
| 2162 |
"""Update the vision capability indicator"""
|
|
|
|
|
|
|
|
|
|
| 2163 |
return is_vision_model(provider, model_choice)
|
| 2164 |
|
| 2165 |
def update_image_upload_visibility(provider, model_choice):
|
| 2166 |
"""Show/hide image upload based on model vision capabilities"""
|
|
|
|
|
|
|
|
|
|
| 2167 |
is_vision = is_vision_model(provider, model_choice)
|
| 2168 |
return gr.update(visible=is_vision)
|
| 2169 |
|
| 2170 |
-
# Search model function
|
| 2171 |
def search_openrouter_models(search_term):
|
| 2172 |
"""Filter OpenRouter models based on search term"""
|
| 2173 |
all_models = [model[0] for model in OPENROUTER_ALL_MODELS]
|
|
@@ -2588,9 +2677,11 @@ def create_app():
|
|
| 2588 |
|
| 2589 |
# Check if model is selected
|
| 2590 |
if not model_choice:
|
| 2591 |
-
|
| 2592 |
-
|
| 2593 |
-
|
|
|
|
|
|
|
| 2594 |
|
| 2595 |
# Select the appropriate API key based on the provider
|
| 2596 |
api_key_override = None
|
|
|
|
| 629 |
|
| 630 |
def is_vision_model(provider, model_name):
|
| 631 |
"""Check if a model supports vision/images"""
|
| 632 |
+
# Safety check for None model name
|
| 633 |
+
if model_name is None:
|
| 634 |
+
return False
|
| 635 |
+
|
| 636 |
if provider in VISION_MODELS:
|
| 637 |
if model_name in VISION_MODELS[provider]:
|
| 638 |
return True
|
|
|
|
| 1136 |
# ==========================================================
|
| 1137 |
|
| 1138 |
def openrouter_streaming_handler(response, history, message):
|
| 1139 |
+
"""Handle streaming responses from OpenRouter"""
|
| 1140 |
try:
|
| 1141 |
updated_history = history + [{"role": "user", "content": message}]
|
| 1142 |
assistant_response = ""
|
|
|
|
| 1168 |
# Add error message to the current response
|
| 1169 |
yield updated_history + [{"role": "assistant", "content": f"Error during streaming: {str(e)}"}]
|
| 1170 |
|
| 1171 |
+
def openai_streaming_handler(response, history, message):
|
| 1172 |
+
"""Handle streaming responses from OpenAI"""
|
| 1173 |
try:
|
| 1174 |
+
updated_history = history + [{"role": "user", "content": message}]
|
| 1175 |
+
assistant_response = ""
|
| 1176 |
+
|
|
|
|
|
|
|
| 1177 |
for chunk in response:
|
| 1178 |
if hasattr(chunk.choices[0].delta, "content") and chunk.choices[0].delta.content is not None:
|
| 1179 |
content = chunk.choices[0].delta.content
|
| 1180 |
+
assistant_response += content
|
| 1181 |
+
yield updated_history + [{"role": "assistant", "content": assistant_response}]
|
|
|
|
|
|
|
| 1182 |
except Exception as e:
|
| 1183 |
logger.error(f"Error in OpenAI streaming handler: {str(e)}")
|
| 1184 |
# Add error message to the current response
|
| 1185 |
+
yield updated_history + [{"role": "assistant", "content": f"Error during streaming: {str(e)}"}]
|
|
|
|
| 1186 |
|
| 1187 |
+
def groq_streaming_handler(response, history, message):
|
| 1188 |
+
"""Handle streaming responses from Groq"""
|
| 1189 |
try:
|
| 1190 |
+
updated_history = history + [{"role": "user", "content": message}]
|
| 1191 |
+
assistant_response = ""
|
| 1192 |
+
|
|
|
|
|
|
|
| 1193 |
for chunk in response:
|
| 1194 |
if hasattr(chunk.choices[0].delta, "content") and chunk.choices[0].delta.content is not None:
|
| 1195 |
content = chunk.choices[0].delta.content
|
| 1196 |
+
assistant_response += content
|
| 1197 |
+
yield updated_history + [{"role": "assistant", "content": assistant_response}]
|
|
|
|
|
|
|
| 1198 |
except Exception as e:
|
| 1199 |
logger.error(f"Error in Groq streaming handler: {str(e)}")
|
| 1200 |
# Add error message to the current response
|
| 1201 |
+
yield updated_history + [{"role": "assistant", "content": f"Error during streaming: {str(e)}"}]
|
|
|
|
| 1202 |
|
| 1203 |
+
def together_streaming_handler(response, history, message):
|
| 1204 |
+
"""Handle streaming responses from Together"""
|
| 1205 |
try:
|
| 1206 |
+
updated_history = history + [{"role": "user", "content": message}]
|
| 1207 |
+
assistant_response = ""
|
| 1208 |
+
|
|
|
|
|
|
|
| 1209 |
for chunk in response:
|
| 1210 |
if hasattr(chunk.choices[0].delta, "content") and chunk.choices[0].delta.content is not None:
|
| 1211 |
content = chunk.choices[0].delta.content
|
| 1212 |
+
assistant_response += content
|
| 1213 |
+
yield updated_history + [{"role": "assistant", "content": assistant_response}]
|
|
|
|
|
|
|
| 1214 |
except Exception as e:
|
| 1215 |
logger.error(f"Error in Together streaming handler: {str(e)}")
|
| 1216 |
# Add error message to the current response
|
| 1217 |
+
yield updated_history + [{"role": "assistant", "content": f"Error during streaming: {str(e)}"}]
|
| 1218 |
+
|
|
|
|
| 1219 |
# ==========================================================
|
| 1220 |
# MAIN FUNCTION TO ASK AI
|
| 1221 |
# ==========================================================
|
|
|
|
| 1229 |
if not message.strip() and not images and not documents:
|
| 1230 |
return history
|
| 1231 |
|
| 1232 |
+
# Create messages from chat history for API requests
|
| 1233 |
+
messages = format_to_message_dict(history)
|
|
|
|
|
|
|
|
|
|
| 1234 |
|
| 1235 |
# Add system message if provided
|
| 1236 |
if system_message and system_message.strip():
|
|
|
|
| 1242 |
# Prepare message with images and documents if any
|
| 1243 |
content = prepare_message_with_media(message, images, documents)
|
| 1244 |
|
| 1245 |
+
# Add current message to API messages
|
| 1246 |
messages.append({"role": "user", "content": content})
|
| 1247 |
|
| 1248 |
# Common parameters for all providers
|
|
|
|
| 1262 |
model_id, _ = get_model_info(provider, model_choice)
|
| 1263 |
if not model_id:
|
| 1264 |
error_message = f"Error: Model '{model_choice}' not found in OpenRouter"
|
| 1265 |
+
# Use proper message format
|
| 1266 |
+
return history + [
|
| 1267 |
+
{"role": "user", "content": message},
|
| 1268 |
+
{"role": "assistant", "content": error_message}
|
| 1269 |
+
]
|
| 1270 |
|
| 1271 |
# Build OpenRouter payload
|
| 1272 |
payload = {
|
|
|
|
| 1312 |
|
| 1313 |
# Handle streaming response
|
| 1314 |
if stream_output and response.status_code == 200:
|
| 1315 |
+
# Add message to history
|
| 1316 |
+
updated_history = history + [{"role": "user", "content": message}]
|
| 1317 |
|
| 1318 |
# Set up generator for streaming updates
|
| 1319 |
def streaming_generator():
|
| 1320 |
+
assistant_response = ""
|
| 1321 |
+
for line in response.iter_lines():
|
| 1322 |
+
if not line:
|
| 1323 |
+
continue
|
| 1324 |
+
|
| 1325 |
+
line = line.decode('utf-8')
|
| 1326 |
+
if not line.startswith('data: '):
|
| 1327 |
+
continue
|
| 1328 |
+
|
| 1329 |
+
data = line[6:]
|
| 1330 |
+
if data.strip() == '[DONE]':
|
| 1331 |
+
break
|
| 1332 |
+
|
| 1333 |
+
try:
|
| 1334 |
+
chunk = json.loads(data)
|
| 1335 |
+
if "choices" in chunk and len(chunk["choices"]) > 0:
|
| 1336 |
+
delta = chunk["choices"][0].get("delta", {})
|
| 1337 |
+
if "content" in delta and delta["content"]:
|
| 1338 |
+
# Update the current response
|
| 1339 |
+
assistant_response += delta["content"]
|
| 1340 |
+
# Yield updated history with current assistant response
|
| 1341 |
+
yield updated_history + [{"role": "assistant", "content": assistant_response}]
|
| 1342 |
+
except json.JSONDecodeError:
|
| 1343 |
+
logger.error(f"Failed to parse JSON from chunk: {data}")
|
| 1344 |
|
| 1345 |
return streaming_generator()
|
| 1346 |
|
|
|
|
| 1352 |
# Extract AI response
|
| 1353 |
ai_response = extract_ai_response(result, provider)
|
| 1354 |
|
| 1355 |
+
# Add response to history with proper format
|
| 1356 |
+
return history + [
|
| 1357 |
+
{"role": "user", "content": message},
|
| 1358 |
+
{"role": "assistant", "content": ai_response}
|
| 1359 |
+
]
|
| 1360 |
|
| 1361 |
# Handle error response
|
| 1362 |
else:
|
|
|
|
| 1368 |
error_message += f"\n\nResponse: {response.text}"
|
| 1369 |
|
| 1370 |
logger.error(error_message)
|
| 1371 |
+
return history + [
|
| 1372 |
+
{"role": "user", "content": message},
|
| 1373 |
+
{"role": "assistant", "content": error_message}
|
| 1374 |
+
]
|
| 1375 |
|
| 1376 |
elif provider == "OpenAI":
|
| 1377 |
# Get model ID from registry
|
| 1378 |
model_id, _ = get_model_info(provider, model_choice)
|
| 1379 |
if not model_id:
|
| 1380 |
error_message = f"Error: Model '{model_choice}' not found in OpenAI"
|
| 1381 |
+
return history + [
|
| 1382 |
+
{"role": "user", "content": message},
|
| 1383 |
+
{"role": "assistant", "content": error_message}
|
| 1384 |
+
]
|
| 1385 |
|
| 1386 |
# Build OpenAI payload
|
| 1387 |
payload = {
|
|
|
|
| 1402 |
|
| 1403 |
# Handle streaming response
|
| 1404 |
if stream_output:
|
| 1405 |
+
# Add message to history
|
| 1406 |
+
updated_history = history + [{"role": "user", "content": message}]
|
| 1407 |
|
| 1408 |
# Set up generator for streaming updates
|
| 1409 |
def streaming_generator():
|
| 1410 |
+
assistant_response = ""
|
| 1411 |
+
for chunk in response:
|
| 1412 |
+
if hasattr(chunk.choices[0].delta, "content") and chunk.choices[0].delta.content is not None:
|
| 1413 |
+
content = chunk.choices[0].delta.content
|
| 1414 |
+
assistant_response += content
|
| 1415 |
+
yield updated_history + [{"role": "assistant", "content": assistant_response}]
|
| 1416 |
|
| 1417 |
return streaming_generator()
|
| 1418 |
|
| 1419 |
# Handle normal response
|
| 1420 |
else:
|
| 1421 |
ai_response = extract_ai_response(response, provider)
|
| 1422 |
+
return history + [
|
| 1423 |
+
{"role": "user", "content": message},
|
| 1424 |
+
{"role": "assistant", "content": ai_response}
|
| 1425 |
+
]
|
| 1426 |
except Exception as e:
|
| 1427 |
error_message = f"OpenAI API Error: {str(e)}"
|
| 1428 |
logger.error(error_message)
|
| 1429 |
+
return history + [
|
| 1430 |
+
{"role": "user", "content": message},
|
| 1431 |
+
{"role": "assistant", "content": error_message}
|
| 1432 |
+
]
|
| 1433 |
|
| 1434 |
elif provider == "HuggingFace":
|
| 1435 |
# Get model ID from registry
|
| 1436 |
model_id, _ = get_model_info(provider, model_choice)
|
| 1437 |
if not model_id:
|
| 1438 |
error_message = f"Error: Model '{model_choice}' not found in HuggingFace"
|
| 1439 |
+
return history + [
|
| 1440 |
+
{"role": "user", "content": message},
|
| 1441 |
+
{"role": "assistant", "content": error_message}
|
| 1442 |
+
]
|
| 1443 |
|
| 1444 |
# Build HuggingFace payload
|
| 1445 |
payload = {
|
|
|
|
| 1457 |
|
| 1458 |
# Extract response
|
| 1459 |
ai_response = extract_ai_response(response, provider)
|
| 1460 |
+
return history + [
|
| 1461 |
+
{"role": "user", "content": message},
|
| 1462 |
+
{"role": "assistant", "content": ai_response}
|
| 1463 |
+
]
|
| 1464 |
except Exception as e:
|
| 1465 |
error_message = f"HuggingFace API Error: {str(e)}"
|
| 1466 |
logger.error(error_message)
|
| 1467 |
+
return history + [
|
| 1468 |
+
{"role": "user", "content": message},
|
| 1469 |
+
{"role": "assistant", "content": error_message}
|
| 1470 |
+
]
|
| 1471 |
|
| 1472 |
elif provider == "Groq":
|
| 1473 |
# Get model ID from registry
|
| 1474 |
model_id, _ = get_model_info(provider, model_choice)
|
| 1475 |
if not model_id:
|
| 1476 |
error_message = f"Error: Model '{model_choice}' not found in Groq"
|
| 1477 |
+
return history + [
|
| 1478 |
+
{"role": "user", "content": message},
|
| 1479 |
+
{"role": "assistant", "content": error_message}
|
| 1480 |
+
]
|
| 1481 |
|
| 1482 |
# Build Groq payload
|
| 1483 |
payload = {
|
|
|
|
| 1497 |
|
| 1498 |
# Handle streaming response
|
| 1499 |
if stream_output:
|
| 1500 |
+
# Add message to history
|
| 1501 |
+
updated_history = history + [{"role": "user", "content": message}]
|
| 1502 |
|
| 1503 |
# Set up generator for streaming updates
|
| 1504 |
def streaming_generator():
|
| 1505 |
+
assistant_response = ""
|
| 1506 |
+
for chunk in response:
|
| 1507 |
+
if hasattr(chunk.choices[0].delta, "content") and chunk.choices[0].delta.content is not None:
|
| 1508 |
+
content = chunk.choices[0].delta.content
|
| 1509 |
+
assistant_response += content
|
| 1510 |
+
yield updated_history + [{"role": "assistant", "content": assistant_response}]
|
| 1511 |
|
| 1512 |
return streaming_generator()
|
| 1513 |
|
| 1514 |
# Handle normal response
|
| 1515 |
else:
|
| 1516 |
ai_response = extract_ai_response(response, provider)
|
| 1517 |
+
return history + [
|
| 1518 |
+
{"role": "user", "content": message},
|
| 1519 |
+
{"role": "assistant", "content": ai_response}
|
| 1520 |
+
]
|
| 1521 |
except Exception as e:
|
| 1522 |
error_message = f"Groq API Error: {str(e)}"
|
| 1523 |
logger.error(error_message)
|
| 1524 |
+
return history + [
|
| 1525 |
+
{"role": "user", "content": message},
|
| 1526 |
+
{"role": "assistant", "content": error_message}
|
| 1527 |
+
]
|
| 1528 |
|
| 1529 |
elif provider == "Cohere":
|
| 1530 |
# Get model ID from registry
|
| 1531 |
model_id, _ = get_model_info(provider, model_choice)
|
| 1532 |
if not model_id:
|
| 1533 |
error_message = f"Error: Model '{model_choice}' not found in Cohere"
|
| 1534 |
+
return history + [
|
| 1535 |
+
{"role": "user", "content": message},
|
| 1536 |
+
{"role": "assistant", "content": error_message}
|
| 1537 |
+
]
|
| 1538 |
|
| 1539 |
# Build Cohere payload (doesn't support streaming the same way)
|
| 1540 |
payload = {
|
|
|
|
| 1552 |
|
| 1553 |
# Extract response
|
| 1554 |
ai_response = extract_ai_response(response, provider)
|
| 1555 |
+
return history + [
|
| 1556 |
+
{"role": "user", "content": message},
|
| 1557 |
+
{"role": "assistant", "content": ai_response}
|
| 1558 |
+
]
|
| 1559 |
except Exception as e:
|
| 1560 |
error_message = f"Cohere API Error: {str(e)}"
|
| 1561 |
logger.error(error_message)
|
| 1562 |
+
return history + [
|
| 1563 |
+
{"role": "user", "content": message},
|
| 1564 |
+
{"role": "assistant", "content": error_message}
|
| 1565 |
+
]
|
| 1566 |
|
| 1567 |
elif provider == "Together":
|
| 1568 |
# Get model ID from registry
|
| 1569 |
model_id, _ = get_model_info(provider, model_choice)
|
| 1570 |
if not model_id:
|
| 1571 |
error_message = f"Error: Model '{model_choice}' not found in Together"
|
| 1572 |
+
return history + [
|
| 1573 |
+
{"role": "user", "content": message},
|
| 1574 |
+
{"role": "assistant", "content": error_message}
|
| 1575 |
+
]
|
| 1576 |
|
| 1577 |
# Build Together payload
|
| 1578 |
payload = {
|
|
|
|
| 1591 |
|
| 1592 |
# Handle streaming response
|
| 1593 |
if stream_output:
|
| 1594 |
+
# Add message to history
|
| 1595 |
+
updated_history = history + [{"role": "user", "content": message}]
|
| 1596 |
|
| 1597 |
# Set up generator for streaming updates
|
| 1598 |
def streaming_generator():
|
| 1599 |
+
assistant_response = ""
|
| 1600 |
+
for chunk in response:
|
| 1601 |
+
if hasattr(chunk.choices[0].delta, "content") and chunk.choices[0].delta.content is not None:
|
| 1602 |
+
content = chunk.choices[0].delta.content
|
| 1603 |
+
assistant_response += content
|
| 1604 |
+
yield updated_history + [{"role": "assistant", "content": assistant_response}]
|
| 1605 |
|
| 1606 |
return streaming_generator()
|
| 1607 |
|
| 1608 |
# Handle normal response
|
| 1609 |
else:
|
| 1610 |
ai_response = extract_ai_response(response, provider)
|
| 1611 |
+
return history + [
|
| 1612 |
+
{"role": "user", "content": message},
|
| 1613 |
+
{"role": "assistant", "content": ai_response}
|
| 1614 |
+
]
|
| 1615 |
except Exception as e:
|
| 1616 |
error_message = f"Together API Error: {str(e)}"
|
| 1617 |
logger.error(error_message)
|
| 1618 |
+
return history + [
|
| 1619 |
+
{"role": "user", "content": message},
|
| 1620 |
+
{"role": "assistant", "content": error_message}
|
| 1621 |
+
]
|
| 1622 |
|
| 1623 |
elif provider == "OVH":
|
| 1624 |
# Get model ID from registry
|
| 1625 |
model_id, _ = get_model_info(provider, model_choice)
|
| 1626 |
if not model_id:
|
| 1627 |
error_message = f"Error: Model '{model_choice}' not found in OVH"
|
| 1628 |
+
return history + [
|
| 1629 |
+
{"role": "user", "content": message},
|
| 1630 |
+
{"role": "assistant", "content": error_message}
|
| 1631 |
+
]
|
| 1632 |
|
| 1633 |
# Build OVH payload
|
| 1634 |
payload = {
|
|
|
|
| 1646 |
|
| 1647 |
# Extract response
|
| 1648 |
ai_response = extract_ai_response(response, provider)
|
| 1649 |
+
return history + [
|
| 1650 |
+
{"role": "user", "content": message},
|
| 1651 |
+
{"role": "assistant", "content": ai_response}
|
| 1652 |
+
]
|
| 1653 |
except Exception as e:
|
| 1654 |
error_message = f"OVH API Error: {str(e)}"
|
| 1655 |
logger.error(error_message)
|
| 1656 |
+
return history + [
|
| 1657 |
+
{"role": "user", "content": message},
|
| 1658 |
+
{"role": "assistant", "content": error_message}
|
| 1659 |
+
]
|
| 1660 |
|
| 1661 |
elif provider == "Cerebras":
|
| 1662 |
# Get model ID from registry
|
| 1663 |
model_id, _ = get_model_info(provider, model_choice)
|
| 1664 |
if not model_id:
|
| 1665 |
error_message = f"Error: Model '{model_choice}' not found in Cerebras"
|
| 1666 |
+
return history + [
|
| 1667 |
+
{"role": "user", "content": message},
|
| 1668 |
+
{"role": "assistant", "content": error_message}
|
| 1669 |
+
]
|
| 1670 |
|
| 1671 |
# Build Cerebras payload
|
| 1672 |
payload = {
|
|
|
|
| 1684 |
|
| 1685 |
# Extract response
|
| 1686 |
ai_response = extract_ai_response(response, provider)
|
| 1687 |
+
return history + [
|
| 1688 |
+
{"role": "user", "content": message},
|
| 1689 |
+
{"role": "assistant", "content": ai_response}
|
| 1690 |
+
]
|
| 1691 |
except Exception as e:
|
| 1692 |
error_message = f"Cerebras API Error: {str(e)}"
|
| 1693 |
logger.error(error_message)
|
| 1694 |
+
return history + [
|
| 1695 |
+
{"role": "user", "content": message},
|
| 1696 |
+
{"role": "assistant", "content": error_message}
|
| 1697 |
+
]
|
| 1698 |
|
| 1699 |
elif provider == "GoogleAI":
|
| 1700 |
# Get model ID from registry
|
| 1701 |
model_id, _ = get_model_info(provider, model_choice)
|
| 1702 |
if not model_id:
|
| 1703 |
error_message = f"Error: Model '{model_choice}' not found in GoogleAI"
|
| 1704 |
+
return history + [
|
| 1705 |
+
{"role": "user", "content": message},
|
| 1706 |
+
{"role": "assistant", "content": error_message}
|
| 1707 |
+
]
|
| 1708 |
|
| 1709 |
# Build GoogleAI payload
|
| 1710 |
payload = {
|
|
|
|
| 1723 |
|
| 1724 |
# Extract response
|
| 1725 |
ai_response = extract_ai_response(response, provider)
|
| 1726 |
+
return history + [
|
| 1727 |
+
{"role": "user", "content": message},
|
| 1728 |
+
{"role": "assistant", "content": ai_response}
|
| 1729 |
+
]
|
| 1730 |
except Exception as e:
|
| 1731 |
error_message = f"GoogleAI API Error: {str(e)}"
|
| 1732 |
logger.error(error_message)
|
| 1733 |
+
return history + [
|
| 1734 |
+
{"role": "user", "content": message},
|
| 1735 |
+
{"role": "assistant", "content": error_message}
|
| 1736 |
+
]
|
| 1737 |
|
| 1738 |
else:
|
| 1739 |
error_message = f"Error: Unsupported provider '{provider}'"
|
| 1740 |
+
return history + [
|
| 1741 |
+
{"role": "user", "content": message},
|
| 1742 |
+
{"role": "assistant", "content": error_message}
|
| 1743 |
+
]
|
| 1744 |
|
| 1745 |
except Exception as e:
|
| 1746 |
error_message = f"Error: {str(e)}"
|
| 1747 |
logger.error(f"Exception during API call: {error_message}")
|
| 1748 |
+
return history + [
|
| 1749 |
+
{"role": "user", "content": message},
|
| 1750 |
+
{"role": "assistant", "content": error_message}
|
| 1751 |
+
]
|
| 1752 |
|
| 1753 |
def clear_chat():
|
| 1754 |
"""Reset all inputs"""
|
|
|
|
| 2243 |
|
| 2244 |
def update_vision_indicator(provider, model_choice):
|
| 2245 |
"""Update the vision capability indicator"""
|
| 2246 |
+
# Safety check for None model
|
| 2247 |
+
if model_choice is None:
|
| 2248 |
+
return False
|
| 2249 |
return is_vision_model(provider, model_choice)
|
| 2250 |
|
| 2251 |
def update_image_upload_visibility(provider, model_choice):
|
| 2252 |
"""Show/hide image upload based on model vision capabilities"""
|
| 2253 |
+
# Safety check for None model
|
| 2254 |
+
if model_choice is None:
|
| 2255 |
+
return gr.update(visible=False)
|
| 2256 |
is_vision = is_vision_model(provider, model_choice)
|
| 2257 |
return gr.update(visible=is_vision)
|
| 2258 |
|
| 2259 |
+
# Search model function
|
| 2260 |
def search_openrouter_models(search_term):
|
| 2261 |
"""Filter OpenRouter models based on search term"""
|
| 2262 |
all_models = [model[0] for model in OPENROUTER_ALL_MODELS]
|
|
|
|
| 2677 |
|
| 2678 |
# Check if model is selected
|
| 2679 |
if not model_choice:
|
| 2680 |
+
error_message = f"Error: No model selected for provider {provider}"
|
| 2681 |
+
return history + [
|
| 2682 |
+
{"role": "user", "content": message},
|
| 2683 |
+
{"role": "assistant", "content": error_message}
|
| 2684 |
+
]
|
| 2685 |
|
| 2686 |
# Select the appropriate API key based on the provider
|
| 2687 |
api_key_override = None
|