yashppawar commited on
Commit
da9e926
·
verified ·
1 Parent(s): 30c52ad

Use LOCAL_IMAGE_NAME per spec

Browse files
Files changed (1) hide show
  1. inference.py +3 -3
inference.py CHANGED
@@ -8,7 +8,7 @@ Required environment variables:
8
  API_BASE_URL The LLM endpoint (OpenAI-compatible)
9
  MODEL_NAME The model id to use
10
  HF_TOKEN API key for the LLM provider
11
- IMAGE_NAME (optional) Docker image for the env server
12
  Default: disk-panic:latest
13
 
14
  Stdout format (one per episode):
@@ -33,7 +33,7 @@ except ImportError:
33
 
34
  # -- config ----------------------------------------------------------------
35
 
36
- IMAGE_NAME = os.getenv("IMAGE_NAME", "disk-panic:latest")
37
  API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
38
  API_BASE_URL = os.getenv("API_BASE_URL") or "https://api.groq.com/openai/v1"
39
  MODEL_NAME = os.getenv("MODEL_NAME") or "llama-3.3-70b-versatile"
@@ -218,7 +218,7 @@ async def run_episode(client: OpenAI, env: DiskPanicEnv, task: str) -> float:
218
 
219
  async def main() -> None:
220
  client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
221
- env = await DiskPanicEnv.from_docker_image(IMAGE_NAME)
222
  try:
223
  for task in TASKS:
224
  await run_episode(client, env, task)
 
8
  API_BASE_URL The LLM endpoint (OpenAI-compatible)
9
  MODEL_NAME The model id to use
10
  HF_TOKEN API key for the LLM provider
11
+ LOCAL_IMAGE_NAME (optional) Docker image for the env server
12
  Default: disk-panic:latest
13
 
14
  Stdout format (one per episode):
 
33
 
34
  # -- config ----------------------------------------------------------------
35
 
36
+ LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME", "disk-panic:latest")
37
  API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
38
  API_BASE_URL = os.getenv("API_BASE_URL") or "https://api.groq.com/openai/v1"
39
  MODEL_NAME = os.getenv("MODEL_NAME") or "llama-3.3-70b-versatile"
 
218
 
219
  async def main() -> None:
220
  client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
221
+ env = await DiskPanicEnv.from_docker_image(LOCAL_IMAGE_NAME)
222
  try:
223
  for task in TASKS:
224
  await run_episode(client, env, task)