parthpethia commited on
Commit
285e6b6
·
1 Parent(s): 63df97b

Use validator's API_KEY and API_BASE_URL from environment

Browse files
Files changed (1) hide show
  1. inference.py +29 -24
inference.py CHANGED
@@ -18,10 +18,12 @@ except ImportError:
18
  from environment.env import EmailTriageEnv
19
  from environment.types import Action, EmailCategory, Team
20
 
21
- # Environment variables
 
 
 
22
  API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
23
  MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini")
24
- API_KEY = os.getenv("OPENAI_API_KEY")
25
 
26
  # Configuration
27
  MAX_STEPS = 50
@@ -200,9 +202,21 @@ def main() -> None:
200
  tasks = ["spam_detection", "multi_class_routing", "context_aware_triage"]
201
  all_scores = []
202
 
203
- if not API_KEY or not OPENAI_AVAILABLE:
204
- # Demo/Validation mode: No API key provided
205
- print("[WARNING] OPENAI_API_KEY not set. Running in validation mode.",
 
 
 
 
 
 
 
 
 
 
 
 
206
  flush=True)
207
 
208
  for task in tasks:
@@ -289,25 +303,16 @@ def main() -> None:
289
  flush=True)
290
 
291
  else:
292
- # Normal mode: Use OpenAI API
293
- try:
294
- client = OpenAI(api_key=API_KEY, base_url=API_BASE_URL if API_BASE_URL
295
- else None)
296
- except Exception as e:
297
- print(f"[ERROR] Failed to initialize OpenAI client: {e}",
298
- file=sys.stderr, flush=True)
299
- all_scores = [0.0, 0.0, 0.0]
300
-
301
- if client is not None:
302
- for task in tasks:
303
- try:
304
- success, steps, score, rewards = run_task(client, task)
305
- all_scores.append(score)
306
- print(f"[TASK_SUMMARY] {task}: score={score:.3f} steps={steps}",
307
- flush=True)
308
- except Exception as e:
309
- print(f"[TASK_ERROR] {task}: {e}", file=sys.stderr, flush=True)
310
- all_scores.append(0.0)
311
 
312
  # Final summary
313
  avg_score = sum(all_scores) / len(all_scores) if all_scores else 0.0
 
18
  from environment.env import EmailTriageEnv
19
  from environment.types import Action, EmailCategory, Team
20
 
21
+ # Environment variables - check both formats
22
+ # Validator provides: API_KEY and API_BASE_URL
23
+ # Local usage: OPENAI_API_KEY
24
+ API_KEY = os.getenv("API_KEY") or os.getenv("OPENAI_API_KEY")
25
  API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
26
  MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini")
 
27
 
28
  # Configuration
29
  MAX_STEPS = 50
 
202
  tasks = ["spam_detection", "multi_class_routing", "context_aware_triage"]
203
  all_scores = []
204
 
205
+ # Try to initialize OpenAI client if API key is available
206
+ client = None
207
+ if API_KEY and OPENAI_AVAILABLE:
208
+ try:
209
+ # Initialize with validator's provided API_BASE_URL and API_KEY
210
+ client = OpenAI(api_key=API_KEY, base_url=API_BASE_URL)
211
+ print(f"[INFO] Using API endpoint: {API_BASE_URL}", flush=True)
212
+ except Exception as e:
213
+ print(f"[WARNING] Failed to initialize OpenAI client: {e}",
214
+ file=sys.stderr, flush=True)
215
+ client = None
216
+
217
+ if client is None:
218
+ # Demo/Validation mode: No API key or OpenAI not available
219
+ print("[WARNING] No API credentials available. Running in validation mode.",
220
  flush=True)
221
 
222
  for task in tasks:
 
303
  flush=True)
304
 
305
  else:
306
+ # Normal mode: Use OpenAI API (through validator's proxy if available)
307
+ for task in tasks:
308
+ try:
309
+ success, steps, score, rewards = run_task(client, task)
310
+ all_scores.append(score)
311
+ print(f"[TASK_SUMMARY] {task}: score={score:.3f} steps={steps}",
312
+ flush=True)
313
+ except Exception as e:
314
+ print(f"[TASK_ERROR] {task}: {e}", file=sys.stderr, flush=True)
315
+ all_scores.append(0.0)
 
 
 
 
 
 
 
 
 
316
 
317
  # Final summary
318
  avg_score = sum(all_scores) / len(all_scores) if all_scores else 0.0