shiv-4567892009 commited on
Commit
7457d51
·
verified ·
1 Parent(s): c491677

Updated app.py

Browse files
Files changed (1) hide show
  1. app.py +122 -38
app.py CHANGED
@@ -79,8 +79,41 @@ def parse_model_string(model):
79
  return "openai", model
80
 
81
 
82
- def build_onyx_payload(messages, model_provider, model_version, temperature, chat_session_id, parent_message_id=None):
83
- """Convert OpenAI format to Onyx payload"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  # Extract the last user message
86
  last_user_message = ""
@@ -109,7 +142,7 @@ def build_onyx_payload(messages, model_provider, model_version, temperature, cha
109
  "message": full_message,
110
  "chat_session_id": chat_session_id,
111
  "parent_message_id": parent_message_id if parent_message_id else None,
112
- "stream": True,
113
  "llm_override": {
114
  "model_provider": model_provider,
115
  "model_version": model_version,
@@ -126,38 +159,83 @@ def build_onyx_payload(messages, model_provider, model_version, temperature, cha
126
 
127
 
128
  def parse_onyx_stream_chunk(chunk_text):
129
- """Parse a chunk from Onyx stream and extract the text content"""
 
 
 
 
 
 
 
 
 
130
  if not chunk_text or not chunk_text.strip():
131
- return None, None
132
 
133
  try:
134
  data = json.loads(chunk_text)
135
 
136
- if isinstance(data, dict):
137
- # Extract message ID for tracking conversation
138
- message_id = data.get('message_id')
 
 
 
 
 
 
 
 
139
 
140
- # Check for different content fields
141
- if 'answer_piece' in data:
142
- return data['answer_piece'], message_id
143
- elif 'text' in data:
144
- return data['text'], message_id
145
- elif 'content' in data:
146
- return data['content'], message_id
147
- elif 'message' in data and isinstance(data['message'], str):
148
- return data['message'], message_id
149
- elif 'error' in data:
150
- return f"[Error: {data['error']}]", message_id
151
-
152
- elif isinstance(data, str):
153
- return data, None
 
 
 
 
 
 
 
 
 
 
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  except json.JSONDecodeError:
156
  # Not JSON, might be raw text
157
  if chunk_text.strip() and not chunk_text.strip().startswith('{'):
158
- return chunk_text.strip(), None
159
 
160
- return None, None
161
 
162
 
163
  def generate_openai_stream_chunk(content, model, chunk_id, finish_reason=None):
@@ -183,8 +261,8 @@ def stream_onyx_response(payload, model, session_key):
183
 
184
  # Try alternate endpoints if needed
185
  endpoints = [
186
- f"{ONYX_BASE_URL}/api/chat/send-message",
187
- f"{ONYX_BASE_URL}/api/chat/send-chat-message",
188
  ]
189
 
190
  # Send initial chunk with role
@@ -250,20 +328,22 @@ def stream_onyx_response(payload, model, session_key):
250
  if line == '[DONE]':
251
  continue
252
 
253
- content, msg_id = parse_onyx_stream_chunk(line)
254
  if msg_id:
255
  last_message_id = msg_id
256
- if content:
 
 
257
  yield generate_openai_stream_chunk(content, model, chunk_id)
258
 
259
  # Process remaining buffer
260
  if buffer.strip():
261
  if buffer.strip().startswith('data: '):
262
  buffer = buffer.strip()[6:]
263
- content, msg_id = parse_onyx_stream_chunk(buffer.strip())
264
  if msg_id:
265
  last_message_id = msg_id
266
- if content:
267
  yield generate_openai_stream_chunk(content, model, chunk_id)
268
 
269
  # Update session with last message ID
@@ -287,8 +367,8 @@ def collect_full_response(payload, model, session_key):
287
  last_message_id = None
288
 
289
  endpoints = [
290
- f"{ONYX_BASE_URL}/api/chat/send-message",
291
- f"{ONYX_BASE_URL}/api/chat/send-chat-message",
292
  ]
293
 
294
  for url in endpoints:
@@ -338,19 +418,21 @@ def collect_full_response(payload, model, session_key):
338
  if line == '[DONE]':
339
  continue
340
 
341
- content, msg_id = parse_onyx_stream_chunk(line)
342
  if msg_id:
343
  last_message_id = msg_id
344
- if content:
 
 
345
  full_content += content
346
 
347
  if buffer.strip():
348
  if buffer.strip().startswith('data: '):
349
  buffer = buffer.strip()[6:]
350
- content, msg_id = parse_onyx_stream_chunk(buffer.strip())
351
  if msg_id:
352
  last_message_id = msg_id
353
- if content:
354
  full_content += content
355
 
356
  # Update session
@@ -424,8 +506,9 @@ def chat_completions():
424
  }
425
  }), 400
426
 
427
- # Parse model string
428
  model_provider, model_version = parse_model_string(model)
 
429
  print(f"Model provider: {model_provider}, version: {model_version}")
430
 
431
  # Get or create chat session
@@ -446,7 +529,8 @@ def chat_completions():
446
  model_version=model_version,
447
  temperature=temperature,
448
  chat_session_id=session_info['session_id'],
449
- parent_message_id=session_info.get('parent_message_id')
 
450
  )
451
 
452
  if stream:
 
79
  return "openai", model
80
 
81
 
82
+ # Known provider name mappings
83
+ # Update these based on what's configured in Onyx Cloud admin panel
84
+ PROVIDER_ALIASES = {
85
+ "openai": "openai",
86
+ "anthropic": "anthropic",
87
+ "google": "google",
88
+ "azure": "azure",
89
+ "bedrock": "bedrock",
90
+ "cohere": "cohere",
91
+ "mistral": "mistral",
92
+ # Add more aliases as needed
93
+ }
94
+
95
+
96
+ def normalize_provider_name(provider):
97
+ """
98
+ Normalize provider name to match Onyx configuration.
99
+ Handles case sensitivity and common aliases.
100
+ """
101
+ provider_lower = provider.lower().strip()
102
+ return PROVIDER_ALIASES.get(provider_lower, provider_lower)
103
+
104
+
105
+ def build_onyx_payload(messages, model_provider, model_version, temperature, chat_session_id, parent_message_id=None, stream=True):
106
+ """Convert OpenAI format to Onyx payload
107
+
108
+ Args:
109
+ messages: List of OpenAI format messages
110
+ model_provider: LLM provider name (e.g., 'openai', 'anthropic')
111
+ model_version: Model version (e.g., 'gpt-4', 'claude-3-opus-20240229')
112
+ temperature: Temperature setting for generation
113
+ chat_session_id: Onyx chat session ID
114
+ parent_message_id: Optional parent message ID for threading
115
+ stream: Whether to stream the response (default True)
116
+ """
117
 
118
  # Extract the last user message
119
  last_user_message = ""
 
142
  "message": full_message,
143
  "chat_session_id": chat_session_id,
144
  "parent_message_id": parent_message_id if parent_message_id else None,
145
+ "stream": stream, # Now respects caller's preference
146
  "llm_override": {
147
  "model_provider": model_provider,
148
  "model_version": model_version,
 
159
 
160
 
161
  def parse_onyx_stream_chunk(chunk_text):
162
+ """Parse a chunk from Onyx stream and extract the text content.
163
+
164
+ New Onyx API uses packet-based format:
165
+ - First packet: {"user_message_id": int, "reserved_assistant_message_id": int}
166
+ - Content packets: {"ind": int, "obj": {"type": "message_delta", "content": "..."}}
167
+ - Stop packet: {"ind": int, "obj": {"type": "stop"}}
168
+
169
+ Returns:
170
+ tuple: (content, message_id, packet_type)
171
+ """
172
  if not chunk_text or not chunk_text.strip():
173
+ return None, None, None
174
 
175
  try:
176
  data = json.loads(chunk_text)
177
 
178
+ if not isinstance(data, dict):
179
+ return None, None, None
180
+
181
+ # Handle first packet (message IDs)
182
+ if 'user_message_id' in data or 'reserved_assistant_message_id' in data:
183
+ return None, data.get('reserved_assistant_message_id'), 'message_ids'
184
+
185
+ # Handle new packet-based format
186
+ if 'obj' in data:
187
+ obj = data['obj']
188
+ packet_type = obj.get('type', '')
189
 
190
+ if packet_type == 'message_delta':
191
+ # This is the actual content!
192
+ content = obj.get('content', '')
193
+ return content, None, 'content'
194
+
195
+ elif packet_type == 'message_start':
196
+ # Contains final_documents, not content
197
+ return None, None, 'message_start'
198
+
199
+ elif packet_type == 'stop':
200
+ # End of stream
201
+ return None, None, 'stop'
202
+
203
+ elif packet_type == 'error':
204
+ error_msg = obj.get('message', obj.get('error', 'Unknown error'))
205
+ return f"[Error: {error_msg}]", None, 'error'
206
+
207
+ elif packet_type == 'citation_delta':
208
+ # Citation info, not content
209
+ return None, None, 'citation'
210
+
211
+ elif packet_type in ['reasoning_start', 'reasoning_delta', 'reasoning_done']:
212
+ # Reasoning packets
213
+ return None, None, 'reasoning'
214
 
215
+ else:
216
+ # Other packet types (search, tools, etc.)
217
+ return None, None, packet_type
218
+
219
+ # FALLBACK: Old format support (for backward compatibility)
220
+ message_id = data.get('message_id')
221
+
222
+ if 'answer_piece' in data:
223
+ return data['answer_piece'], message_id, 'legacy'
224
+ elif 'text' in data:
225
+ return data['text'], message_id, 'legacy'
226
+ elif 'content' in data and isinstance(data['content'], str):
227
+ return data['content'], message_id, 'legacy'
228
+ elif 'error' in data:
229
+ return f"[Error: {data['error']}]", message_id, 'error'
230
+
231
+ return None, None, None
232
+
233
  except json.JSONDecodeError:
234
  # Not JSON, might be raw text
235
  if chunk_text.strip() and not chunk_text.strip().startswith('{'):
236
+ return chunk_text.strip(), None, 'raw'
237
 
238
+ return None, None, None
239
 
240
 
241
  def generate_openai_stream_chunk(content, model, chunk_id, finish_reason=None):
 
261
 
262
  # Try alternate endpoints if needed
263
  endpoints = [
264
+ f"{ONYX_BASE_URL}/api/chat/send-chat-message", # Primary (new)
265
+ f"{ONYX_BASE_URL}/api/chat/send-message", # Fallback (deprecated)
266
  ]
267
 
268
  # Send initial chunk with role
 
328
  if line == '[DONE]':
329
  continue
330
 
331
+ content, msg_id, packet_type = parse_onyx_stream_chunk(line)
332
  if msg_id:
333
  last_message_id = msg_id
334
+ if packet_type == 'stop':
335
+ break
336
+ if content and packet_type in ['content', 'legacy', 'raw', 'error']:
337
  yield generate_openai_stream_chunk(content, model, chunk_id)
338
 
339
  # Process remaining buffer
340
  if buffer.strip():
341
  if buffer.strip().startswith('data: '):
342
  buffer = buffer.strip()[6:]
343
+ content, msg_id, packet_type = parse_onyx_stream_chunk(buffer.strip())
344
  if msg_id:
345
  last_message_id = msg_id
346
+ if content and packet_type in ['content', 'legacy', 'raw', 'error']:
347
  yield generate_openai_stream_chunk(content, model, chunk_id)
348
 
349
  # Update session with last message ID
 
367
  last_message_id = None
368
 
369
  endpoints = [
370
+ f"{ONYX_BASE_URL}/api/chat/send-chat-message", # Primary (new)
371
+ f"{ONYX_BASE_URL}/api/chat/send-message", # Fallback (deprecated)
372
  ]
373
 
374
  for url in endpoints:
 
418
  if line == '[DONE]':
419
  continue
420
 
421
+ content, msg_id, packet_type = parse_onyx_stream_chunk(line)
422
  if msg_id:
423
  last_message_id = msg_id
424
+ if packet_type == 'stop':
425
+ break
426
+ if content and packet_type in ['content', 'legacy', 'raw', 'error']:
427
  full_content += content
428
 
429
  if buffer.strip():
430
  if buffer.strip().startswith('data: '):
431
  buffer = buffer.strip()[6:]
432
+ content, msg_id, packet_type = parse_onyx_stream_chunk(buffer.strip())
433
  if msg_id:
434
  last_message_id = msg_id
435
+ if content and packet_type in ['content', 'legacy', 'raw', 'error']:
436
  full_content += content
437
 
438
  # Update session
 
506
  }
507
  }), 400
508
 
509
+ # Parse model string and normalize provider name
510
  model_provider, model_version = parse_model_string(model)
511
+ model_provider = normalize_provider_name(model_provider)
512
  print(f"Model provider: {model_provider}, version: {model_version}")
513
 
514
  # Get or create chat session
 
529
  model_version=model_version,
530
  temperature=temperature,
531
  chat_session_id=session_info['session_id'],
532
+ parent_message_id=session_info.get('parent_message_id'),
533
+ stream=stream # Pass client's streaming preference
534
  )
535
 
536
  if stream: