Astocoder commited on
Commit
9ca0bab
·
1 Parent(s): cf294ef

update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +18 -22
inference.py CHANGED
@@ -6,9 +6,8 @@ from openai import OpenAI
6
  import requests
7
 
8
 
9
- API_BASE_URL = os.getenv("API_BASE_URL") # NO DEFAULT
10
- MODEL_NAME = os.getenv("MODEL_NAME", "gpt-3.5-turbo") # default
11
- HF_TOKEN = os.getenv("HF_TOKEN")
12
 
13
  # Quant-Gym configuration
14
  BASE_URL = os.getenv("BASE_URL", "http://localhost:8000")
@@ -111,11 +110,6 @@ class QuantGymClient:
111
  def get_model_action(client: OpenAI, step: int, observation: dict, history: List[str]) -> str:
112
  """Get action from LLM using the judge's proxy"""
113
 
114
- # CRITICAL: Must use client with API_BASE_URL from judge
115
- if not API_BASE_URL:
116
- print("[WARNING] API_BASE_URL not set! Using fallback.", flush=True)
117
- return fallback_strategy(observation)
118
-
119
  user_prompt = textwrap.dedent(
120
  f"""
121
  Step: {step}
@@ -130,9 +124,9 @@ def get_model_action(client: OpenAI, step: int, observation: dict, history: List
130
  ).strip()
131
 
132
  try:
133
- # This MUST go through their proxy
134
  completion = client.chat.completions.create(
135
- model=MODEL_NAME,
136
  messages=[
137
  {"role": "system", "content": SYSTEM_PROMPT},
138
  {"role": "user", "content": user_prompt},
@@ -181,18 +175,24 @@ def fallback_strategy(observation: dict) -> str:
181
  async def main() -> None:
182
  print("[INFO] Starting Quant-Gym Inference", flush=True)
183
 
184
- # CRITICAL: Check if API_BASE_URL is provided by judge
185
  if not API_BASE_URL:
186
  print("[ERROR] API_BASE_URL environment variable not set!", flush=True)
187
  print("[ERROR] This must be provided by the hackathon judge.", flush=True)
188
- print("[INFO] Using fallback strategy without LLM.", flush=True)
189
 
190
- # Initialize OpenAI client with judge's proxy URL
191
- # DO NOT use default - MUST use their provided URL
 
 
 
 
 
 
192
  client = OpenAI(
193
  base_url=API_BASE_URL, # Their proxy URL
194
- api_key="dummy" # Their proxy may not need a real key
195
- ) if API_BASE_URL else None
196
 
197
  env = QuantGymClient(BASE_URL)
198
 
@@ -202,18 +202,14 @@ async def main() -> None:
202
  success = False
203
  final_score = 0.0
204
 
205
- log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
206
 
207
  try:
208
  result = env.reset()
209
  observation = result.get('observation', {})
210
 
211
  for step in range(1, MAX_STEPS + 1):
212
- # Get action - this will use their proxy if available
213
- if client:
214
- action_str = get_model_action(client, step, observation, history)
215
- else:
216
- action_str = fallback_strategy(observation)
217
 
218
  result = env.step(action_str)
219
  observation = result.get('observation', {})
 
6
  import requests
7
 
8
 
9
+ API_BASE_URL = os.environ.get("API_BASE_URL")
10
+ API_KEY = os.environ.get("API_KEY")
 
11
 
12
  # Quant-Gym configuration
13
  BASE_URL = os.getenv("BASE_URL", "http://localhost:8000")
 
110
  def get_model_action(client: OpenAI, step: int, observation: dict, history: List[str]) -> str:
111
  """Get action from LLM using the judge's proxy"""
112
 
 
 
 
 
 
113
  user_prompt = textwrap.dedent(
114
  f"""
115
  Step: {step}
 
124
  ).strip()
125
 
126
  try:
127
+ # CRITICAL: This MUST go through their proxy using BOTH env vars
128
  completion = client.chat.completions.create(
129
+ model="gpt-3.5-turbo", # Their proxy expects this
130
  messages=[
131
  {"role": "system", "content": SYSTEM_PROMPT},
132
  {"role": "user", "content": user_prompt},
 
175
  async def main() -> None:
176
  print("[INFO] Starting Quant-Gym Inference", flush=True)
177
 
178
+ # CRITICAL CHECK: Both environment variables MUST be set
179
  if not API_BASE_URL:
180
  print("[ERROR] API_BASE_URL environment variable not set!", flush=True)
181
  print("[ERROR] This must be provided by the hackathon judge.", flush=True)
182
+ return
183
 
184
+ if not API_KEY:
185
+ print("[ERROR] API_KEY environment variable not set!", flush=True)
186
+ print("[ERROR] This must be provided by the hackathon judge.", flush=True)
187
+ return
188
+
189
+ print(f"[INFO] Using API_BASE_URL: {API_BASE_URL}", flush=True)
190
+
191
+ # Initialize OpenAI client with judge's proxy - MUST use BOTH
192
  client = OpenAI(
193
  base_url=API_BASE_URL, # Their proxy URL
194
+ api_key=API_KEY, # Their API key
195
+ )
196
 
197
  env = QuantGymClient(BASE_URL)
198
 
 
202
  success = False
203
  final_score = 0.0
204
 
205
+ log_start(task=TASK_NAME, env=BENCHMARK, model="gpt-3.5-turbo")
206
 
207
  try:
208
  result = env.reset()
209
  observation = result.get('observation', {})
210
 
211
  for step in range(1, MAX_STEPS + 1):
212
+ action_str = get_model_action(client, step, observation, history)
 
 
 
 
213
 
214
  result = env.step(action_str)
215
  observation = result.get('observation', {})