cjovs commited on
Commit
821145f
·
verified ·
1 Parent(s): bcab3b5

Add streaming support for /v1/responses compatibility

Browse files
Files changed (1) hide show
  1. app/routes.py +195 -7
app/routes.py CHANGED
@@ -6,6 +6,7 @@ import re
6
  import threading
7
  import time
8
  import urllib.request
 
9
  from uuid import uuid4
10
 
11
  from fastapi import APIRouter, HTTPException, Request
@@ -196,6 +197,22 @@ def _call_local_chat_completions(chat_body: dict, auth_header: str, x_api_key: s
196
  return 500, {"error": str(exc)}
197
 
198
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  def _chat_completion_to_response_payload(chat_payload: dict) -> dict:
200
  choice = ((chat_payload.get("choices") or [{}])[0])
201
  message = choice.get("message") or {}
@@ -245,16 +262,176 @@ def _chat_completion_to_response_payload(chat_payload: dict) -> dict:
245
  }
246
 
247
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
  @router.post("/v1/responses")
249
  async def responses_api(request: Request):
250
  try:
251
  req_data = await request.json()
252
- if bool(req_data.get("stream", False)):
253
- return JSONResponse(
254
- status_code=400,
255
- content={"error": "/v1/responses streaming is not supported yet. Use /v1/chat/completions with stream=true."},
256
- )
257
-
258
  model = req_data.get("model")
259
  messages = _responses_input_to_messages(req_data)
260
  if not model or not messages:
@@ -266,7 +443,7 @@ async def responses_api(request: Request):
266
  chat_body = {
267
  "model": model,
268
  "messages": messages,
269
- "stream": False,
270
  }
271
  if "tools" in req_data:
272
  chat_body["tools"] = _responses_tools_to_chat_tools(req_data.get("tools"))
@@ -277,6 +454,17 @@ async def responses_api(request: Request):
277
  if "top_p" in req_data:
278
  chat_body["top_p"] = req_data.get("top_p")
279
 
 
 
 
 
 
 
 
 
 
 
 
280
  status_code, chat_payload = await __import__("asyncio").to_thread(
281
  _call_local_chat_completions,
282
  chat_body,
 
6
  import threading
7
  import time
8
  import urllib.request
9
+ import urllib.parse
10
  from uuid import uuid4
11
 
12
  from fastapi import APIRouter, HTTPException, Request
 
197
  return 500, {"error": str(exc)}
198
 
199
 
200
+ def _open_local_chat_completions_stream(chat_body: dict, auth_header: str, x_api_key: str):
201
+ local_port = os.getenv("PORT", "7860")
202
+ headers = {"Content-Type": "application/json", "Accept": "text/event-stream"}
203
+ if auth_header:
204
+ headers["Authorization"] = auth_header
205
+ if x_api_key:
206
+ headers["x-api-key"] = x_api_key
207
+ req = urllib.request.Request(
208
+ f"http://127.0.0.1:{local_port}/v1/chat/completions",
209
+ data=json.dumps(chat_body).encode("utf-8"),
210
+ headers=headers,
211
+ method="POST",
212
+ )
213
+ return urllib.request.urlopen(req, timeout=180)
214
+
215
+
216
  def _chat_completion_to_response_payload(chat_payload: dict) -> dict:
217
  choice = ((chat_payload.get("choices") or [{}])[0])
218
  message = choice.get("message") or {}
 
262
  }
263
 
264
 
265
+ def _sse_event(data: dict) -> str:
266
+ return f"event: {data['type']}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
267
+
268
+
269
+ def _response_base(response_id: str, model: str) -> dict:
270
+ return {
271
+ "id": response_id,
272
+ "object": "response",
273
+ "created_at": int(time.time()),
274
+ "status": "in_progress",
275
+ "error": None,
276
+ "incomplete_details": None,
277
+ "instructions": None,
278
+ "model": model,
279
+ "output": [],
280
+ "parallel_tool_calls": True,
281
+ "tools": [],
282
+ "top_p": 1,
283
+ "temperature": 1,
284
+ "text": {"format": {"type": "text"}},
285
+ "metadata": {},
286
+ "usage": None,
287
+ }
288
+
289
+
290
+ def _stream_responses_from_chat(chat_body: dict, auth_header: str, x_api_key: str, model: str):
291
+ response_id = f"resp_{uuid4().hex}"
292
+ message_id = f"msg_{uuid4().hex}"
293
+ seq = 1
294
+
295
+ yield _sse_event({"type": "response.created", "response": _response_base(response_id, model), "sequence_number": seq})
296
+ seq += 1
297
+ yield _sse_event({"type": "response.in_progress", "response": _response_base(response_id, model), "sequence_number": seq})
298
+ seq += 1
299
+ yield _sse_event(
300
+ {
301
+ "type": "response.output_item.added",
302
+ "output_index": 0,
303
+ "item": {
304
+ "id": message_id,
305
+ "type": "message",
306
+ "status": "in_progress",
307
+ "role": "assistant",
308
+ "content": [],
309
+ },
310
+ "sequence_number": seq,
311
+ }
312
+ )
313
+ seq += 1
314
+ yield _sse_event(
315
+ {
316
+ "type": "response.content_part.added",
317
+ "item_id": message_id,
318
+ "output_index": 0,
319
+ "content_index": 0,
320
+ "part": {"type": "output_text", "text": "", "annotations": []},
321
+ "sequence_number": seq,
322
+ }
323
+ )
324
+ seq += 1
325
+
326
+ full_text = []
327
+ usage = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
328
+
329
+ try:
330
+ with _open_local_chat_completions_stream(chat_body, auth_header, x_api_key) as resp:
331
+ for raw_line in resp:
332
+ line = raw_line.decode("utf-8", errors="replace").strip()
333
+ if not line or not line.startswith("data: "):
334
+ continue
335
+ payload = line[6:]
336
+ if payload == "[DONE]":
337
+ break
338
+ try:
339
+ chunk = json.loads(payload)
340
+ except json.JSONDecodeError:
341
+ continue
342
+ choice = ((chunk.get("choices") or [{}])[0])
343
+ delta = choice.get("delta") or {}
344
+ text_delta = delta.get("content")
345
+ if text_delta:
346
+ full_text.append(text_delta)
347
+ yield _sse_event(
348
+ {
349
+ "type": "response.output_text.delta",
350
+ "item_id": message_id,
351
+ "output_index": 0,
352
+ "content_index": 0,
353
+ "delta": text_delta,
354
+ "sequence_number": seq,
355
+ }
356
+ )
357
+ seq += 1
358
+ if "usage" in chunk and isinstance(chunk["usage"], dict):
359
+ usage = {
360
+ "input_tokens": chunk["usage"].get("prompt_tokens", 0),
361
+ "output_tokens": chunk["usage"].get("completion_tokens", 0),
362
+ "total_tokens": chunk["usage"].get("total_tokens", 0),
363
+ }
364
+ except urllib.error.HTTPError as exc:
365
+ raw = exc.read().decode("utf-8", errors="replace")
366
+ message = raw
367
+ try:
368
+ parsed = json.loads(raw)
369
+ message = parsed.get("error") or parsed.get("detail") or raw
370
+ except Exception:
371
+ pass
372
+ yield _sse_event({"type": "error", "error": {"message": str(message)}, "sequence_number": seq})
373
+ return
374
+ except Exception as exc:
375
+ yield _sse_event({"type": "error", "error": {"message": str(exc)}, "sequence_number": seq})
376
+ return
377
+
378
+ final_text = "".join(full_text)
379
+ yield _sse_event(
380
+ {
381
+ "type": "response.output_text.done",
382
+ "item_id": message_id,
383
+ "output_index": 0,
384
+ "content_index": 0,
385
+ "text": final_text,
386
+ "sequence_number": seq,
387
+ }
388
+ )
389
+ seq += 1
390
+ yield _sse_event(
391
+ {
392
+ "type": "response.content_part.done",
393
+ "item_id": message_id,
394
+ "output_index": 0,
395
+ "content_index": 0,
396
+ "part": {"type": "output_text", "text": final_text, "annotations": []},
397
+ "sequence_number": seq,
398
+ }
399
+ )
400
+ seq += 1
401
+ yield _sse_event(
402
+ {
403
+ "type": "response.output_item.done",
404
+ "output_index": 0,
405
+ "item": {
406
+ "id": message_id,
407
+ "type": "message",
408
+ "status": "completed",
409
+ "role": "assistant",
410
+ "content": [{"type": "output_text", "text": final_text, "annotations": []}],
411
+ },
412
+ "sequence_number": seq,
413
+ }
414
+ )
415
+ seq += 1
416
+ completed = _response_base(response_id, model)
417
+ completed["status"] = "completed"
418
+ completed["output"] = [
419
+ {
420
+ "id": message_id,
421
+ "type": "message",
422
+ "status": "completed",
423
+ "role": "assistant",
424
+ "content": [{"type": "output_text", "text": final_text, "annotations": []}],
425
+ }
426
+ ]
427
+ completed["usage"] = usage
428
+ yield _sse_event({"type": "response.completed", "response": completed, "sequence_number": seq})
429
+
430
+
431
  @router.post("/v1/responses")
432
  async def responses_api(request: Request):
433
  try:
434
  req_data = await request.json()
 
 
 
 
 
 
435
  model = req_data.get("model")
436
  messages = _responses_input_to_messages(req_data)
437
  if not model or not messages:
 
443
  chat_body = {
444
  "model": model,
445
  "messages": messages,
446
+ "stream": bool(req_data.get("stream", False)),
447
  }
448
  if "tools" in req_data:
449
  chat_body["tools"] = _responses_tools_to_chat_tools(req_data.get("tools"))
 
454
  if "top_p" in req_data:
455
  chat_body["top_p"] = req_data.get("top_p")
456
 
457
+ if bool(req_data.get("stream", False)):
458
+ return StreamingResponse(
459
+ _stream_responses_from_chat(
460
+ chat_body,
461
+ request.headers.get("Authorization", ""),
462
+ request.headers.get("x-api-key", ""),
463
+ model,
464
+ ),
465
+ media_type="text/event-stream",
466
+ )
467
+
468
  status_code, chat_payload = await __import__("asyncio").to_thread(
469
  _call_local_chat_completions,
470
  chat_body,