Spaces:
Running
Running
Update gen.py
Browse files
gen.py
CHANGED
|
@@ -941,4 +941,352 @@ def return_models_openai():
|
|
| 941 |
"owned_by": "inferenceport-ai"
|
| 942 |
}
|
| 943 |
]
|
| 944 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 941 |
"owned_by": "inferenceport-ai"
|
| 942 |
}
|
| 943 |
]
|
| 944 |
+
}
|
| 945 |
+
|
| 946 |
+
from uuid import uuid4
|
| 947 |
+
from time import time
|
| 948 |
+
from typing import Any, Dict, List, Optional
|
| 949 |
+
import json
|
| 950 |
+
import os
|
| 951 |
+
import random
|
| 952 |
+
import httpx
|
| 953 |
+
|
| 954 |
+
from fastapi import Request, HTTPException, Header
|
| 955 |
+
from fastapi.responses import JSONResponse, StreamingResponse
|
| 956 |
+
|
| 957 |
+
def _resp_id(prefix: str) -> str:
|
| 958 |
+
return f"{prefix}_{uuid4().hex}"
|
| 959 |
+
|
| 960 |
+
def _resp_ts() -> int:
|
| 961 |
+
return int(time())
|
| 962 |
+
|
| 963 |
+
def _content_to_text(content: Any) -> str:
|
| 964 |
+
if isinstance(content, str):
|
| 965 |
+
return content
|
| 966 |
+
if isinstance(content, list):
|
| 967 |
+
parts = []
|
| 968 |
+
for item in content:
|
| 969 |
+
if isinstance(item, dict):
|
| 970 |
+
t = item.get("type")
|
| 971 |
+
if t in ("input_text", "output_text", "text"):
|
| 972 |
+
txt = item.get("text")
|
| 973 |
+
if isinstance(txt, str):
|
| 974 |
+
parts.append(txt)
|
| 975 |
+
return "".join(parts)
|
| 976 |
+
return ""
|
| 977 |
+
|
| 978 |
+
def _responses_input_to_messages(input_data: Any, instructions: Optional[str] = None) -> List[Dict[str, Any]]:
|
| 979 |
+
messages: List[Dict[str, Any]] = []
|
| 980 |
+
if instructions:
|
| 981 |
+
messages.append({"role": "developer", "content": instructions})
|
| 982 |
+
|
| 983 |
+
if isinstance(input_data, str):
|
| 984 |
+
messages.append({"role": "user", "content": input_data})
|
| 985 |
+
return messages
|
| 986 |
+
|
| 987 |
+
if isinstance(input_data, list):
|
| 988 |
+
for item in input_data:
|
| 989 |
+
if isinstance(item, str):
|
| 990 |
+
messages.append({"role": "user", "content": item})
|
| 991 |
+
continue
|
| 992 |
+
if not isinstance(item, dict):
|
| 993 |
+
continue
|
| 994 |
+
role = item.get("role", "user")
|
| 995 |
+
content = item.get("content", "")
|
| 996 |
+
text = _content_to_text(content)
|
| 997 |
+
if text:
|
| 998 |
+
messages.append({"role": role, "content": text})
|
| 999 |
+
|
| 1000 |
+
return messages
|
| 1001 |
+
|
| 1002 |
+
def _openai_responses_payload(model: str, text: str, input_tokens: int = 0, output_tokens: int = 0) -> Dict[str, Any]:
|
| 1003 |
+
return {
|
| 1004 |
+
"id": _resp_id("resp"),
|
| 1005 |
+
"object": "response",
|
| 1006 |
+
"created_at": _resp_ts(),
|
| 1007 |
+
"status": "completed",
|
| 1008 |
+
"completed_at": _resp_ts(),
|
| 1009 |
+
"error": None,
|
| 1010 |
+
"incomplete_details": None,
|
| 1011 |
+
"instructions": None,
|
| 1012 |
+
"max_output_tokens": None,
|
| 1013 |
+
"model": model,
|
| 1014 |
+
"output": [
|
| 1015 |
+
{
|
| 1016 |
+
"id": _resp_id("msg"),
|
| 1017 |
+
"type": "message",
|
| 1018 |
+
"role": "assistant",
|
| 1019 |
+
"status": "completed",
|
| 1020 |
+
"content": [
|
| 1021 |
+
{
|
| 1022 |
+
"type": "output_text",
|
| 1023 |
+
"text": text,
|
| 1024 |
+
"annotations": []
|
| 1025 |
+
}
|
| 1026 |
+
]
|
| 1027 |
+
}
|
| 1028 |
+
],
|
| 1029 |
+
"output_text": text,
|
| 1030 |
+
"usage": {
|
| 1031 |
+
"input_tokens": input_tokens,
|
| 1032 |
+
"output_tokens": output_tokens,
|
| 1033 |
+
"total_tokens": input_tokens + output_tokens
|
| 1034 |
+
}
|
| 1035 |
+
}
|
| 1036 |
+
|
| 1037 |
+
async def _generate_text_from_messages(
|
| 1038 |
+
request: Request,
|
| 1039 |
+
messages: List[Dict[str, Any]],
|
| 1040 |
+
authorization: Optional[str],
|
| 1041 |
+
xclientid: Optional[str],
|
| 1042 |
+
) -> Dict[str, Any]:
|
| 1043 |
+
totalchars, totalbytes = calculatemessagessize(messages)
|
| 1044 |
+
prompttext = extractusertext(messages)
|
| 1045 |
+
|
| 1046 |
+
usestools = False
|
| 1047 |
+
longcontext = islongcontext(messages)
|
| 1048 |
+
codepresent = containscode(prompttext)
|
| 1049 |
+
mathheavy = ismathheavyprompt(prompttext)
|
| 1050 |
+
structuredtask = isstructuredtask(prompttext)
|
| 1051 |
+
multiq = multiplequestions(prompttext)
|
| 1052 |
+
codeheavy = iscodeheavyprompt(prompttext, codepresent, longcontext)
|
| 1053 |
+
|
| 1054 |
+
score = 0
|
| 1055 |
+
if longcontext:
|
| 1056 |
+
score += 3
|
| 1057 |
+
if mathheavy:
|
| 1058 |
+
score += 3
|
| 1059 |
+
if structuredtask:
|
| 1060 |
+
score += 2
|
| 1061 |
+
if codepresent:
|
| 1062 |
+
score += 2
|
| 1063 |
+
if multiq:
|
| 1064 |
+
score += 1
|
| 1065 |
+
for kw in REASONINGKEYWORDS:
|
| 1066 |
+
if kw in prompttext:
|
| 1067 |
+
score += 1
|
| 1068 |
+
if score > 10:
|
| 1069 |
+
score = 10
|
| 1070 |
+
|
| 1071 |
+
chosenmodel = "llama-3.1-8b-instant"
|
| 1072 |
+
provider = "groq"
|
| 1073 |
+
hasimages = containsimages(messages)
|
| 1074 |
+
|
| 1075 |
+
if hasimages:
|
| 1076 |
+
chosenmodel = "gpt-4o-mini"
|
| 1077 |
+
provider = "navy vision"
|
| 1078 |
+
else:
|
| 1079 |
+
if usestools:
|
| 1080 |
+
if score >= 6:
|
| 1081 |
+
chosenmodel = "nemotron-3-super"
|
| 1082 |
+
provider = "navy"
|
| 1083 |
+
elif score >= 4:
|
| 1084 |
+
chosenmodel = "openai/gpt-oss-120b"
|
| 1085 |
+
provider = "groq"
|
| 1086 |
+
else:
|
| 1087 |
+
chosenmodel = "openai/gpt-oss-20b"
|
| 1088 |
+
provider = "groq"
|
| 1089 |
+
elif codepresent:
|
| 1090 |
+
if codeheavy and score >= 6:
|
| 1091 |
+
chosenmodel = "o3-mini"
|
| 1092 |
+
provider = "navy"
|
| 1093 |
+
elif score >= 4:
|
| 1094 |
+
chosenmodel = "llama-3.3-70b-versatile"
|
| 1095 |
+
provider = "groq"
|
| 1096 |
+
elif score >= 4:
|
| 1097 |
+
chosenmodel = "meta-llama/llama-4-scout-17b-16e-instruct"
|
| 1098 |
+
provider = "groq"
|
| 1099 |
+
elif score >= 6:
|
| 1100 |
+
chosenmodel = "sonar"
|
| 1101 |
+
provider = "navy"
|
| 1102 |
+
|
| 1103 |
+
if provider == "groq" and (totalchars > MAXGROQPROMPTCHARS or totalbytes > MAXGROQPROMPTBYTES):
|
| 1104 |
+
provider = "navy"
|
| 1105 |
+
chosenmodel = "gpt-4o-mini"
|
| 1106 |
+
|
| 1107 |
+
await checkchatratelimit(request, authorization, xclientid)
|
| 1108 |
+
|
| 1109 |
+
if provider == "groq":
|
| 1110 |
+
groqkeys = os.getenv("GROQKEY")
|
| 1111 |
+
groqkeyslist = [k.strip() for k in groqkeys.split(",") if k.strip()] if groqkeys else []
|
| 1112 |
+
if not groqkeyslist:
|
| 1113 |
+
raise HTTPException(status_code=500, detail="Missing GROQKEYs")
|
| 1114 |
+
apikey = random.choice(groqkeyslist)
|
| 1115 |
+
url = "https://api.groq.com/openai/v1/chat/completions"
|
| 1116 |
+
headers = {"Authorization": f"Bearer {apikey}", "Content-Type": "application/json"}
|
| 1117 |
+
payload = {"model": chosenmodel, "messages": messages, "stream": False}
|
| 1118 |
+
async with httpx.AsyncClient(timeout=None) as client:
|
| 1119 |
+
r = await client.post(url, json=payload, headers=headers)
|
| 1120 |
+
if r.status_code != 200:
|
| 1121 |
+
raise HTTPException(status_code=r.status_code, detail=r.text[:1000])
|
| 1122 |
+
data = r.json()
|
| 1123 |
+
text = ""
|
| 1124 |
+
try:
|
| 1125 |
+
text = data["choices"][0]["message"]["content"] or ""
|
| 1126 |
+
except Exception:
|
| 1127 |
+
text = ""
|
| 1128 |
+
return {"text": text, "model": chosenmodel, "provider": provider, "raw": data}
|
| 1129 |
+
|
| 1130 |
+
if provider == "navy vision":
|
| 1131 |
+
navykeys = os.getenv("NAVYKEY")
|
| 1132 |
+
navykeyslist = [k.strip() for k in navykeys.split(",") if k.strip()] if navykeys else []
|
| 1133 |
+
if not navykeyslist:
|
| 1134 |
+
raise HTTPException(status_code=500, detail="Missing NAVYKEYs")
|
| 1135 |
+
apikey = random.choice(navykeyslist)
|
| 1136 |
+
url = "https://api.navy/v1/chat/completions"
|
| 1137 |
+
headers = {"Authorization": f"Bearer {apikey}", "Content-Type": "application/json"}
|
| 1138 |
+
payload = {"model": chosenmodel, "messages": messages, "stream": False}
|
| 1139 |
+
async with httpx.AsyncClient(timeout=None) as client:
|
| 1140 |
+
r = await client.post(url, json=payload, headers=headers)
|
| 1141 |
+
if r.status_code != 200:
|
| 1142 |
+
raise HTTPException(status_code=r.status_code, detail=r.text[:1000])
|
| 1143 |
+
data = r.json()
|
| 1144 |
+
text = ""
|
| 1145 |
+
try:
|
| 1146 |
+
text = data["choices"][0]["message"]["content"] or ""
|
| 1147 |
+
except Exception:
|
| 1148 |
+
text = ""
|
| 1149 |
+
return {"text": text, "model": chosenmodel, "provider": provider, "raw": data}
|
| 1150 |
+
|
| 1151 |
+
if provider == "navy":
|
| 1152 |
+
navykeys = os.getenv("NAVYTEXTONLY")
|
| 1153 |
+
navykeyslist = [k.strip() for k in navykeys.split(",") if k.strip()] if navykeys else []
|
| 1154 |
+
if not navykeyslist:
|
| 1155 |
+
raise HTTPException(status_code=500, detail="Missing NAVY TEXT ONLY keys")
|
| 1156 |
+
apikey = random.choice(navykeyslist)
|
| 1157 |
+
url = "https://api.navy/v1/chat/completions"
|
| 1158 |
+
headers = {"Authorization": f"Bearer {apikey}", "Content-Type": "application/json"}
|
| 1159 |
+
payload = {"model": chosenmodel, "messages": messages, "stream": False}
|
| 1160 |
+
async with httpx.AsyncClient(timeout=None) as client:
|
| 1161 |
+
r = await client.post(url, json=payload, headers=headers)
|
| 1162 |
+
if r.status_code != 200:
|
| 1163 |
+
raise HTTPException(status_code=r.status_code, detail=r.text[:1000])
|
| 1164 |
+
data = r.json()
|
| 1165 |
+
text = ""
|
| 1166 |
+
try:
|
| 1167 |
+
text = data["choices"][0]["message"]["content"] or ""
|
| 1168 |
+
except Exception:
|
| 1169 |
+
text = ""
|
| 1170 |
+
return {"text": text, "model": chosenmodel, "provider": provider, "raw": data}
|
| 1171 |
+
|
| 1172 |
+
raise HTTPException(status_code=500, detail="Unknown provider routing error")
|
| 1173 |
+
|
| 1174 |
+
@router.post("/responses")
|
| 1175 |
+
async def create_responses(
|
| 1176 |
+
request: Request,
|
| 1177 |
+
authorization: Optional[str] = Header(None),
|
| 1178 |
+
xclientid: Optional[str] = Header(None),
|
| 1179 |
+
):
|
| 1180 |
+
body = await request.json()
|
| 1181 |
+
model = body.get("model")
|
| 1182 |
+
input_data = body.get("input")
|
| 1183 |
+
instructions = body.get("instructions")
|
| 1184 |
+
stream = body.get("stream", True)
|
| 1185 |
+
response_format = body.get("response_format")
|
| 1186 |
+
|
| 1187 |
+
if not model:
|
| 1188 |
+
raise HTTPException(status_code=400, detail="model is required")
|
| 1189 |
+
if input_data is None:
|
| 1190 |
+
raise HTTPException(status_code=400, detail="input is required")
|
| 1191 |
+
|
| 1192 |
+
messages = _responses_input_to_messages(input_data, instructions=instructions)
|
| 1193 |
+
if not messages:
|
| 1194 |
+
raise HTTPException(status_code=400, detail="input could not be parsed")
|
| 1195 |
+
|
| 1196 |
+
if stream is False:
|
| 1197 |
+
result = await _generate_text_from_messages(
|
| 1198 |
+
request=request,
|
| 1199 |
+
messages=messages,
|
| 1200 |
+
authorization=authorization,
|
| 1201 |
+
xclientid=xclientid,
|
| 1202 |
+
)
|
| 1203 |
+
if "text" not in result:
|
| 1204 |
+
raise HTTPException(status_code=500, detail="upstream generation failed")
|
| 1205 |
+
return JSONResponse(content=_openai_responses_payload(model, result["text"]))
|
| 1206 |
+
|
| 1207 |
+
async def event_stream():
|
| 1208 |
+
response_id = _resp_id("resp")
|
| 1209 |
+
created = {
|
| 1210 |
+
"type": "response.created",
|
| 1211 |
+
"response": {
|
| 1212 |
+
"id": response_id,
|
| 1213 |
+
"object": "response",
|
| 1214 |
+
"created_at": _resp_ts(),
|
| 1215 |
+
"status": "in_progress",
|
| 1216 |
+
"model": model
|
| 1217 |
+
}
|
| 1218 |
+
}
|
| 1219 |
+
yield f"data: {json.dumps(created)}\n\n"
|
| 1220 |
+
|
| 1221 |
+
result = await _generate_text_from_messages(
|
| 1222 |
+
request=request,
|
| 1223 |
+
messages=messages,
|
| 1224 |
+
authorization=authorization,
|
| 1225 |
+
xclientid=xclientid,
|
| 1226 |
+
)
|
| 1227 |
+
|
| 1228 |
+
if "text" in result:
|
| 1229 |
+
text = result["text"]
|
| 1230 |
+
if text:
|
| 1231 |
+
chunk_size = 64
|
| 1232 |
+
for i in range(0, len(text), chunk_size):
|
| 1233 |
+
delta = text[i:i + chunk_size]
|
| 1234 |
+
evt = {
|
| 1235 |
+
"type": "response.output_text.delta",
|
| 1236 |
+
"response_id": response_id,
|
| 1237 |
+
"delta": delta
|
| 1238 |
+
}
|
| 1239 |
+
yield f"data: {json.dumps(evt)}\n\n"
|
| 1240 |
+
|
| 1241 |
+
completed = {
|
| 1242 |
+
"type": "response.completed",
|
| 1243 |
+
"response": {
|
| 1244 |
+
"id": response_id,
|
| 1245 |
+
"object": "response",
|
| 1246 |
+
"created_at": _resp_ts(),
|
| 1247 |
+
"status": "completed",
|
| 1248 |
+
"completed_at": _resp_ts(),
|
| 1249 |
+
"model": model,
|
| 1250 |
+
"output_text": result["text"],
|
| 1251 |
+
"output": [
|
| 1252 |
+
{
|
| 1253 |
+
"id": _resp_id("msg"),
|
| 1254 |
+
"type": "message",
|
| 1255 |
+
"role": "assistant",
|
| 1256 |
+
"status": "completed",
|
| 1257 |
+
"content": [
|
| 1258 |
+
{
|
| 1259 |
+
"type": "output_text",
|
| 1260 |
+
"text": result["text"],
|
| 1261 |
+
"annotations": []
|
| 1262 |
+
}
|
| 1263 |
+
]
|
| 1264 |
+
}
|
| 1265 |
+
],
|
| 1266 |
+
"usage": {
|
| 1267 |
+
"input_tokens": 0,
|
| 1268 |
+
"output_tokens": 0,
|
| 1269 |
+
"total_tokens": 0
|
| 1270 |
+
}
|
| 1271 |
+
}
|
| 1272 |
+
}
|
| 1273 |
+
yield f"data: {json.dumps(completed)}\n\n"
|
| 1274 |
+
yield "data: [DONE]\n\n"
|
| 1275 |
+
return
|
| 1276 |
+
|
| 1277 |
+
err = {
|
| 1278 |
+
"type": "response.error",
|
| 1279 |
+
"error": result.get("error", {"message": "upstream error"})
|
| 1280 |
+
}
|
| 1281 |
+
yield f"data: {json.dumps(err)}\n\n"
|
| 1282 |
+
yield "data: [DONE]\n\n"
|
| 1283 |
+
|
| 1284 |
+
return StreamingResponse(
|
| 1285 |
+
event_stream(),
|
| 1286 |
+
media_type="text/event-stream",
|
| 1287 |
+
headers={
|
| 1288 |
+
"Cache-Control": "no-cache",
|
| 1289 |
+
"Connection": "keep-alive",
|
| 1290 |
+
"X-Accel-Buffering": "no",
|
| 1291 |
+
},
|
| 1292 |
+
)
|