Spaces:
Sleeping
Sleeping
Updated app.py
Browse files
app.py
CHANGED
|
@@ -79,8 +79,41 @@ def parse_model_string(model):
|
|
| 79 |
return "openai", model
|
| 80 |
|
| 81 |
|
| 82 |
-
|
| 83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
# Extract the last user message
|
| 86 |
last_user_message = ""
|
|
@@ -109,7 +142,7 @@ def build_onyx_payload(messages, model_provider, model_version, temperature, cha
|
|
| 109 |
"message": full_message,
|
| 110 |
"chat_session_id": chat_session_id,
|
| 111 |
"parent_message_id": parent_message_id if parent_message_id else None,
|
| 112 |
-
"stream":
|
| 113 |
"llm_override": {
|
| 114 |
"model_provider": model_provider,
|
| 115 |
"model_version": model_version,
|
|
@@ -126,38 +159,83 @@ def build_onyx_payload(messages, model_provider, model_version, temperature, cha
|
|
| 126 |
|
| 127 |
|
| 128 |
def parse_onyx_stream_chunk(chunk_text):
|
| 129 |
-
"""Parse a chunk from Onyx stream and extract the text content
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
if not chunk_text or not chunk_text.strip():
|
| 131 |
-
return None, None
|
| 132 |
|
| 133 |
try:
|
| 134 |
data = json.loads(chunk_text)
|
| 135 |
|
| 136 |
-
if isinstance(data, dict):
|
| 137 |
-
|
| 138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
elif
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
elif
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
except json.JSONDecodeError:
|
| 156 |
# Not JSON, might be raw text
|
| 157 |
if chunk_text.strip() and not chunk_text.strip().startswith('{'):
|
| 158 |
-
return chunk_text.strip(), None
|
| 159 |
|
| 160 |
-
return None, None
|
| 161 |
|
| 162 |
|
| 163 |
def generate_openai_stream_chunk(content, model, chunk_id, finish_reason=None):
|
|
@@ -183,8 +261,8 @@ def stream_onyx_response(payload, model, session_key):
|
|
| 183 |
|
| 184 |
# Try alternate endpoints if needed
|
| 185 |
endpoints = [
|
| 186 |
-
f"{ONYX_BASE_URL}/api/chat/send-message",
|
| 187 |
-
f"{ONYX_BASE_URL}/api/chat/send-
|
| 188 |
]
|
| 189 |
|
| 190 |
# Send initial chunk with role
|
|
@@ -250,20 +328,22 @@ def stream_onyx_response(payload, model, session_key):
|
|
| 250 |
if line == '[DONE]':
|
| 251 |
continue
|
| 252 |
|
| 253 |
-
content, msg_id = parse_onyx_stream_chunk(line)
|
| 254 |
if msg_id:
|
| 255 |
last_message_id = msg_id
|
| 256 |
-
if
|
|
|
|
|
|
|
| 257 |
yield generate_openai_stream_chunk(content, model, chunk_id)
|
| 258 |
|
| 259 |
# Process remaining buffer
|
| 260 |
if buffer.strip():
|
| 261 |
if buffer.strip().startswith('data: '):
|
| 262 |
buffer = buffer.strip()[6:]
|
| 263 |
-
content, msg_id = parse_onyx_stream_chunk(buffer.strip())
|
| 264 |
if msg_id:
|
| 265 |
last_message_id = msg_id
|
| 266 |
-
if content:
|
| 267 |
yield generate_openai_stream_chunk(content, model, chunk_id)
|
| 268 |
|
| 269 |
# Update session with last message ID
|
|
@@ -287,8 +367,8 @@ def collect_full_response(payload, model, session_key):
|
|
| 287 |
last_message_id = None
|
| 288 |
|
| 289 |
endpoints = [
|
| 290 |
-
f"{ONYX_BASE_URL}/api/chat/send-message",
|
| 291 |
-
f"{ONYX_BASE_URL}/api/chat/send-
|
| 292 |
]
|
| 293 |
|
| 294 |
for url in endpoints:
|
|
@@ -338,19 +418,21 @@ def collect_full_response(payload, model, session_key):
|
|
| 338 |
if line == '[DONE]':
|
| 339 |
continue
|
| 340 |
|
| 341 |
-
content, msg_id = parse_onyx_stream_chunk(line)
|
| 342 |
if msg_id:
|
| 343 |
last_message_id = msg_id
|
| 344 |
-
if
|
|
|
|
|
|
|
| 345 |
full_content += content
|
| 346 |
|
| 347 |
if buffer.strip():
|
| 348 |
if buffer.strip().startswith('data: '):
|
| 349 |
buffer = buffer.strip()[6:]
|
| 350 |
-
content, msg_id = parse_onyx_stream_chunk(buffer.strip())
|
| 351 |
if msg_id:
|
| 352 |
last_message_id = msg_id
|
| 353 |
-
if content:
|
| 354 |
full_content += content
|
| 355 |
|
| 356 |
# Update session
|
|
@@ -424,8 +506,9 @@ def chat_completions():
|
|
| 424 |
}
|
| 425 |
}), 400
|
| 426 |
|
| 427 |
-
# Parse model string
|
| 428 |
model_provider, model_version = parse_model_string(model)
|
|
|
|
| 429 |
print(f"Model provider: {model_provider}, version: {model_version}")
|
| 430 |
|
| 431 |
# Get or create chat session
|
|
@@ -446,7 +529,8 @@ def chat_completions():
|
|
| 446 |
model_version=model_version,
|
| 447 |
temperature=temperature,
|
| 448 |
chat_session_id=session_info['session_id'],
|
| 449 |
-
parent_message_id=session_info.get('parent_message_id')
|
|
|
|
| 450 |
)
|
| 451 |
|
| 452 |
if stream:
|
|
|
|
| 79 |
return "openai", model
|
| 80 |
|
| 81 |
|
| 82 |
+
# Known provider name mappings
|
| 83 |
+
# Update these based on what's configured in Onyx Cloud admin panel
|
| 84 |
+
PROVIDER_ALIASES = {
|
| 85 |
+
"openai": "openai",
|
| 86 |
+
"anthropic": "anthropic",
|
| 87 |
+
"google": "google",
|
| 88 |
+
"azure": "azure",
|
| 89 |
+
"bedrock": "bedrock",
|
| 90 |
+
"cohere": "cohere",
|
| 91 |
+
"mistral": "mistral",
|
| 92 |
+
# Add more aliases as needed
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def normalize_provider_name(provider):
|
| 97 |
+
"""
|
| 98 |
+
Normalize provider name to match Onyx configuration.
|
| 99 |
+
Handles case sensitivity and common aliases.
|
| 100 |
+
"""
|
| 101 |
+
provider_lower = provider.lower().strip()
|
| 102 |
+
return PROVIDER_ALIASES.get(provider_lower, provider_lower)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def build_onyx_payload(messages, model_provider, model_version, temperature, chat_session_id, parent_message_id=None, stream=True):
|
| 106 |
+
"""Convert OpenAI format to Onyx payload
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
messages: List of OpenAI format messages
|
| 110 |
+
model_provider: LLM provider name (e.g., 'openai', 'anthropic')
|
| 111 |
+
model_version: Model version (e.g., 'gpt-4', 'claude-3-opus-20240229')
|
| 112 |
+
temperature: Temperature setting for generation
|
| 113 |
+
chat_session_id: Onyx chat session ID
|
| 114 |
+
parent_message_id: Optional parent message ID for threading
|
| 115 |
+
stream: Whether to stream the response (default True)
|
| 116 |
+
"""
|
| 117 |
|
| 118 |
# Extract the last user message
|
| 119 |
last_user_message = ""
|
|
|
|
| 142 |
"message": full_message,
|
| 143 |
"chat_session_id": chat_session_id,
|
| 144 |
"parent_message_id": parent_message_id if parent_message_id else None,
|
| 145 |
+
"stream": stream, # Now respects caller's preference
|
| 146 |
"llm_override": {
|
| 147 |
"model_provider": model_provider,
|
| 148 |
"model_version": model_version,
|
|
|
|
| 159 |
|
| 160 |
|
| 161 |
def parse_onyx_stream_chunk(chunk_text):
|
| 162 |
+
"""Parse a chunk from Onyx stream and extract the text content.
|
| 163 |
+
|
| 164 |
+
New Onyx API uses packet-based format:
|
| 165 |
+
- First packet: {"user_message_id": int, "reserved_assistant_message_id": int}
|
| 166 |
+
- Content packets: {"ind": int, "obj": {"type": "message_delta", "content": "..."}}
|
| 167 |
+
- Stop packet: {"ind": int, "obj": {"type": "stop"}}
|
| 168 |
+
|
| 169 |
+
Returns:
|
| 170 |
+
tuple: (content, message_id, packet_type)
|
| 171 |
+
"""
|
| 172 |
if not chunk_text or not chunk_text.strip():
|
| 173 |
+
return None, None, None
|
| 174 |
|
| 175 |
try:
|
| 176 |
data = json.loads(chunk_text)
|
| 177 |
|
| 178 |
+
if not isinstance(data, dict):
|
| 179 |
+
return None, None, None
|
| 180 |
+
|
| 181 |
+
# Handle first packet (message IDs)
|
| 182 |
+
if 'user_message_id' in data or 'reserved_assistant_message_id' in data:
|
| 183 |
+
return None, data.get('reserved_assistant_message_id'), 'message_ids'
|
| 184 |
+
|
| 185 |
+
# Handle new packet-based format
|
| 186 |
+
if 'obj' in data:
|
| 187 |
+
obj = data['obj']
|
| 188 |
+
packet_type = obj.get('type', '')
|
| 189 |
|
| 190 |
+
if packet_type == 'message_delta':
|
| 191 |
+
# This is the actual content!
|
| 192 |
+
content = obj.get('content', '')
|
| 193 |
+
return content, None, 'content'
|
| 194 |
+
|
| 195 |
+
elif packet_type == 'message_start':
|
| 196 |
+
# Contains final_documents, not content
|
| 197 |
+
return None, None, 'message_start'
|
| 198 |
+
|
| 199 |
+
elif packet_type == 'stop':
|
| 200 |
+
# End of stream
|
| 201 |
+
return None, None, 'stop'
|
| 202 |
+
|
| 203 |
+
elif packet_type == 'error':
|
| 204 |
+
error_msg = obj.get('message', obj.get('error', 'Unknown error'))
|
| 205 |
+
return f"[Error: {error_msg}]", None, 'error'
|
| 206 |
+
|
| 207 |
+
elif packet_type == 'citation_delta':
|
| 208 |
+
# Citation info, not content
|
| 209 |
+
return None, None, 'citation'
|
| 210 |
+
|
| 211 |
+
elif packet_type in ['reasoning_start', 'reasoning_delta', 'reasoning_done']:
|
| 212 |
+
# Reasoning packets
|
| 213 |
+
return None, None, 'reasoning'
|
| 214 |
|
| 215 |
+
else:
|
| 216 |
+
# Other packet types (search, tools, etc.)
|
| 217 |
+
return None, None, packet_type
|
| 218 |
+
|
| 219 |
+
# FALLBACK: Old format support (for backward compatibility)
|
| 220 |
+
message_id = data.get('message_id')
|
| 221 |
+
|
| 222 |
+
if 'answer_piece' in data:
|
| 223 |
+
return data['answer_piece'], message_id, 'legacy'
|
| 224 |
+
elif 'text' in data:
|
| 225 |
+
return data['text'], message_id, 'legacy'
|
| 226 |
+
elif 'content' in data and isinstance(data['content'], str):
|
| 227 |
+
return data['content'], message_id, 'legacy'
|
| 228 |
+
elif 'error' in data:
|
| 229 |
+
return f"[Error: {data['error']}]", message_id, 'error'
|
| 230 |
+
|
| 231 |
+
return None, None, None
|
| 232 |
+
|
| 233 |
except json.JSONDecodeError:
|
| 234 |
# Not JSON, might be raw text
|
| 235 |
if chunk_text.strip() and not chunk_text.strip().startswith('{'):
|
| 236 |
+
return chunk_text.strip(), None, 'raw'
|
| 237 |
|
| 238 |
+
return None, None, None
|
| 239 |
|
| 240 |
|
| 241 |
def generate_openai_stream_chunk(content, model, chunk_id, finish_reason=None):
|
|
|
|
| 261 |
|
| 262 |
# Try alternate endpoints if needed
|
| 263 |
endpoints = [
|
| 264 |
+
f"{ONYX_BASE_URL}/api/chat/send-chat-message", # Primary (new)
|
| 265 |
+
f"{ONYX_BASE_URL}/api/chat/send-message", # Fallback (deprecated)
|
| 266 |
]
|
| 267 |
|
| 268 |
# Send initial chunk with role
|
|
|
|
| 328 |
if line == '[DONE]':
|
| 329 |
continue
|
| 330 |
|
| 331 |
+
content, msg_id, packet_type = parse_onyx_stream_chunk(line)
|
| 332 |
if msg_id:
|
| 333 |
last_message_id = msg_id
|
| 334 |
+
if packet_type == 'stop':
|
| 335 |
+
break
|
| 336 |
+
if content and packet_type in ['content', 'legacy', 'raw', 'error']:
|
| 337 |
yield generate_openai_stream_chunk(content, model, chunk_id)
|
| 338 |
|
| 339 |
# Process remaining buffer
|
| 340 |
if buffer.strip():
|
| 341 |
if buffer.strip().startswith('data: '):
|
| 342 |
buffer = buffer.strip()[6:]
|
| 343 |
+
content, msg_id, packet_type = parse_onyx_stream_chunk(buffer.strip())
|
| 344 |
if msg_id:
|
| 345 |
last_message_id = msg_id
|
| 346 |
+
if content and packet_type in ['content', 'legacy', 'raw', 'error']:
|
| 347 |
yield generate_openai_stream_chunk(content, model, chunk_id)
|
| 348 |
|
| 349 |
# Update session with last message ID
|
|
|
|
| 367 |
last_message_id = None
|
| 368 |
|
| 369 |
endpoints = [
|
| 370 |
+
f"{ONYX_BASE_URL}/api/chat/send-chat-message", # Primary (new)
|
| 371 |
+
f"{ONYX_BASE_URL}/api/chat/send-message", # Fallback (deprecated)
|
| 372 |
]
|
| 373 |
|
| 374 |
for url in endpoints:
|
|
|
|
| 418 |
if line == '[DONE]':
|
| 419 |
continue
|
| 420 |
|
| 421 |
+
content, msg_id, packet_type = parse_onyx_stream_chunk(line)
|
| 422 |
if msg_id:
|
| 423 |
last_message_id = msg_id
|
| 424 |
+
if packet_type == 'stop':
|
| 425 |
+
break
|
| 426 |
+
if content and packet_type in ['content', 'legacy', 'raw', 'error']:
|
| 427 |
full_content += content
|
| 428 |
|
| 429 |
if buffer.strip():
|
| 430 |
if buffer.strip().startswith('data: '):
|
| 431 |
buffer = buffer.strip()[6:]
|
| 432 |
+
content, msg_id, packet_type = parse_onyx_stream_chunk(buffer.strip())
|
| 433 |
if msg_id:
|
| 434 |
last_message_id = msg_id
|
| 435 |
+
if content and packet_type in ['content', 'legacy', 'raw', 'error']:
|
| 436 |
full_content += content
|
| 437 |
|
| 438 |
# Update session
|
|
|
|
| 506 |
}
|
| 507 |
}), 400
|
| 508 |
|
| 509 |
+
# Parse model string and normalize provider name
|
| 510 |
model_provider, model_version = parse_model_string(model)
|
| 511 |
+
model_provider = normalize_provider_name(model_provider)
|
| 512 |
print(f"Model provider: {model_provider}, version: {model_version}")
|
| 513 |
|
| 514 |
# Get or create chat session
|
|
|
|
| 529 |
model_version=model_version,
|
| 530 |
temperature=temperature,
|
| 531 |
chat_session_id=session_info['session_id'],
|
| 532 |
+
parent_message_id=session_info.get('parent_message_id'),
|
| 533 |
+
stream=stream # Pass client's streaming preference
|
| 534 |
)
|
| 535 |
|
| 536 |
if stream:
|