Update inference.py
Browse files- inference.py +20 -15
inference.py
CHANGED
|
@@ -7,17 +7,19 @@ from env import EmailTriageEnv
|
|
| 7 |
from app import smart_agent_logic
|
| 8 |
|
| 9 |
|
| 10 |
-
# β
|
| 11 |
API_BASE_URL = os.environ.get("API_BASE_URL")
|
| 12 |
API_KEY = os.environ.get("API_KEY")
|
| 13 |
MODEL_NAME = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
|
| 14 |
|
| 15 |
-
TASK_NAME = os.getenv("MY_ENV_V4_TASK", "easy")
|
| 16 |
BENCHMARK = "email_triage_env"
|
| 17 |
|
| 18 |
MAX_STEPS = 20
|
| 19 |
SUCCESS_SCORE_THRESHOLD = 0.5
|
| 20 |
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
def log_start(task: str, env: str, model: str) -> None:
|
| 23 |
print(f"[START] task={task} env={env} model={model}", flush=True)
|
|
@@ -42,14 +44,7 @@ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> No
|
|
| 42 |
)
|
| 43 |
|
| 44 |
|
| 45 |
-
def
|
| 46 |
-
# β
REQUIRED: Initialize OpenAI client with provided proxy
|
| 47 |
-
try:
|
| 48 |
-
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
|
| 49 |
-
except Exception as e:
|
| 50 |
-
print(f"[DEBUG] OpenAI init failed: {e}", flush=True)
|
| 51 |
-
client = None
|
| 52 |
-
|
| 53 |
env = EmailTriageEnv(task=TASK_NAME)
|
| 54 |
|
| 55 |
rewards: List[float] = []
|
|
@@ -69,9 +64,9 @@ def main():
|
|
| 69 |
try:
|
| 70 |
desc = state["description"]
|
| 71 |
|
| 72 |
-
# β
π₯ LLM CALL (MANDATORY)
|
| 73 |
action_list = None
|
| 74 |
|
|
|
|
| 75 |
if client:
|
| 76 |
try:
|
| 77 |
response = client.chat.completions.create(
|
|
@@ -91,17 +86,15 @@ def main():
|
|
| 91 |
)
|
| 92 |
|
| 93 |
text = response.choices[0].message.content.strip()
|
| 94 |
-
|
| 95 |
-
# Parse response
|
| 96 |
action_list = [int(x) for x in text.replace(",", " ").split()[:3]]
|
| 97 |
|
| 98 |
if len(action_list) != 3:
|
| 99 |
-
raise ValueError(
|
| 100 |
|
| 101 |
except Exception as llm_error:
|
| 102 |
print(f"[DEBUG] LLM failed: {llm_error}", flush=True)
|
| 103 |
|
| 104 |
-
#
|
| 105 |
if not action_list:
|
| 106 |
action_list = smart_agent_logic(desc)
|
| 107 |
|
|
@@ -129,5 +122,17 @@ def main():
|
|
| 129 |
log_end(success, steps_taken, score, rewards)
|
| 130 |
|
| 131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
if __name__ == "__main__":
|
| 133 |
main()
|
|
|
|
| 7 |
from app import smart_agent_logic
|
| 8 |
|
| 9 |
|
| 10 |
+
# β
REQUIRED env vars
|
| 11 |
API_BASE_URL = os.environ.get("API_BASE_URL")
|
| 12 |
API_KEY = os.environ.get("API_KEY")
|
| 13 |
MODEL_NAME = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
|
| 14 |
|
|
|
|
| 15 |
BENCHMARK = "email_triage_env"
|
| 16 |
|
| 17 |
MAX_STEPS = 20
|
| 18 |
SUCCESS_SCORE_THRESHOLD = 0.5
|
| 19 |
|
| 20 |
+
# β
RUN ALL TASKS
|
| 21 |
+
TASKS = ["easy", "medium", "hard"]
|
| 22 |
+
|
| 23 |
|
| 24 |
def log_start(task: str, env: str, model: str) -> None:
|
| 25 |
print(f"[START] task={task} env={env} model={model}", flush=True)
|
|
|
|
| 44 |
)
|
| 45 |
|
| 46 |
|
| 47 |
+
def run_task(client, TASK_NAME):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
env = EmailTriageEnv(task=TASK_NAME)
|
| 49 |
|
| 50 |
rewards: List[float] = []
|
|
|
|
| 64 |
try:
|
| 65 |
desc = state["description"]
|
| 66 |
|
|
|
|
| 67 |
action_list = None
|
| 68 |
|
| 69 |
+
# β
LLM CALL
|
| 70 |
if client:
|
| 71 |
try:
|
| 72 |
response = client.chat.completions.create(
|
|
|
|
| 86 |
)
|
| 87 |
|
| 88 |
text = response.choices[0].message.content.strip()
|
|
|
|
|
|
|
| 89 |
action_list = [int(x) for x in text.replace(",", " ").split()[:3]]
|
| 90 |
|
| 91 |
if len(action_list) != 3:
|
| 92 |
+
raise ValueError()
|
| 93 |
|
| 94 |
except Exception as llm_error:
|
| 95 |
print(f"[DEBUG] LLM failed: {llm_error}", flush=True)
|
| 96 |
|
| 97 |
+
# fallback
|
| 98 |
if not action_list:
|
| 99 |
action_list = smart_agent_logic(desc)
|
| 100 |
|
|
|
|
| 122 |
log_end(success, steps_taken, score, rewards)
|
| 123 |
|
| 124 |
|
| 125 |
+
def main():
|
| 126 |
+
try:
|
| 127 |
+
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
|
| 128 |
+
except Exception as e:
|
| 129 |
+
print(f"[DEBUG] OpenAI init failed: {e}", flush=True)
|
| 130 |
+
client = None
|
| 131 |
+
|
| 132 |
+
# β
RUN ALL TASKS
|
| 133 |
+
for task in TASKS:
|
| 134 |
+
run_task(client, task)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
if __name__ == "__main__":
|
| 138 |
main()
|