suhaas-code commited on
Commit
07431af
·
1 Parent(s): 96478ce
Files changed (1) hide show
  1. inference.py +29 -7
inference.py CHANGED
@@ -7,8 +7,16 @@ from pathlib import Path
7
  from typing import Any, Dict, Optional
8
  from urllib.parse import urlparse
9
 
10
- from dotenv import load_dotenv
11
- from openai import OpenAI
 
 
 
 
 
 
 
 
12
 
13
  from env.farm_env import FarmAction, FarmEnv, FarmState
14
  from tasks.graders import grade_all
@@ -95,9 +103,13 @@ def build_prompt(state: FarmState, step: int, recent_actions: list[dict[str, flo
95
  )
96
 
97
 
98
- def build_client() -> Optional[OpenAI]:
 
 
 
99
  if not API_BASE_URL:
100
- raise RuntimeError("Missing required environment variable 'API_BASE_URL'.")
 
101
 
102
  base_lower = API_BASE_URL.lower()
103
  if "huggingface.co" in base_lower:
@@ -194,11 +206,14 @@ def choose_fallback_action(state: FarmState, recent_actions: list[dict[str, floa
194
 
195
 
196
  def choose_action(
197
- client: OpenAI,
198
  state: FarmState,
199
  step: int,
200
  recent_actions: list[dict[str, float]],
201
  ) -> FarmAction:
 
 
 
202
  prompt = build_prompt(state, step=step, recent_actions=recent_actions)
203
  completion = client.chat.completions.create(
204
  model=MODEL_NAME,
@@ -263,7 +278,14 @@ def run_inference() -> None:
263
  dataset_path = Path(__file__).resolve().parent / \
264
  "farmer_advisor_dataset.csv"
265
  env = FarmEnv(dataset_path=dataset_path, seed=42, max_days=30)
266
- client = build_client()
 
 
 
 
 
 
 
267
 
268
  total_reward = 0.0
269
  total_yield = 0.0
@@ -297,7 +319,7 @@ def run_inference() -> None:
297
  except Exception as exc:
298
  llm_failures += 1
299
  step_error = f"llm_error:{exc.__class__.__name__}"
300
- raise RuntimeError(step_error) from exc
301
 
302
  try:
303
  step_result = env.step(action)
 
7
  from typing import Any, Dict, Optional
8
  from urllib.parse import urlparse
9
 
10
+ try:
11
+ from dotenv import load_dotenv
12
+ except Exception:
13
+ def load_dotenv(*_args: Any, **_kwargs: Any) -> bool:
14
+ return False
15
+
16
+ try:
17
+ from openai import OpenAI
18
+ except Exception:
19
+ OpenAI = None # type: ignore[assignment]
20
 
21
  from env.farm_env import FarmAction, FarmEnv, FarmState
22
  from tasks.graders import grade_all
 
103
  )
104
 
105
 
106
+ def build_client() -> Optional[Any]:
107
+ if OpenAI is None:
108
+ raise RuntimeError("openai_sdk_unavailable")
109
+
110
  if not API_BASE_URL:
111
+ raise RuntimeError(
112
+ "Missing required environment variable 'API_BASE_URL'.")
113
 
114
  base_lower = API_BASE_URL.lower()
115
  if "huggingface.co" in base_lower:
 
206
 
207
 
208
  def choose_action(
209
+ client: Optional[Any],
210
  state: FarmState,
211
  step: int,
212
  recent_actions: list[dict[str, float]],
213
  ) -> FarmAction:
214
+ if client is None:
215
+ raise RuntimeError("llm_client_unavailable")
216
+
217
  prompt = build_prompt(state, step=step, recent_actions=recent_actions)
218
  completion = client.chat.completions.create(
219
  model=MODEL_NAME,
 
278
  dataset_path = Path(__file__).resolve().parent / \
279
  "farmer_advisor_dataset.csv"
280
  env = FarmEnv(dataset_path=dataset_path, seed=42, max_days=30)
281
+ try:
282
+ client = build_client()
283
+ except Exception as exc:
284
+ client = None
285
+ print(
286
+ f"[WARN] llm_client_init_failed error={exc.__class__.__name__}",
287
+ flush=True,
288
+ )
289
 
290
  total_reward = 0.0
291
  total_yield = 0.0
 
319
  except Exception as exc:
320
  llm_failures += 1
321
  step_error = f"llm_error:{exc.__class__.__name__}"
322
+ action = choose_fallback_action(state, recent_actions)
323
 
324
  try:
325
  step_result = env.step(action)