sharktide commited on
Commit
6a3237f
·
verified ·
1 Parent(s): adba311

Update gen.py

Browse files
Files changed (1) hide show
  1. gen.py +349 -1
gen.py CHANGED
@@ -941,4 +941,352 @@ def return_models_openai():
941
  "owned_by": "inferenceport-ai"
942
  }
943
  ]
944
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
941
  "owned_by": "inferenceport-ai"
942
  }
943
  ]
944
+ }
945
+
946
+ from uuid import uuid4
947
+ from time import time
948
+ from typing import Any, Dict, List, Optional
949
+ import json
950
+ import os
951
+ import random
952
+ import httpx
953
+
954
+ from fastapi import Request, HTTPException, Header
955
+ from fastapi.responses import JSONResponse, StreamingResponse
956
+
957
+ def _resp_id(prefix: str) -> str:
958
+ return f"{prefix}_{uuid4().hex}"
959
+
960
+ def _resp_ts() -> int:
961
+ return int(time())
962
+
963
+ def _content_to_text(content: Any) -> str:
964
+ if isinstance(content, str):
965
+ return content
966
+ if isinstance(content, list):
967
+ parts = []
968
+ for item in content:
969
+ if isinstance(item, dict):
970
+ t = item.get("type")
971
+ if t in ("input_text", "output_text", "text"):
972
+ txt = item.get("text")
973
+ if isinstance(txt, str):
974
+ parts.append(txt)
975
+ return "".join(parts)
976
+ return ""
977
+
978
+ def _responses_input_to_messages(input_data: Any, instructions: Optional[str] = None) -> List[Dict[str, Any]]:
979
+ messages: List[Dict[str, Any]] = []
980
+ if instructions:
981
+ messages.append({"role": "developer", "content": instructions})
982
+
983
+ if isinstance(input_data, str):
984
+ messages.append({"role": "user", "content": input_data})
985
+ return messages
986
+
987
+ if isinstance(input_data, list):
988
+ for item in input_data:
989
+ if isinstance(item, str):
990
+ messages.append({"role": "user", "content": item})
991
+ continue
992
+ if not isinstance(item, dict):
993
+ continue
994
+ role = item.get("role", "user")
995
+ content = item.get("content", "")
996
+ text = _content_to_text(content)
997
+ if text:
998
+ messages.append({"role": role, "content": text})
999
+
1000
+ return messages
1001
+
1002
+ def _openai_responses_payload(model: str, text: str, input_tokens: int = 0, output_tokens: int = 0) -> Dict[str, Any]:
1003
+ return {
1004
+ "id": _resp_id("resp"),
1005
+ "object": "response",
1006
+ "created_at": _resp_ts(),
1007
+ "status": "completed",
1008
+ "completed_at": _resp_ts(),
1009
+ "error": None,
1010
+ "incomplete_details": None,
1011
+ "instructions": None,
1012
+ "max_output_tokens": None,
1013
+ "model": model,
1014
+ "output": [
1015
+ {
1016
+ "id": _resp_id("msg"),
1017
+ "type": "message",
1018
+ "role": "assistant",
1019
+ "status": "completed",
1020
+ "content": [
1021
+ {
1022
+ "type": "output_text",
1023
+ "text": text,
1024
+ "annotations": []
1025
+ }
1026
+ ]
1027
+ }
1028
+ ],
1029
+ "output_text": text,
1030
+ "usage": {
1031
+ "input_tokens": input_tokens,
1032
+ "output_tokens": output_tokens,
1033
+ "total_tokens": input_tokens + output_tokens
1034
+ }
1035
+ }
1036
+
1037
+ async def _generate_text_from_messages(
1038
+ request: Request,
1039
+ messages: List[Dict[str, Any]],
1040
+ authorization: Optional[str],
1041
+ xclientid: Optional[str],
1042
+ ) -> Dict[str, Any]:
1043
+ totalchars, totalbytes = calculatemessagessize(messages)
1044
+ prompttext = extractusertext(messages)
1045
+
1046
+ usestools = False
1047
+ longcontext = islongcontext(messages)
1048
+ codepresent = containscode(prompttext)
1049
+ mathheavy = ismathheavyprompt(prompttext)
1050
+ structuredtask = isstructuredtask(prompttext)
1051
+ multiq = multiplequestions(prompttext)
1052
+ codeheavy = iscodeheavyprompt(prompttext, codepresent, longcontext)
1053
+
1054
+ score = 0
1055
+ if longcontext:
1056
+ score += 3
1057
+ if mathheavy:
1058
+ score += 3
1059
+ if structuredtask:
1060
+ score += 2
1061
+ if codepresent:
1062
+ score += 2
1063
+ if multiq:
1064
+ score += 1
1065
+ for kw in REASONINGKEYWORDS:
1066
+ if kw in prompttext:
1067
+ score += 1
1068
+ if score > 10:
1069
+ score = 10
1070
+
1071
+ chosenmodel = "llama-3.1-8b-instant"
1072
+ provider = "groq"
1073
+ hasimages = containsimages(messages)
1074
+
1075
+ if hasimages:
1076
+ chosenmodel = "gpt-4o-mini"
1077
+ provider = "navy vision"
1078
+ else:
1079
+ if usestools:
1080
+ if score >= 6:
1081
+ chosenmodel = "nemotron-3-super"
1082
+ provider = "navy"
1083
+ elif score >= 4:
1084
+ chosenmodel = "openai/gpt-oss-120b"
1085
+ provider = "groq"
1086
+ else:
1087
+ chosenmodel = "openai/gpt-oss-20b"
1088
+ provider = "groq"
1089
+ elif codepresent:
1090
+ if codeheavy and score >= 6:
1091
+ chosenmodel = "o3-mini"
1092
+ provider = "navy"
1093
+ elif score >= 4:
1094
+ chosenmodel = "llama-3.3-70b-versatile"
1095
+ provider = "groq"
1096
+ elif score >= 4:
1097
+ chosenmodel = "meta-llama/llama-4-scout-17b-16e-instruct"
1098
+ provider = "groq"
1099
+ elif score >= 6:
1100
+ chosenmodel = "sonar"
1101
+ provider = "navy"
1102
+
1103
+ if provider == "groq" and (totalchars > MAXGROQPROMPTCHARS or totalbytes > MAXGROQPROMPTBYTES):
1104
+ provider = "navy"
1105
+ chosenmodel = "gpt-4o-mini"
1106
+
1107
+ await checkchatratelimit(request, authorization, xclientid)
1108
+
1109
+ if provider == "groq":
1110
+ groqkeys = os.getenv("GROQKEY")
1111
+ groqkeyslist = [k.strip() for k in groqkeys.split(",") if k.strip()] if groqkeys else []
1112
+ if not groqkeyslist:
1113
+ raise HTTPException(status_code=500, detail="Missing GROQKEYs")
1114
+ apikey = random.choice(groqkeyslist)
1115
+ url = "https://api.groq.com/openai/v1/chat/completions"
1116
+ headers = {"Authorization": f"Bearer {apikey}", "Content-Type": "application/json"}
1117
+ payload = {"model": chosenmodel, "messages": messages, "stream": False}
1118
+ async with httpx.AsyncClient(timeout=None) as client:
1119
+ r = await client.post(url, json=payload, headers=headers)
1120
+ if r.status_code != 200:
1121
+ raise HTTPException(status_code=r.status_code, detail=r.text[:1000])
1122
+ data = r.json()
1123
+ text = ""
1124
+ try:
1125
+ text = data["choices"][0]["message"]["content"] or ""
1126
+ except Exception:
1127
+ text = ""
1128
+ return {"text": text, "model": chosenmodel, "provider": provider, "raw": data}
1129
+
1130
+ if provider == "navy vision":
1131
+ navykeys = os.getenv("NAVYKEY")
1132
+ navykeyslist = [k.strip() for k in navykeys.split(",") if k.strip()] if navykeys else []
1133
+ if not navykeyslist:
1134
+ raise HTTPException(status_code=500, detail="Missing NAVYKEYs")
1135
+ apikey = random.choice(navykeyslist)
1136
+ url = "https://api.navy/v1/chat/completions"
1137
+ headers = {"Authorization": f"Bearer {apikey}", "Content-Type": "application/json"}
1138
+ payload = {"model": chosenmodel, "messages": messages, "stream": False}
1139
+ async with httpx.AsyncClient(timeout=None) as client:
1140
+ r = await client.post(url, json=payload, headers=headers)
1141
+ if r.status_code != 200:
1142
+ raise HTTPException(status_code=r.status_code, detail=r.text[:1000])
1143
+ data = r.json()
1144
+ text = ""
1145
+ try:
1146
+ text = data["choices"][0]["message"]["content"] or ""
1147
+ except Exception:
1148
+ text = ""
1149
+ return {"text": text, "model": chosenmodel, "provider": provider, "raw": data}
1150
+
1151
+ if provider == "navy":
1152
+ navykeys = os.getenv("NAVYTEXTONLY")
1153
+ navykeyslist = [k.strip() for k in navykeys.split(",") if k.strip()] if navykeys else []
1154
+ if not navykeyslist:
1155
+ raise HTTPException(status_code=500, detail="Missing NAVY TEXT ONLY keys")
1156
+ apikey = random.choice(navykeyslist)
1157
+ url = "https://api.navy/v1/chat/completions"
1158
+ headers = {"Authorization": f"Bearer {apikey}", "Content-Type": "application/json"}
1159
+ payload = {"model": chosenmodel, "messages": messages, "stream": False}
1160
+ async with httpx.AsyncClient(timeout=None) as client:
1161
+ r = await client.post(url, json=payload, headers=headers)
1162
+ if r.status_code != 200:
1163
+ raise HTTPException(status_code=r.status_code, detail=r.text[:1000])
1164
+ data = r.json()
1165
+ text = ""
1166
+ try:
1167
+ text = data["choices"][0]["message"]["content"] or ""
1168
+ except Exception:
1169
+ text = ""
1170
+ return {"text": text, "model": chosenmodel, "provider": provider, "raw": data}
1171
+
1172
+ raise HTTPException(status_code=500, detail="Unknown provider routing error")
1173
+
1174
+ @router.post("/responses")
1175
+ async def create_responses(
1176
+ request: Request,
1177
+ authorization: Optional[str] = Header(None),
1178
+ xclientid: Optional[str] = Header(None),
1179
+ ):
1180
+ body = await request.json()
1181
+ model = body.get("model")
1182
+ input_data = body.get("input")
1183
+ instructions = body.get("instructions")
1184
+ stream = body.get("stream", True)
1185
+ response_format = body.get("response_format")
1186
+
1187
+ if not model:
1188
+ raise HTTPException(status_code=400, detail="model is required")
1189
+ if input_data is None:
1190
+ raise HTTPException(status_code=400, detail="input is required")
1191
+
1192
+ messages = _responses_input_to_messages(input_data, instructions=instructions)
1193
+ if not messages:
1194
+ raise HTTPException(status_code=400, detail="input could not be parsed")
1195
+
1196
+ if stream is False:
1197
+ result = await _generate_text_from_messages(
1198
+ request=request,
1199
+ messages=messages,
1200
+ authorization=authorization,
1201
+ xclientid=xclientid,
1202
+ )
1203
+ if "text" not in result:
1204
+ raise HTTPException(status_code=500, detail="upstream generation failed")
1205
+ return JSONResponse(content=_openai_responses_payload(model, result["text"]))
1206
+
1207
+ async def event_stream():
1208
+ response_id = _resp_id("resp")
1209
+ created = {
1210
+ "type": "response.created",
1211
+ "response": {
1212
+ "id": response_id,
1213
+ "object": "response",
1214
+ "created_at": _resp_ts(),
1215
+ "status": "in_progress",
1216
+ "model": model
1217
+ }
1218
+ }
1219
+ yield f"data: {json.dumps(created)}\n\n"
1220
+
1221
+ result = await _generate_text_from_messages(
1222
+ request=request,
1223
+ messages=messages,
1224
+ authorization=authorization,
1225
+ xclientid=xclientid,
1226
+ )
1227
+
1228
+ if "text" in result:
1229
+ text = result["text"]
1230
+ if text:
1231
+ chunk_size = 64
1232
+ for i in range(0, len(text), chunk_size):
1233
+ delta = text[i:i + chunk_size]
1234
+ evt = {
1235
+ "type": "response.output_text.delta",
1236
+ "response_id": response_id,
1237
+ "delta": delta
1238
+ }
1239
+ yield f"data: {json.dumps(evt)}\n\n"
1240
+
1241
+ completed = {
1242
+ "type": "response.completed",
1243
+ "response": {
1244
+ "id": response_id,
1245
+ "object": "response",
1246
+ "created_at": _resp_ts(),
1247
+ "status": "completed",
1248
+ "completed_at": _resp_ts(),
1249
+ "model": model,
1250
+ "output_text": result["text"],
1251
+ "output": [
1252
+ {
1253
+ "id": _resp_id("msg"),
1254
+ "type": "message",
1255
+ "role": "assistant",
1256
+ "status": "completed",
1257
+ "content": [
1258
+ {
1259
+ "type": "output_text",
1260
+ "text": result["text"],
1261
+ "annotations": []
1262
+ }
1263
+ ]
1264
+ }
1265
+ ],
1266
+ "usage": {
1267
+ "input_tokens": 0,
1268
+ "output_tokens": 0,
1269
+ "total_tokens": 0
1270
+ }
1271
+ }
1272
+ }
1273
+ yield f"data: {json.dumps(completed)}\n\n"
1274
+ yield "data: [DONE]\n\n"
1275
+ return
1276
+
1277
+ err = {
1278
+ "type": "response.error",
1279
+ "error": result.get("error", {"message": "upstream error"})
1280
+ }
1281
+ yield f"data: {json.dumps(err)}\n\n"
1282
+ yield "data: [DONE]\n\n"
1283
+
1284
+ return StreamingResponse(
1285
+ event_stream(),
1286
+ media_type="text/event-stream",
1287
+ headers={
1288
+ "Cache-Control": "no-cache",
1289
+ "Connection": "keep-alive",
1290
+ "X-Accel-Buffering": "no",
1291
+ },
1292
+ )