Spaces:
Sleeping
Sleeping
suhaas-code commited on
Commit ·
07431af
1
Parent(s): 96478ce
phase2
Browse files- inference.py +29 -7
inference.py
CHANGED
|
@@ -7,8 +7,16 @@ from pathlib import Path
|
|
| 7 |
from typing import Any, Dict, Optional
|
| 8 |
from urllib.parse import urlparse
|
| 9 |
|
| 10 |
-
|
| 11 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
from env.farm_env import FarmAction, FarmEnv, FarmState
|
| 14 |
from tasks.graders import grade_all
|
|
@@ -95,9 +103,13 @@ def build_prompt(state: FarmState, step: int, recent_actions: list[dict[str, flo
|
|
| 95 |
)
|
| 96 |
|
| 97 |
|
| 98 |
-
def build_client() -> Optional[
|
|
|
|
|
|
|
|
|
|
| 99 |
if not API_BASE_URL:
|
| 100 |
-
raise RuntimeError(
|
|
|
|
| 101 |
|
| 102 |
base_lower = API_BASE_URL.lower()
|
| 103 |
if "huggingface.co" in base_lower:
|
|
@@ -194,11 +206,14 @@ def choose_fallback_action(state: FarmState, recent_actions: list[dict[str, floa
|
|
| 194 |
|
| 195 |
|
| 196 |
def choose_action(
|
| 197 |
-
client:
|
| 198 |
state: FarmState,
|
| 199 |
step: int,
|
| 200 |
recent_actions: list[dict[str, float]],
|
| 201 |
) -> FarmAction:
|
|
|
|
|
|
|
|
|
|
| 202 |
prompt = build_prompt(state, step=step, recent_actions=recent_actions)
|
| 203 |
completion = client.chat.completions.create(
|
| 204 |
model=MODEL_NAME,
|
|
@@ -263,7 +278,14 @@ def run_inference() -> None:
|
|
| 263 |
dataset_path = Path(__file__).resolve().parent / \
|
| 264 |
"farmer_advisor_dataset.csv"
|
| 265 |
env = FarmEnv(dataset_path=dataset_path, seed=42, max_days=30)
|
| 266 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
|
| 268 |
total_reward = 0.0
|
| 269 |
total_yield = 0.0
|
|
@@ -297,7 +319,7 @@ def run_inference() -> None:
|
|
| 297 |
except Exception as exc:
|
| 298 |
llm_failures += 1
|
| 299 |
step_error = f"llm_error:{exc.__class__.__name__}"
|
| 300 |
-
|
| 301 |
|
| 302 |
try:
|
| 303 |
step_result = env.step(action)
|
|
|
|
| 7 |
from typing import Any, Dict, Optional
|
| 8 |
from urllib.parse import urlparse
|
| 9 |
|
| 10 |
+
try:
|
| 11 |
+
from dotenv import load_dotenv
|
| 12 |
+
except Exception:
|
| 13 |
+
def load_dotenv(*_args: Any, **_kwargs: Any) -> bool:
|
| 14 |
+
return False
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
from openai import OpenAI
|
| 18 |
+
except Exception:
|
| 19 |
+
OpenAI = None # type: ignore[assignment]
|
| 20 |
|
| 21 |
from env.farm_env import FarmAction, FarmEnv, FarmState
|
| 22 |
from tasks.graders import grade_all
|
|
|
|
| 103 |
)
|
| 104 |
|
| 105 |
|
| 106 |
+
def build_client() -> Optional[Any]:
|
| 107 |
+
if OpenAI is None:
|
| 108 |
+
raise RuntimeError("openai_sdk_unavailable")
|
| 109 |
+
|
| 110 |
if not API_BASE_URL:
|
| 111 |
+
raise RuntimeError(
|
| 112 |
+
"Missing required environment variable 'API_BASE_URL'.")
|
| 113 |
|
| 114 |
base_lower = API_BASE_URL.lower()
|
| 115 |
if "huggingface.co" in base_lower:
|
|
|
|
| 206 |
|
| 207 |
|
| 208 |
def choose_action(
|
| 209 |
+
client: Optional[Any],
|
| 210 |
state: FarmState,
|
| 211 |
step: int,
|
| 212 |
recent_actions: list[dict[str, float]],
|
| 213 |
) -> FarmAction:
|
| 214 |
+
if client is None:
|
| 215 |
+
raise RuntimeError("llm_client_unavailable")
|
| 216 |
+
|
| 217 |
prompt = build_prompt(state, step=step, recent_actions=recent_actions)
|
| 218 |
completion = client.chat.completions.create(
|
| 219 |
model=MODEL_NAME,
|
|
|
|
| 278 |
dataset_path = Path(__file__).resolve().parent / \
|
| 279 |
"farmer_advisor_dataset.csv"
|
| 280 |
env = FarmEnv(dataset_path=dataset_path, seed=42, max_days=30)
|
| 281 |
+
try:
|
| 282 |
+
client = build_client()
|
| 283 |
+
except Exception as exc:
|
| 284 |
+
client = None
|
| 285 |
+
print(
|
| 286 |
+
f"[WARN] llm_client_init_failed error={exc.__class__.__name__}",
|
| 287 |
+
flush=True,
|
| 288 |
+
)
|
| 289 |
|
| 290 |
total_reward = 0.0
|
| 291 |
total_yield = 0.0
|
|
|
|
| 319 |
except Exception as exc:
|
| 320 |
llm_failures += 1
|
| 321 |
step_error = f"llm_error:{exc.__class__.__name__}"
|
| 322 |
+
action = choose_fallback_action(state, recent_actions)
|
| 323 |
|
| 324 |
try:
|
| 325 |
step_result = env.step(action)
|