update inference.py
Browse files- inference.py +18 -22
inference.py
CHANGED
|
@@ -6,9 +6,8 @@ from openai import OpenAI
|
|
| 6 |
import requests
|
| 7 |
|
| 8 |
|
| 9 |
-
API_BASE_URL = os.
|
| 10 |
-
|
| 11 |
-
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 12 |
|
| 13 |
# Quant-Gym configuration
|
| 14 |
BASE_URL = os.getenv("BASE_URL", "http://localhost:8000")
|
|
@@ -111,11 +110,6 @@ class QuantGymClient:
|
|
| 111 |
def get_model_action(client: OpenAI, step: int, observation: dict, history: List[str]) -> str:
|
| 112 |
"""Get action from LLM using the judge's proxy"""
|
| 113 |
|
| 114 |
-
# CRITICAL: Must use client with API_BASE_URL from judge
|
| 115 |
-
if not API_BASE_URL:
|
| 116 |
-
print("[WARNING] API_BASE_URL not set! Using fallback.", flush=True)
|
| 117 |
-
return fallback_strategy(observation)
|
| 118 |
-
|
| 119 |
user_prompt = textwrap.dedent(
|
| 120 |
f"""
|
| 121 |
Step: {step}
|
|
@@ -130,9 +124,9 @@ def get_model_action(client: OpenAI, step: int, observation: dict, history: List
|
|
| 130 |
).strip()
|
| 131 |
|
| 132 |
try:
|
| 133 |
-
# This MUST go through their proxy
|
| 134 |
completion = client.chat.completions.create(
|
| 135 |
-
model=
|
| 136 |
messages=[
|
| 137 |
{"role": "system", "content": SYSTEM_PROMPT},
|
| 138 |
{"role": "user", "content": user_prompt},
|
|
@@ -181,18 +175,24 @@ def fallback_strategy(observation: dict) -> str:
|
|
| 181 |
async def main() -> None:
|
| 182 |
print("[INFO] Starting Quant-Gym Inference", flush=True)
|
| 183 |
|
| 184 |
-
# CRITICAL:
|
| 185 |
if not API_BASE_URL:
|
| 186 |
print("[ERROR] API_BASE_URL environment variable not set!", flush=True)
|
| 187 |
print("[ERROR] This must be provided by the hackathon judge.", flush=True)
|
| 188 |
-
|
| 189 |
|
| 190 |
-
|
| 191 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
client = OpenAI(
|
| 193 |
base_url=API_BASE_URL, # Their proxy URL
|
| 194 |
-
api_key=
|
| 195 |
-
)
|
| 196 |
|
| 197 |
env = QuantGymClient(BASE_URL)
|
| 198 |
|
|
@@ -202,18 +202,14 @@ async def main() -> None:
|
|
| 202 |
success = False
|
| 203 |
final_score = 0.0
|
| 204 |
|
| 205 |
-
log_start(task=TASK_NAME, env=BENCHMARK, model=
|
| 206 |
|
| 207 |
try:
|
| 208 |
result = env.reset()
|
| 209 |
observation = result.get('observation', {})
|
| 210 |
|
| 211 |
for step in range(1, MAX_STEPS + 1):
|
| 212 |
-
|
| 213 |
-
if client:
|
| 214 |
-
action_str = get_model_action(client, step, observation, history)
|
| 215 |
-
else:
|
| 216 |
-
action_str = fallback_strategy(observation)
|
| 217 |
|
| 218 |
result = env.step(action_str)
|
| 219 |
observation = result.get('observation', {})
|
|
|
|
| 6 |
import requests
|
| 7 |
|
| 8 |
|
| 9 |
+
API_BASE_URL = os.environ.get("API_BASE_URL")
|
| 10 |
+
API_KEY = os.environ.get("API_KEY")
|
|
|
|
| 11 |
|
| 12 |
# Quant-Gym configuration
|
| 13 |
BASE_URL = os.getenv("BASE_URL", "http://localhost:8000")
|
|
|
|
| 110 |
def get_model_action(client: OpenAI, step: int, observation: dict, history: List[str]) -> str:
|
| 111 |
"""Get action from LLM using the judge's proxy"""
|
| 112 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
user_prompt = textwrap.dedent(
|
| 114 |
f"""
|
| 115 |
Step: {step}
|
|
|
|
| 124 |
).strip()
|
| 125 |
|
| 126 |
try:
|
| 127 |
+
# CRITICAL: This MUST go through their proxy using BOTH env vars
|
| 128 |
completion = client.chat.completions.create(
|
| 129 |
+
model="gpt-3.5-turbo", # Their proxy expects this
|
| 130 |
messages=[
|
| 131 |
{"role": "system", "content": SYSTEM_PROMPT},
|
| 132 |
{"role": "user", "content": user_prompt},
|
|
|
|
| 175 |
async def main() -> None:
|
| 176 |
print("[INFO] Starting Quant-Gym Inference", flush=True)
|
| 177 |
|
| 178 |
+
# CRITICAL CHECK: Both environment variables MUST be set
|
| 179 |
if not API_BASE_URL:
|
| 180 |
print("[ERROR] API_BASE_URL environment variable not set!", flush=True)
|
| 181 |
print("[ERROR] This must be provided by the hackathon judge.", flush=True)
|
| 182 |
+
return
|
| 183 |
|
| 184 |
+
if not API_KEY:
|
| 185 |
+
print("[ERROR] API_KEY environment variable not set!", flush=True)
|
| 186 |
+
print("[ERROR] This must be provided by the hackathon judge.", flush=True)
|
| 187 |
+
return
|
| 188 |
+
|
| 189 |
+
print(f"[INFO] Using API_BASE_URL: {API_BASE_URL}", flush=True)
|
| 190 |
+
|
| 191 |
+
# Initialize OpenAI client with judge's proxy - MUST use BOTH
|
| 192 |
client = OpenAI(
|
| 193 |
base_url=API_BASE_URL, # Their proxy URL
|
| 194 |
+
api_key=API_KEY, # Their API key
|
| 195 |
+
)
|
| 196 |
|
| 197 |
env = QuantGymClient(BASE_URL)
|
| 198 |
|
|
|
|
| 202 |
success = False
|
| 203 |
final_score = 0.0
|
| 204 |
|
| 205 |
+
log_start(task=TASK_NAME, env=BENCHMARK, model="gpt-3.5-turbo")
|
| 206 |
|
| 207 |
try:
|
| 208 |
result = env.reset()
|
| 209 |
observation = result.get('observation', {})
|
| 210 |
|
| 211 |
for step in range(1, MAX_STEPS + 1):
|
| 212 |
+
action_str = get_model_action(client, step, observation, history)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
|
| 214 |
result = env.step(action_str)
|
| 215 |
observation = result.get('observation', {})
|