jeevank0 commited on
Commit
bfd41b5
·
1 Parent(s): 4ad742c

Inference.py

Browse files
Files changed (1) hide show
  1. inference.py +15 -4
inference.py CHANGED
@@ -7,8 +7,16 @@ from pathlib import Path
7
  from typing import Any, Dict, Optional
8
  from urllib.parse import urlparse
9
 
10
- from dotenv import load_dotenv
11
- from openai import OpenAI
 
 
 
 
 
 
 
 
12
 
13
  from env.farm_env import FarmAction, FarmEnv, FarmState
14
  from tasks.graders import grade_all
@@ -95,7 +103,10 @@ def build_prompt(state: FarmState, step: int, recent_actions: list[dict[str, flo
95
  )
96
 
97
 
98
- def build_client() -> Optional[OpenAI]:
 
 
 
99
  if not API_BASE_URL:
100
  raise RuntimeError(
101
  "Missing required environment variable 'API_BASE_URL'.")
@@ -195,7 +206,7 @@ def choose_fallback_action(state: FarmState, recent_actions: list[dict[str, floa
195
 
196
 
197
  def choose_action(
198
- client: Optional[OpenAI],
199
  state: FarmState,
200
  step: int,
201
  recent_actions: list[dict[str, float]],
 
7
  from typing import Any, Dict, Optional
8
  from urllib.parse import urlparse
9
 
10
+ try:
11
+ from dotenv import load_dotenv
12
+ except Exception:
13
+ def load_dotenv(*_args: Any, **_kwargs: Any) -> bool:
14
+ return False
15
+
16
+ try:
17
+ from openai import OpenAI
18
+ except Exception:
19
+ OpenAI = None # type: ignore[assignment]
20
 
21
  from env.farm_env import FarmAction, FarmEnv, FarmState
22
  from tasks.graders import grade_all
 
103
  )
104
 
105
 
106
+ def build_client() -> Optional[Any]:
107
+ if OpenAI is None:
108
+ raise RuntimeError("openai_sdk_unavailable")
109
+
110
  if not API_BASE_URL:
111
  raise RuntimeError(
112
  "Missing required environment variable 'API_BASE_URL'.")
 
206
 
207
 
208
  def choose_action(
209
+ client: Optional[Any],
210
  state: FarmState,
211
  step: int,
212
  recent_actions: list[dict[str, float]],