xinjjj commited on
Commit
92a557f
·
1 Parent(s): a7dabe5

feat(gpt): add GPT-5.4 support

Browse files
embodied_gen/utils/gpt_clients.py CHANGED
@@ -46,6 +46,9 @@ __all__ = [
46
  _CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
47
  CONFIG_FILE = os.path.join(_CURRENT_DIR, "gpt_config.yaml")
48
  DEFAULT_GPT_TIMEOUT = float(os.environ.get("GPT_TIMEOUT", 120))
 
 
 
49
 
50
 
51
  def combine_images_to_grid(
@@ -148,6 +151,11 @@ class GPTclient:
148
 
149
  logger.info(f"Using GPT model: {self.model_name}.")
150
 
 
 
 
 
 
151
  @retry(
152
  retry=retry_if_not_exception_type(openai.BadRequestError),
153
  wait=wait_random_exponential(min=1, max=10),
@@ -215,21 +223,49 @@ class GPTclient:
215
  }
216
  )
217
 
218
- payload = {
219
- "messages": [
220
- {"role": "system", "content": system_role},
221
- {"role": "user", "content": content_user},
222
- ],
223
- "temperature": 0.1,
224
- "max_tokens": 500,
225
- "top_p": 0.1,
226
- "frequency_penalty": 0,
227
- "presence_penalty": 0,
228
- "stop": None,
229
- "model": self.model_name,
230
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
  if params:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
  payload.update(params)
234
 
235
  response = None
@@ -253,15 +289,19 @@ class GPTclient:
253
  ConnectionError: If connection fails.
254
  """
255
  try:
256
- response = self.completion_with_backoff(
257
  messages=[
258
  {"role": "system", "content": "You are a test system."},
259
  {"role": "user", "content": "Hello"},
260
  ],
261
  model=self.model_name,
262
- temperature=0,
263
- max_tokens=100,
264
  )
 
 
 
 
 
 
265
  response.choices[0].message.content
266
  logger.info("Connection check success.")
267
  except Exception:
 
46
  _CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
47
  CONFIG_FILE = os.path.join(_CURRENT_DIR, "gpt_config.yaml")
48
  DEFAULT_GPT_TIMEOUT = float(os.environ.get("GPT_TIMEOUT", 120))
49
+ # GPT-5.x counts reasoning tokens against this cap, so it must be high
50
+ # enough to leave room for both reasoning and the visible reply.
51
+ GPT5_DEFAULT_MAX_COMPLETION_TOKENS = 8192
52
 
53
 
54
  def combine_images_to_grid(
 
151
 
152
  logger.info(f"Using GPT model: {self.model_name}.")
153
 
154
+ @staticmethod
155
+ def _is_gpt5_model(model_name: str) -> bool:
156
+ name = (model_name or "").lower()
157
+ return "gpt-5" in name or "gpt5" in name
158
+
159
  @retry(
160
  retry=retry_if_not_exception_type(openai.BadRequestError),
161
  wait=wait_random_exponential(min=1, max=10),
 
223
  }
224
  )
225
 
226
+ is_gpt5 = self._is_gpt5_model(self.model_name)
227
+ if is_gpt5:
228
+ # GPT-5.x only supports default temperature/top_p and uses
229
+ # `max_completion_tokens` instead of `max_tokens`.
230
+ payload = {
231
+ "messages": [
232
+ {"role": "system", "content": system_role},
233
+ {"role": "user", "content": content_user},
234
+ ],
235
+ "max_completion_tokens": GPT5_DEFAULT_MAX_COMPLETION_TOKENS,
236
+ "model": self.model_name,
237
+ }
238
+ else:
239
+ payload = {
240
+ "messages": [
241
+ {"role": "system", "content": system_role},
242
+ {"role": "user", "content": content_user},
243
+ ],
244
+ "temperature": 0.1,
245
+ "max_tokens": 500,
246
+ "top_p": 0.1,
247
+ "frequency_penalty": 0,
248
+ "presence_penalty": 0,
249
+ "stop": None,
250
+ "model": self.model_name,
251
+ }
252
 
253
  if params:
254
+ params = dict(params)
255
+ if is_gpt5:
256
+ # GPT-5.x rejects custom temperature/top_p/penalty/stop and
257
+ # uses `max_completion_tokens` instead of `max_tokens`.
258
+ if "max_tokens" in params and "max_completion_tokens" not in params:
259
+ params["max_completion_tokens"] = params.pop("max_tokens")
260
+ for k in (
261
+ "temperature",
262
+ "top_p",
263
+ "frequency_penalty",
264
+ "presence_penalty",
265
+ "stop",
266
+ "max_tokens",
267
+ ):
268
+ params.pop(k, None)
269
  payload.update(params)
270
 
271
  response = None
 
289
  ConnectionError: If connection fails.
290
  """
291
  try:
292
+ probe_kwargs = dict(
293
  messages=[
294
  {"role": "system", "content": "You are a test system."},
295
  {"role": "user", "content": "Hello"},
296
  ],
297
  model=self.model_name,
 
 
298
  )
299
+ if self._is_gpt5_model(self.model_name):
300
+ probe_kwargs["max_completion_tokens"] = 100
301
+ else:
302
+ probe_kwargs["temperature"] = 0
303
+ probe_kwargs["max_tokens"] = 100
304
+ response = self.completion_with_backoff(**probe_kwargs)
305
  response.choices[0].message.content
306
  logger.info("Connection check success.")
307
  except Exception:
embodied_gen/utils/gpt_config.yaml CHANGED
@@ -1,5 +1,5 @@
1
  # config.yaml
2
- agent_type: "qwen2.5-vl" # gpt-4o or qwen2.5-vl
3
 
4
  gpt-4o:
5
  endpoint: https://xxx.openai.azure.com
@@ -7,6 +7,12 @@ gpt-4o:
7
  api_version: 2025-xx-xx
8
  model_name: yfb-gpt-4o
9
 
 
 
 
 
 
 
10
  qwen2.5-vl:
11
  endpoint: https://openrouter.ai/api/v1
12
  api_key: sk-or-v1-xxx
 
1
  # config.yaml
2
+ agent_type: "gpt-5.4" # gpt-4o, gpt-5.4 or qwen2.5-vl
3
 
4
  gpt-4o:
5
  endpoint: https://xxx.openai.azure.com
 
7
  api_version: 2025-xx-xx
8
  model_name: yfb-gpt-4o
9
 
10
+ gpt-5.4:
11
+ endpoint: https://yfb-openai-sweden.openai.azure.com/
12
+ api_key: xxx
13
+ api_version: 2024-12-01-preview
14
+ model_name: gpt-5.4
15
+
16
  qwen2.5-vl:
17
  endpoint: https://openrouter.ai/api/v1
18
  api_key: sk-or-v1-xxx