sharktide commited on
Commit
9f437ef
·
verified ·
1 Parent(s): 1f8b39e

Update gen.py

Browse files
Files changed (1) hide show
  1. gen.py +308 -0
gen.py CHANGED
@@ -942,3 +942,311 @@ def return_models_openai():
942
  }
943
  ]
944
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
942
  }
943
  ]
944
  }
945
+
946
+ import time
947
+ import uuid
948
+
949
+ # -----------------------------
950
+ # RESPONSES API (OpenAI-compatible)
951
+ # -----------------------------
952
+
953
+ def _normalize_responses_input(input_field) -> list[dict]:
954
+ """
955
+ Coerce the Responses API `input` field into a standard messages[] list.
956
+
957
+ Accepted shapes:
958
+ • str → [{"role":"user","content":"..."}]
959
+ • list of message-like dicts → pass through, normalising content parts
960
+ """
961
+ if isinstance(input_field, str):
962
+ return [{"role": "user", "content": input_field}]
963
+
964
+ messages = []
965
+ for item in input_field:
966
+ role = item.get("role", "user")
967
+ content = item.get("content", "")
968
+
969
+ # Content can be a plain string or a list of content parts
970
+ if isinstance(content, list):
971
+ # Translate Responses-style parts to Chat-style parts
972
+ parts = []
973
+ for part in content:
974
+ ptype = part.get("type", "")
975
+ if ptype == "text":
976
+ parts.append({"type": "text", "text": part.get("text", "")})
977
+ elif ptype == "image_url":
978
+ parts.append({"type": "image_url", "image_url": part.get("image_url", {})})
979
+ elif ptype == "input_audio":
980
+ # Not supported downstream — skip gracefully
981
+ pass
982
+ else:
983
+ # Forward unknown parts as-is so nothing is silently dropped
984
+ parts.append(part)
985
+ messages.append({"role": role, "content": parts})
986
+ else:
987
+ messages.append({"role": role, "content": content})
988
+
989
+ return messages
990
+
991
+
992
+ def _wrap_responses_output(chat_payload: dict, model_name: str) -> dict:
993
+ """
994
+ Wrap a standard chat-completions JSON response into the Responses API shape.
995
+ """
996
+ choices = chat_payload.get("choices", [])
997
+ output = []
998
+
999
+ for choice in choices:
1000
+ msg = choice.get("message", {})
1001
+ content_text = msg.get("content") or ""
1002
+ tool_calls = msg.get("tool_calls")
1003
+
1004
+ if tool_calls:
1005
+ for tc in tool_calls:
1006
+ fn = tc.get("function", {})
1007
+ output.append({
1008
+ "type": "function_call",
1009
+ "id": tc.get("id", f"call_{uuid.uuid4().hex[:8]}"),
1010
+ "call_id": tc.get("id", f"call_{uuid.uuid4().hex[:8]}"),
1011
+ "name": fn.get("name", ""),
1012
+ "arguments": fn.get("arguments", ""),
1013
+ })
1014
+ else:
1015
+ output.append({
1016
+ "type": "message",
1017
+ "id": f"msg_{uuid.uuid4().hex[:12]}",
1018
+ "role": msg.get("role", "assistant"),
1019
+ "content": [{"type": "output_text", "text": content_text, "annotations": []}],
1020
+ "status": "completed",
1021
+ })
1022
+
1023
+ usage = chat_payload.get("usage", {})
1024
+
1025
+ return {
1026
+ "id": f"resp_{uuid.uuid4().hex[:24]}",
1027
+ "object": "response",
1028
+ "created_at": int(time.time()),
1029
+ "model": model_name,
1030
+ "output": output,
1031
+ "usage": {
1032
+ "input_tokens": usage.get("prompt_tokens", 0),
1033
+ "output_tokens": usage.get("completion_tokens", 0),
1034
+ "total_tokens": usage.get("total_tokens", 0),
1035
+ },
1036
+ "status": "completed",
1037
+ "error": None,
1038
+ }
1039
+
1040
+
1041
+ def _wrap_responses_stream_chunk(
1042
+ line: str, response_id: str, model_name: str, sent_created: bool
1043
+ ) -> tuple[str, bool]:
1044
+ """
1045
+ Translate a single SSE line from chat-completions streaming format
1046
+ into Responses API streaming events.
1047
+
1048
+ Returns (translated_line, sent_created).
1049
+ """
1050
+ if not line.startswith("data:"):
1051
+ return line + "\n", sent_created
1052
+
1053
+ raw = line[5:].strip()
1054
+ if raw == "[DONE]":
1055
+ done_event = json.dumps({
1056
+ "type": "response.completed",
1057
+ "response": {
1058
+ "id": response_id,
1059
+ "object": "response",
1060
+ "model": model_name,
1061
+ "status": "completed",
1062
+ "output": [],
1063
+ "usage": None,
1064
+ },
1065
+ })
1066
+ return f"data: {done_event}\n\n", sent_created
1067
+
1068
+ try:
1069
+ chunk = json.loads(raw)
1070
+ except json.JSONDecodeError:
1071
+ # Forward as-is — could be our router_metadata injection
1072
+ return line + "\n", sent_created
1073
+
1074
+ # router_metadata forwarded from generate_text — pass through unchanged
1075
+ if "router_metadata" in chunk:
1076
+ return f"data: {json.dumps(chunk)}\n\n", sent_created
1077
+
1078
+ out_lines = []
1079
+
1080
+ if not sent_created:
1081
+ created_event = {
1082
+ "type": "response.created",
1083
+ "response": {
1084
+ "id": response_id,
1085
+ "object": "response",
1086
+ "model": model_name,
1087
+ "status": "in_progress",
1088
+ "output": [],
1089
+ },
1090
+ }
1091
+ out_lines.append(f"data: {json.dumps(created_event)}\n\n")
1092
+ sent_created = True
1093
+
1094
+ choices = chunk.get("choices", [])
1095
+ for choice in choices:
1096
+ delta = choice.get("delta", {})
1097
+ finish_reason = choice.get("finish_reason")
1098
+
1099
+ # Text delta
1100
+ text_delta = delta.get("content")
1101
+ if text_delta:
1102
+ delta_event = {
1103
+ "type": "response.output_text.delta",
1104
+ "item_id": f"msg_{response_id[-12:]}",
1105
+ "output_index": 0,
1106
+ "content_index": 0,
1107
+ "delta": text_delta,
1108
+ }
1109
+ out_lines.append(f"data: {json.dumps(delta_event)}\n\n")
1110
+
1111
+ # Tool call delta
1112
+ tool_calls = delta.get("tool_calls")
1113
+ if tool_calls:
1114
+ for tc in tool_calls:
1115
+ fn = tc.get("function", {})
1116
+ tc_event = {
1117
+ "type": "response.function_call_arguments.delta",
1118
+ "item_id": tc.get("id", ""),
1119
+ "output_index": tc.get("index", 0),
1120
+ "call_id": tc.get("id", ""),
1121
+ "delta": fn.get("arguments", ""),
1122
+ }
1123
+ out_lines.append(f"data: {json.dumps(tc_event)}\n\n")
1124
+
1125
+ if finish_reason:
1126
+ done_text_event = {
1127
+ "type": "response.output_text.done",
1128
+ "item_id": f"msg_{response_id[-12:]}",
1129
+ "output_index": 0,
1130
+ "content_index": 0,
1131
+ "text": "", # full text not echoed here; client accumulates deltas
1132
+ }
1133
+ out_lines.append(f"data: {json.dumps(done_text_event)}\n\n")
1134
+
1135
+ return "".join(out_lines), sent_created
1136
+
1137
+
1138
+ @router.post("/responses")
1139
+ async def create_response(
1140
+ request: Request,
1141
+ authorization: Optional[str] = Header(None),
1142
+ x_client_id: Optional[str] = Header(None),
1143
+ ):
1144
+ """
1145
+ OpenAI Responses API-compatible endpoint.
1146
+
1147
+ Accepts the Responses API request shape, normalises it into the chat
1148
+ completions format, routes it through the existing generate_text logic,
1149
+ and wraps the result back into the Responses API shape.
1150
+
1151
+ Supported fields:
1152
+ • input (str | list) — required
1153
+ • model (str) — accepted but ignored (router decides)
1154
+ • stream (bool) — optional, default False
1155
+ • tools (list) — optional, forwarded as-is
1156
+ • tool_choice (str|dict) — optional, forwarded as-is
1157
+ • temperature (float) — optional, forwarded
1158
+ • max_output_tokens (int) — mapped to max_tokens
1159
+ """
1160
+ body = await request.json()
1161
+
1162
+ input_field = body.get("input")
1163
+ if not input_field and input_field != "":
1164
+ raise HTTPException(400, "`input` is required")
1165
+
1166
+ # --- Normalise into chat-completions shape ---
1167
+ messages = _normalize_responses_input(input_field)
1168
+
1169
+ chat_body: dict = {"messages": messages}
1170
+
1171
+ # Forward compatible fields
1172
+ for field in ("tools", "tool_choice", "temperature", "top_p", "stream"):
1173
+ if field in body:
1174
+ chat_body[field] = body[field]
1175
+
1176
+ if "max_output_tokens" in body:
1177
+ chat_body["max_tokens"] = body["max_output_tokens"]
1178
+
1179
+ # Mutate the request body so generate_text can read it via request.json()
1180
+ # We call it directly instead, reusing its inner logic via a sub-request shim.
1181
+ # Simpler: re-invoke the routing logic inline by building a new Request.
1182
+ from starlette.requests import Request as StarletteRequest
1183
+ from starlette.datastructures import Headers as StarletteHeaders
1184
+
1185
+ raw_body = json.dumps(chat_body).encode()
1186
+
1187
+ scope = dict(request.scope)
1188
+ scope["path"] = "/gen/chat/completions"
1189
+ scope["headers"] = [
1190
+ (k, v) for k, v in request.scope["headers"]
1191
+ if k.lower() not in (b"content-length",)
1192
+ ] + [(b"content-length", str(len(raw_body)).encode())]
1193
+
1194
+ async def new_receive():
1195
+ return {"type": "http.request", "body": raw_body, "more_body": False}
1196
+
1197
+ sub_request = StarletteRequest(scope, new_receive)
1198
+
1199
+ # --- Delegate to generate_text ---
1200
+ response = await generate_text(sub_request, authorization, x_client_id)
1201
+
1202
+ # --- Streaming path ---
1203
+ if chat_body.get("stream"):
1204
+ response_id = f"resp_{uuid.uuid4().hex[:24]}"
1205
+ model_label = MODEL_MAP.get("lightning", "lightning")
1206
+
1207
+ async def responses_stream():
1208
+ sent_created = False
1209
+ async for chunk in response.body_iterator:
1210
+ if isinstance(chunk, bytes):
1211
+ chunk = chunk.decode("utf-8", errors="replace")
1212
+ for line in chunk.splitlines():
1213
+ translated, sent_created = _wrap_responses_stream_chunk(
1214
+ line, response_id, model_label, sent_created
1215
+ )
1216
+ if translated:
1217
+ yield translated
1218
+
1219
+ return StreamingResponse(
1220
+ responses_stream(),
1221
+ media_type="text/event-stream",
1222
+ headers={
1223
+ "Cache-Control": "no-cache",
1224
+ "Connection": "keep-alive",
1225
+ "X-Accel-Buffering": "no",
1226
+ },
1227
+ )
1228
+
1229
+ # --- Non-streaming path ---
1230
+ if hasattr(response, "body"):
1231
+ raw = response.body
1232
+ else:
1233
+ raw = b""
1234
+ async for chunk in response.body_iterator:
1235
+ raw += chunk if isinstance(chunk, bytes) else chunk.encode()
1236
+
1237
+ try:
1238
+ chat_payload = json.loads(raw)
1239
+ except json.JSONDecodeError:
1240
+ raise HTTPException(502, "Upstream returned unparseable JSON")
1241
+
1242
+ if response.status_code >= 400:
1243
+ return JSONResponse(status_code=response.status_code, content=chat_payload)
1244
+
1245
+ model_label = MODEL_MAP.get(
1246
+ chat_payload.get("model", ""),
1247
+ chat_payload.get("model", "lightning"),
1248
+ )
1249
+ return JSONResponse(
1250
+ status_code=200,
1251
+ content=_wrap_responses_output(chat_payload, model_label),
1252
+ )