coffeine16 commited on
Commit
caa2181
·
1 Parent(s): 7889ae7

Changes made

Browse files
Files changed (1) hide show
  1. inference.py +36 -45
inference.py CHANGED
@@ -8,13 +8,13 @@ LOCAL USAGE (no Docker — start the server first in a separate terminal):
8
  uvicorn server.app:app --host 0.0.0.0 --port 8000
9
 
10
  Then in another terminal:
11
- USE_DOCKER=false API_BASE_URL=https://api.openai.com/v1 MODEL_NAME=gpt-4o HF_TOKEN=sk-... python inference.py
12
 
13
  SINGLE TASK (local):
14
  FITSCRIPT_TASK=basic_plan USE_DOCKER=false python inference.py
15
 
16
  DOCKER USAGE (spins up the container automatically):
17
- USE_DOCKER=true LOCAL_IMAGE_NAME=fitscript-env:latest API_BASE_URL=... MODEL_NAME=... HF_TOKEN=... python inference.py
18
 
19
  STDOUT FORMAT (exact hackathon spec):
20
  [START] task=<task> env=fitscript_env model=<model>
@@ -38,13 +38,12 @@ except ImportError:
38
  # Configuration (hackathon mandatory variables)
39
  # ---------------------------------------------------------------------------
40
  API_BASE_URL: str = os.environ.get("API_BASE_URL", "https://router.huggingface.co/v1")
41
- MODEL_NAME: str = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
 
42
  API_KEY: str = os.environ.get("HF_TOKEN") or os.environ.get("API_KEY", "")
43
 
44
  BENCHMARK: str = "fitscript_env"
45
 
46
- # USE_DOCKER=false → connect to a local server already running (default for local dev)
47
- # USE_DOCKER=true → spin up a Docker container automatically
48
  USE_DOCKER: bool = os.environ.get("USE_DOCKER", "false").lower() == "true"
49
 
50
  IMAGE_NAME: str = (
@@ -54,8 +53,6 @@ IMAGE_NAME: str = (
54
 
55
  LOCAL_SERVER_URL: str = os.environ.get("LOCAL_SERVER_URL", "http://localhost:8000")
56
 
57
- # FITSCRIPT_TASK: set to a single task name to run only that task.
58
- # Leave empty (default) to run all 3 tasks sequentially (required for hackathon).
59
  FITSCRIPT_TASK: str = os.environ.get("FITSCRIPT_TASK", "")
60
 
61
  MAX_STEPS: int = int(os.environ.get("MAX_STEPS", "8"))
@@ -132,40 +129,41 @@ For a periodized 4-week powerlifting program, use:
132
  """
133
 
134
  # ---------------------------------------------------------------------------
135
- # LLM helpers
136
  # ---------------------------------------------------------------------------
137
 
138
  def _call_llm_sync(messages: list) -> str:
139
- """Synchronous Hugging Face call"""
140
- from huggingface_hub import InferenceClient
141
- import os
142
-
143
- client = InferenceClient(
144
- model=os.getenv("MODEL_NAME"),
145
- token=os.getenv("HF_API_KEY")
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  )
147
 
148
- # Convert OpenAI-style messages → single prompt
149
- prompt = ""
150
- for m in messages:
151
- role = m.get("role", "")
152
- content = m.get("content", "")
153
- if role == "system":
154
- prompt += f"[SYSTEM]: {content}\n"
155
- elif role == "user":
156
- prompt += f"[USER]: {content}\n"
157
- elif role == "assistant":
158
- prompt += f"[ASSISTANT]: {content}\n"
159
-
160
- prompt += "[ASSISTANT]:"
161
-
162
- response = client.text_generation(
163
- prompt,
164
- max_new_tokens=2048,
165
  temperature=0.7,
166
  )
167
 
168
- return response
 
169
 
170
  async def call_llm_async(messages: list) -> str:
171
  loop = asyncio.get_event_loop()
@@ -219,7 +217,6 @@ async def run_episode(task_name: str, env) -> None:
219
  error_msg = None
220
 
221
  try:
222
- # reset() is async in EnvClient
223
  reset_result = await env.reset()
224
  obs = reset_result.observation
225
 
@@ -243,7 +240,6 @@ async def run_episode(task_name: str, env) -> None:
243
  action_type = "modify_plan" if task_name == "injury_safe_modification" else "generate_plan"
244
  action = FitscriptAction(action_type=action_type, plan=plan_str)
245
 
246
- # step() is async in EnvClient
247
  try:
248
  result = await env.step(action)
249
  except Exception as exc:
@@ -281,23 +277,17 @@ async def main() -> None:
281
  tasks_to_run = [FITSCRIPT_TASK] if FITSCRIPT_TASK else ALL_TASKS
282
 
283
  if USE_DOCKER:
284
- # Docker mode: launch one container per task.
285
- # FITSCRIPT_TASK env var is passed into the container so the server
286
- # initialises with the correct task_id.
287
  for task_name in tasks_to_run:
288
  print(
289
  f"[INFO] Starting Docker container ({IMAGE_NAME}) for task={task_name}",
290
  file=sys.stderr, flush=True,
291
  )
292
- # from_docker_image is async and returns a connected EnvClient
293
  try:
294
  env = await FitscriptEnv.from_docker_image(
295
  IMAGE_NAME,
296
  env={"FITSCRIPT_TASK": task_name},
297
  )
298
  except TypeError:
299
- # Some versions of EnvClient don't support the env= kwarg;
300
- # fall back to no extra env (server uses its own FITSCRIPT_TASK)
301
  env = await FitscriptEnv.from_docker_image(IMAGE_NAME)
302
  try:
303
  await run_episode(task_name, env)
@@ -305,9 +295,6 @@ async def main() -> None:
305
  await env.close()
306
 
307
  else:
308
- # Local mode: server must already be running at LOCAL_SERVER_URL.
309
- # Each task gets a fresh client connection (the server keeps its state
310
- # per-session via WebSocket, so reconnecting is a clean reset).
311
  for task_name in tasks_to_run:
312
  print(
313
  f"[INFO] Connecting to local server at {LOCAL_SERVER_URL} for task={task_name}",
@@ -317,7 +304,11 @@ async def main() -> None:
317
  try:
318
  await run_episode(task_name, env)
319
  finally:
320
- env.close()
 
 
 
 
321
 
322
 
323
  if __name__ == "__main__":
 
8
  uvicorn server.app:app --host 0.0.0.0 --port 8000
9
 
10
  Then in another terminal:
11
+ USE_DOCKER=false HF_TOKEN=hf_... python inference.py
12
 
13
  SINGLE TASK (local):
14
  FITSCRIPT_TASK=basic_plan USE_DOCKER=false python inference.py
15
 
16
  DOCKER USAGE (spins up the container automatically):
17
+ USE_DOCKER=true LOCAL_IMAGE_NAME=fitscript-env:latest HF_TOKEN=hf_... python inference.py
18
 
19
  STDOUT FORMAT (exact hackathon spec):
20
  [START] task=<task> env=fitscript_env model=<model>
 
38
  # Configuration (hackathon mandatory variables)
39
  # ---------------------------------------------------------------------------
40
  API_BASE_URL: str = os.environ.get("API_BASE_URL", "https://router.huggingface.co/v1")
41
+ MODEL_NAME: str = os.environ.get("MODEL_NAME", "meta-llama/Meta-Llama-3-8B-Instruct")
42
+ # Accept HF_TOKEN or API_KEY
43
  API_KEY: str = os.environ.get("HF_TOKEN") or os.environ.get("API_KEY", "")
44
 
45
  BENCHMARK: str = "fitscript_env"
46
 
 
 
47
  USE_DOCKER: bool = os.environ.get("USE_DOCKER", "false").lower() == "true"
48
 
49
  IMAGE_NAME: str = (
 
53
 
54
  LOCAL_SERVER_URL: str = os.environ.get("LOCAL_SERVER_URL", "http://localhost:8000")
55
 
 
 
56
  FITSCRIPT_TASK: str = os.environ.get("FITSCRIPT_TASK", "")
57
 
58
  MAX_STEPS: int = int(os.environ.get("MAX_STEPS", "8"))
 
129
  """
130
 
131
  # ---------------------------------------------------------------------------
132
+ # LLM helpers — using OpenAI-compatible HuggingFace router
133
  # ---------------------------------------------------------------------------
134
 
135
  def _call_llm_sync(messages: list) -> str:
136
+ """
137
+ Call HuggingFace Inference API via its OpenAI-compatible /v1/chat/completions
138
+ endpoint. Works with any model available on HF's serverless inference router.
139
+ """
140
+ try:
141
+ from openai import OpenAI
142
+ except ImportError:
143
+ raise ImportError(
144
+ "openai package is required. Install with: pip install openai"
145
+ )
146
+
147
+ if not API_KEY:
148
+ raise ValueError(
149
+ "HF_TOKEN (or API_KEY) environment variable is not set. "
150
+ "Get your token from https://huggingface.co/settings/tokens"
151
+ )
152
+
153
+ client = OpenAI(
154
+ base_url=API_BASE_URL,
155
+ api_key=API_KEY,
156
  )
157
 
158
+ response = client.chat.completions.create(
159
+ model=MODEL_NAME,
160
+ messages=messages,
161
+ max_tokens=2048,
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  temperature=0.7,
163
  )
164
 
165
+ return response.choices[0].message.content
166
+
167
 
168
  async def call_llm_async(messages: list) -> str:
169
  loop = asyncio.get_event_loop()
 
217
  error_msg = None
218
 
219
  try:
 
220
  reset_result = await env.reset()
221
  obs = reset_result.observation
222
 
 
240
  action_type = "modify_plan" if task_name == "injury_safe_modification" else "generate_plan"
241
  action = FitscriptAction(action_type=action_type, plan=plan_str)
242
 
 
243
  try:
244
  result = await env.step(action)
245
  except Exception as exc:
 
277
  tasks_to_run = [FITSCRIPT_TASK] if FITSCRIPT_TASK else ALL_TASKS
278
 
279
  if USE_DOCKER:
 
 
 
280
  for task_name in tasks_to_run:
281
  print(
282
  f"[INFO] Starting Docker container ({IMAGE_NAME}) for task={task_name}",
283
  file=sys.stderr, flush=True,
284
  )
 
285
  try:
286
  env = await FitscriptEnv.from_docker_image(
287
  IMAGE_NAME,
288
  env={"FITSCRIPT_TASK": task_name},
289
  )
290
  except TypeError:
 
 
291
  env = await FitscriptEnv.from_docker_image(IMAGE_NAME)
292
  try:
293
  await run_episode(task_name, env)
 
295
  await env.close()
296
 
297
  else:
 
 
 
298
  for task_name in tasks_to_run:
299
  print(
300
  f"[INFO] Connecting to local server at {LOCAL_SERVER_URL} for task={task_name}",
 
304
  try:
305
  await run_episode(task_name, env)
306
  finally:
307
+ # close() is async — await it properly
308
+ try:
309
+ await env.close()
310
+ except TypeError:
311
+ env.close()
312
 
313
 
314
  if __name__ == "__main__":