eeshwar143 commited on
Commit
a979929
·
1 Parent(s): 4c46ae2

Support both old and new OpenAI Python clients

Browse files
Files changed (1) hide show
  1. inference.py +42 -15
inference.py CHANGED
@@ -5,7 +5,12 @@ import json
5
  import os
6
  from typing import Any, List
7
 
8
- from openai import OpenAI
 
 
 
 
 
9
 
10
  from support_queue_env.client import SupportQueueEnv
11
  from support_queue_env.models import TaskCard, SupportQueueAction, SupportQueueObservation
@@ -40,8 +45,17 @@ def log_end(success: bool, steps: int, score: float, rewards: list[float]) -> No
40
  )
41
 
42
 
 
 
 
 
 
 
 
 
 
43
  def get_model_message(
44
- client: OpenAI,
45
  step: int,
46
  observation: SupportQueueObservation,
47
  last_reward: float,
@@ -52,17 +66,30 @@ def get_model_message(
52
  f"Step: {step}. Last reward: {last_reward:.4f}. History: {history[-4:]}. Observation: {observation.model_dump_json()}"
53
  )
54
  try:
55
- completion = client.chat.completions.create(
56
- model=MODEL_NAME,
57
- messages=[
58
- {"role": "system", "content": "You are assisting a support triage agent."},
59
- {"role": "user", "content": prompt},
60
- ],
61
- temperature=0.0,
62
- max_tokens=MAX_TOKENS,
63
- stream=False,
64
- )
65
- text = (completion.choices[0].message.content or "").strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  return text if text else "hello"
67
  except Exception as exc:
68
  print(f"[DEBUG] Model request failed: {exc}", flush=True)
@@ -188,7 +215,7 @@ def heuristic_action(observation: SupportQueueObservation) -> SupportQueueAction
188
  )
189
 
190
 
191
- async def run_task(client: OpenAI, env: SupportQueueEnv, task: TaskCard) -> dict[str, Any]:
192
  history: List[str] = []
193
  rewards: List[float] = []
194
  steps_taken = 0
@@ -252,7 +279,7 @@ async def run_task(client: OpenAI, env: SupportQueueEnv, task: TaskCard) -> dict
252
 
253
 
254
  async def main() -> None:
255
- client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN or "placeholder")
256
  tasks = available_tasks()
257
  results: list[dict[str, Any]] = []
258
  env: SupportQueueEnv | None = None
 
5
  import os
6
  from typing import Any, List
7
 
8
+ try:
9
+ from openai import OpenAI
10
+ import openai as openai_module
11
+ except ImportError:
12
+ OpenAI = None
13
+ import openai as openai_module
14
 
15
  from support_queue_env.client import SupportQueueEnv
16
  from support_queue_env.models import TaskCard, SupportQueueAction, SupportQueueObservation
 
45
  )
46
 
47
 
48
+ def create_openai_client() -> Any:
49
+ if OpenAI is not None:
50
+ return OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN or "placeholder")
51
+
52
+ openai_module.api_base = API_BASE_URL
53
+ openai_module.api_key = HF_TOKEN or "placeholder"
54
+ return openai_module
55
+
56
+
57
  def get_model_message(
58
+ client: Any,
59
  step: int,
60
  observation: SupportQueueObservation,
61
  last_reward: float,
 
66
  f"Step: {step}. Last reward: {last_reward:.4f}. History: {history[-4:]}. Observation: {observation.model_dump_json()}"
67
  )
68
  try:
69
+ if hasattr(client, "chat") and hasattr(client.chat, "completions"):
70
+ completion = client.chat.completions.create(
71
+ model=MODEL_NAME,
72
+ messages=[
73
+ {"role": "system", "content": "You are assisting a support triage agent."},
74
+ {"role": "user", "content": prompt},
75
+ ],
76
+ temperature=0.0,
77
+ max_tokens=MAX_TOKENS,
78
+ stream=False,
79
+ )
80
+ text = (completion.choices[0].message.content or "").strip()
81
+ else:
82
+ completion = client.ChatCompletion.create(
83
+ model=MODEL_NAME,
84
+ messages=[
85
+ {"role": "system", "content": "You are assisting a support triage agent."},
86
+ {"role": "user", "content": prompt},
87
+ ],
88
+ temperature=0.0,
89
+ max_tokens=MAX_TOKENS,
90
+ stream=False,
91
+ )
92
+ text = (completion["choices"][0]["message"]["content"] or "").strip()
93
  return text if text else "hello"
94
  except Exception as exc:
95
  print(f"[DEBUG] Model request failed: {exc}", flush=True)
 
215
  )
216
 
217
 
218
+ async def run_task(client: Any, env: SupportQueueEnv, task: TaskCard) -> dict[str, Any]:
219
  history: List[str] = []
220
  rewards: List[float] = []
221
  steps_taken = 0
 
279
 
280
 
281
  async def main() -> None:
282
+ client = create_openai_client()
283
  tasks = available_tasks()
284
  results: list[dict[str, Any]] = []
285
  env: SupportQueueEnv | None = None