Spaces:
Sleeping
Sleeping
Commit ·
7889ae7
1
Parent(s): bc664e2
Update
Browse files- inference.py +190 -121
- server/__pycache__/app.cpython-313.pyc +0 -0
- server/app.py +32 -36
- server/requirements.txt +3 -4
inference.py
CHANGED
|
@@ -1,22 +1,25 @@
|
|
| 1 |
"""
|
| 2 |
-
FitScript inference.py
|
| 3 |
|
| 4 |
-
|
| 5 |
-
FITSCRIPT_TASK=basic_plan \\
|
| 6 |
-
API_BASE_URL=https://api.openai.com/v1 \\
|
| 7 |
-
MODEL_NAME=gpt-4o \\
|
| 8 |
-
HF_TOKEN=<your_key> \\
|
| 9 |
-
python inference.py
|
| 10 |
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
periodized_program (hard)
|
| 15 |
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
[START] task=<task> env=fitscript_env model=<model>
|
| 18 |
-
[STEP]
|
| 19 |
-
[END]
|
| 20 |
"""
|
| 21 |
|
| 22 |
import asyncio
|
|
@@ -24,48 +27,64 @@ import json
|
|
| 24 |
import os
|
| 25 |
import sys
|
| 26 |
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
from dotenv import load_dotenv
|
| 30 |
-
|
|
|
|
|
|
|
| 31 |
|
| 32 |
-
load_dotenv()
|
| 33 |
# ---------------------------------------------------------------------------
|
| 34 |
-
#
|
| 35 |
# ---------------------------------------------------------------------------
|
| 36 |
-
API_BASE_URL: str = os.environ
|
| 37 |
-
MODEL_NAME: str = os.environ
|
| 38 |
-
API_KEY: str = os.environ
|
| 39 |
|
| 40 |
-
TASK_NAME: str = os.getenv("FITSCRIPT_TASK", "basic_plan")
|
| 41 |
BENCHMARK: str = "fitscript_env"
|
| 42 |
-
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
# ---------------------------------------------------------------------------
|
| 46 |
-
# Structured log helpers (hackathon spec
|
| 47 |
# ---------------------------------------------------------------------------
|
| 48 |
|
| 49 |
-
def log_start(task: str,
|
| 50 |
-
print(f"[START] task={task} env={
|
| 51 |
|
| 52 |
|
| 53 |
def log_step(step: int, action: str, reward: float, done: bool, error) -> None:
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
action_token = action.replace("\n", " ").replace("\r", "")[:120]
|
| 57 |
print(
|
| 58 |
-
f"[STEP] step={step} action={
|
| 59 |
-
f" done={str(done).lower()} error={
|
| 60 |
flush=True,
|
| 61 |
)
|
| 62 |
|
| 63 |
|
| 64 |
def log_end(success: bool, steps: int, score: float, rewards: list) -> None:
|
| 65 |
-
|
| 66 |
print(
|
| 67 |
f"[END] success={str(success).lower()} steps={steps}"
|
| 68 |
-
f" score={score:.
|
| 69 |
flush=True,
|
| 70 |
)
|
| 71 |
|
|
@@ -77,34 +96,33 @@ def log_end(success: bool, steps: int, score: float, rewards: list) -> None:
|
|
| 77 |
SYSTEM_PROMPT = """You are an expert personal trainer and exercise scientist.
|
| 78 |
You will receive a client profile and must generate a structured workout plan as JSON.
|
| 79 |
|
| 80 |
-
|
| 81 |
-
Do NOT include any prose or explanation outside the JSON.
|
| 82 |
|
| 83 |
-
|
| 84 |
{
|
| 85 |
"days": [
|
| 86 |
{
|
| 87 |
-
"name": "Day 1 -
|
| 88 |
-
"focus": "
|
| 89 |
"exercises": [
|
| 90 |
-
{"name": "
|
| 91 |
]
|
| 92 |
}
|
| 93 |
]
|
| 94 |
}
|
| 95 |
|
| 96 |
-
|
| 97 |
{
|
| 98 |
"weeks": [
|
| 99 |
{
|
| 100 |
"week": 1,
|
| 101 |
-
"intensity":
|
| 102 |
-
"total_sets":
|
| 103 |
"days": [
|
| 104 |
{
|
| 105 |
-
"name": "Day 1 -
|
| 106 |
"exercises": [
|
| 107 |
-
{"name": "
|
| 108 |
]
|
| 109 |
}
|
| 110 |
]
|
|
@@ -113,81 +131,107 @@ JSON schema for a periodized 4-week program:
|
|
| 113 |
}
|
| 114 |
"""
|
| 115 |
|
| 116 |
-
|
| 117 |
# ---------------------------------------------------------------------------
|
| 118 |
-
# LLM
|
| 119 |
# ---------------------------------------------------------------------------
|
| 120 |
|
| 121 |
-
def
|
| 122 |
-
"""
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
temperature=0.7,
|
| 127 |
-
max_tokens=2048,
|
| 128 |
)
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
|
| 132 |
def build_user_message(observation) -> str:
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
task_id = observation.task_id if hasattr(observation, "task_id") else ""
|
| 138 |
|
| 139 |
parts = [
|
| 140 |
f"Task: {task_id}",
|
| 141 |
-
f"Client profile:
|
| 142 |
]
|
| 143 |
if feedback:
|
| 144 |
parts.append(f"Environment feedback: {feedback}")
|
| 145 |
if breakdown:
|
| 146 |
parts.append(f"Score breakdown: {json.dumps(breakdown, indent=2)}")
|
| 147 |
-
parts.append("
|
| 148 |
return "\n\n".join(parts)
|
| 149 |
|
| 150 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
# ---------------------------------------------------------------------------
|
| 152 |
-
#
|
| 153 |
# ---------------------------------------------------------------------------
|
| 154 |
|
| 155 |
-
async def run_episode() -> None:
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
|
|
|
|
|
|
| 159 |
|
| 160 |
-
log_start(
|
| 161 |
|
| 162 |
rewards: list = []
|
| 163 |
-
final_score
|
| 164 |
-
success
|
| 165 |
-
step
|
| 166 |
-
error_msg
|
| 167 |
|
| 168 |
-
env = None
|
| 169 |
try:
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
if USE_DOCKER:
|
| 173 |
-
env = await FitscriptEnv.from_docker_image(IMAGE_NAME)
|
| 174 |
-
else:
|
| 175 |
-
env = FitscriptEnv(base_url="http://localhost:8000")
|
| 176 |
-
|
| 177 |
-
# Reset
|
| 178 |
-
reset_result = env.reset()
|
| 179 |
obs = reset_result.observation
|
| 180 |
|
| 181 |
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
|
| 182 |
|
| 183 |
for step in range(1, MAX_STEPS + 1):
|
| 184 |
-
# Build user turn from current observation
|
| 185 |
user_content = build_user_message(obs)
|
| 186 |
messages.append({"role": "user", "content": user_content})
|
| 187 |
|
| 188 |
-
#
|
| 189 |
try:
|
| 190 |
-
assistant_reply =
|
| 191 |
except Exception as exc:
|
| 192 |
error_msg = str(exc)
|
| 193 |
log_step(step, "LLM_ERROR", 0.0, True, error_msg)
|
|
@@ -195,61 +239,86 @@ async def run_episode() -> None:
|
|
| 195 |
|
| 196 |
messages.append({"role": "assistant", "content": assistant_reply})
|
| 197 |
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
line for line in lines
|
| 204 |
-
if not line.startswith("```")
|
| 205 |
-
).strip()
|
| 206 |
-
|
| 207 |
-
# Determine action_type from task
|
| 208 |
-
if TASK_NAME == "injury_safe_modification":
|
| 209 |
-
action_type = "modify_plan"
|
| 210 |
-
elif TASK_NAME == "periodized_program":
|
| 211 |
-
action_type = "generate_plan"
|
| 212 |
-
else:
|
| 213 |
-
action_type = "generate_plan"
|
| 214 |
-
|
| 215 |
-
action = FitscriptAction(action_type=action_type, plan=plan_str)
|
| 216 |
-
|
| 217 |
-
# Step in environment
|
| 218 |
try:
|
| 219 |
-
result = env.step(action)
|
| 220 |
except Exception as exc:
|
| 221 |
error_msg = str(exc)
|
| 222 |
log_step(step, action_type, 0.0, True, error_msg)
|
| 223 |
break
|
| 224 |
|
| 225 |
-
obs
|
| 226 |
-
reward
|
| 227 |
-
done
|
| 228 |
rewards.append(reward)
|
| 229 |
-
final_score = reward
|
| 230 |
|
| 231 |
log_step(step, action_type, reward, done, None)
|
| 232 |
|
| 233 |
if done:
|
| 234 |
-
success = reward >= 0.75
|
| 235 |
break
|
| 236 |
|
|
|
|
|
|
|
| 237 |
except Exception as exc:
|
| 238 |
error_msg = str(exc)
|
| 239 |
-
print(f"[ERROR] {error_msg}",
|
| 240 |
-
finally:
|
| 241 |
-
if env is not None:
|
| 242 |
-
if USE_DOCKER:
|
| 243 |
-
await env.close()
|
| 244 |
-
else:
|
| 245 |
-
env.close()
|
| 246 |
|
| 247 |
log_end(success, step, final_score, rewards)
|
| 248 |
|
| 249 |
|
| 250 |
# ---------------------------------------------------------------------------
|
| 251 |
-
#
|
| 252 |
# ---------------------------------------------------------------------------
|
| 253 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
if __name__ == "__main__":
|
| 255 |
-
asyncio.run(
|
|
|
|
| 1 |
"""
|
| 2 |
+
FitScript inference.py — required entry point for hackathon evaluation.
|
| 3 |
|
| 4 |
+
Runs all 3 tasks sequentially and emits structured stdout logs per spec.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
+
LOCAL USAGE (no Docker — start the server first in a separate terminal):
|
| 7 |
+
cd FitScript
|
| 8 |
+
uvicorn server.app:app --host 0.0.0.0 --port 8000
|
|
|
|
| 9 |
|
| 10 |
+
Then in another terminal:
|
| 11 |
+
USE_DOCKER=false API_BASE_URL=https://api.openai.com/v1 MODEL_NAME=gpt-4o HF_TOKEN=sk-... python inference.py
|
| 12 |
+
|
| 13 |
+
SINGLE TASK (local):
|
| 14 |
+
FITSCRIPT_TASK=basic_plan USE_DOCKER=false python inference.py
|
| 15 |
+
|
| 16 |
+
DOCKER USAGE (spins up the container automatically):
|
| 17 |
+
USE_DOCKER=true LOCAL_IMAGE_NAME=fitscript-env:latest API_BASE_URL=... MODEL_NAME=... HF_TOKEN=... python inference.py
|
| 18 |
+
|
| 19 |
+
STDOUT FORMAT (exact hackathon spec):
|
| 20 |
[START] task=<task> env=fitscript_env model=<model>
|
| 21 |
+
[STEP] step=<N> action=<text> reward=<R:.2f> done=<true|false> error=<null|msg>
|
| 22 |
+
[END] success=<true|false> steps=<N> score=<score:.2f> rewards=<r1:.2f,...>
|
| 23 |
"""
|
| 24 |
|
| 25 |
import asyncio
|
|
|
|
| 27 |
import os
|
| 28 |
import sys
|
| 29 |
|
| 30 |
+
# Optional: load .env for local development
|
| 31 |
+
try:
|
| 32 |
+
from dotenv import load_dotenv
|
| 33 |
+
load_dotenv()
|
| 34 |
+
except ImportError:
|
| 35 |
+
pass
|
| 36 |
|
|
|
|
| 37 |
# ---------------------------------------------------------------------------
|
| 38 |
+
# Configuration (hackathon mandatory variables)
|
| 39 |
# ---------------------------------------------------------------------------
|
| 40 |
+
API_BASE_URL: str = os.environ.get("API_BASE_URL", "https://router.huggingface.co/v1")
|
| 41 |
+
MODEL_NAME: str = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
|
| 42 |
+
API_KEY: str = os.environ.get("HF_TOKEN") or os.environ.get("API_KEY", "")
|
| 43 |
|
|
|
|
| 44 |
BENCHMARK: str = "fitscript_env"
|
| 45 |
+
|
| 46 |
+
# USE_DOCKER=false → connect to a local server already running (default for local dev)
|
| 47 |
+
# USE_DOCKER=true → spin up a Docker container automatically
|
| 48 |
+
USE_DOCKER: bool = os.environ.get("USE_DOCKER", "false").lower() == "true"
|
| 49 |
+
|
| 50 |
+
IMAGE_NAME: str = (
|
| 51 |
+
os.environ.get("LOCAL_IMAGE_NAME")
|
| 52 |
+
or os.environ.get("FITSCRIPT_IMAGE", "fitscript-env:latest")
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
LOCAL_SERVER_URL: str = os.environ.get("LOCAL_SERVER_URL", "http://localhost:8000")
|
| 56 |
+
|
| 57 |
+
# FITSCRIPT_TASK: set to a single task name to run only that task.
|
| 58 |
+
# Leave empty (default) to run all 3 tasks sequentially (required for hackathon).
|
| 59 |
+
FITSCRIPT_TASK: str = os.environ.get("FITSCRIPT_TASK", "")
|
| 60 |
+
|
| 61 |
+
MAX_STEPS: int = int(os.environ.get("MAX_STEPS", "8"))
|
| 62 |
+
|
| 63 |
+
ALL_TASKS = ["basic_plan", "injury_safe_modification", "periodized_program"]
|
| 64 |
|
| 65 |
# ---------------------------------------------------------------------------
|
| 66 |
+
# Structured log helpers (exact hackathon spec format — do not change)
|
| 67 |
# ---------------------------------------------------------------------------
|
| 68 |
|
| 69 |
+
def log_start(task: str, env_name: str, model: str) -> None:
|
| 70 |
+
print(f"[START] task={task} env={env_name} model={model}", flush=True)
|
| 71 |
|
| 72 |
|
| 73 |
def log_step(step: int, action: str, reward: float, done: bool, error) -> None:
|
| 74 |
+
err_str = str(error) if error else "null"
|
| 75 |
+
action_str = str(action).replace("\n", " ").replace("\r", "")[:120]
|
|
|
|
| 76 |
print(
|
| 77 |
+
f"[STEP] step={step} action={action_str} reward={reward:.2f}"
|
| 78 |
+
f" done={str(done).lower()} error={err_str}",
|
| 79 |
flush=True,
|
| 80 |
)
|
| 81 |
|
| 82 |
|
| 83 |
def log_end(success: bool, steps: int, score: float, rewards: list) -> None:
|
| 84 |
+
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
|
| 85 |
print(
|
| 86 |
f"[END] success={str(success).lower()} steps={steps}"
|
| 87 |
+
f" score={score:.2f} rewards={rewards_str}",
|
| 88 |
flush=True,
|
| 89 |
)
|
| 90 |
|
|
|
|
| 96 |
SYSTEM_PROMPT = """You are an expert personal trainer and exercise scientist.
|
| 97 |
You will receive a client profile and must generate a structured workout plan as JSON.
|
| 98 |
|
| 99 |
+
IMPORTANT: Respond with ONLY a valid JSON object. No prose, no markdown fences, no explanation.
|
|
|
|
| 100 |
|
| 101 |
+
For a basic plan or injury-modification plan, use:
|
| 102 |
{
|
| 103 |
"days": [
|
| 104 |
{
|
| 105 |
+
"name": "Day 1 - Lower Body",
|
| 106 |
+
"focus": "legs",
|
| 107 |
"exercises": [
|
| 108 |
+
{"name": "Squat", "sets": 3, "reps": 10, "rest_seconds": 60}
|
| 109 |
]
|
| 110 |
}
|
| 111 |
]
|
| 112 |
}
|
| 113 |
|
| 114 |
+
For a periodized 4-week powerlifting program, use:
|
| 115 |
{
|
| 116 |
"weeks": [
|
| 117 |
{
|
| 118 |
"week": 1,
|
| 119 |
+
"intensity": 72.5,
|
| 120 |
+
"total_sets": 80,
|
| 121 |
"days": [
|
| 122 |
{
|
| 123 |
+
"name": "Day 1 - Squat",
|
| 124 |
"exercises": [
|
| 125 |
+
{"name": "Back Squat", "sets": 5, "reps": 5, "intensity_pct": 72.5}
|
| 126 |
]
|
| 127 |
}
|
| 128 |
]
|
|
|
|
| 131 |
}
|
| 132 |
"""
|
| 133 |
|
|
|
|
| 134 |
# ---------------------------------------------------------------------------
|
| 135 |
+
# LLM helpers
|
| 136 |
# ---------------------------------------------------------------------------
|
| 137 |
|
| 138 |
+
def _call_llm_sync(messages: list) -> str:
|
| 139 |
+
"""Synchronous Hugging Face call"""
|
| 140 |
+
from huggingface_hub import InferenceClient
|
| 141 |
+
import os
|
| 142 |
+
|
| 143 |
+
client = InferenceClient(
|
| 144 |
+
model=os.getenv("MODEL_NAME"),
|
| 145 |
+
token=os.getenv("HF_API_KEY")
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
# Convert OpenAI-style messages → single prompt
|
| 149 |
+
prompt = ""
|
| 150 |
+
for m in messages:
|
| 151 |
+
role = m.get("role", "")
|
| 152 |
+
content = m.get("content", "")
|
| 153 |
+
if role == "system":
|
| 154 |
+
prompt += f"[SYSTEM]: {content}\n"
|
| 155 |
+
elif role == "user":
|
| 156 |
+
prompt += f"[USER]: {content}\n"
|
| 157 |
+
elif role == "assistant":
|
| 158 |
+
prompt += f"[ASSISTANT]: {content}\n"
|
| 159 |
+
|
| 160 |
+
prompt += "[ASSISTANT]:"
|
| 161 |
+
|
| 162 |
+
response = client.text_generation(
|
| 163 |
+
prompt,
|
| 164 |
+
max_new_tokens=2048,
|
| 165 |
temperature=0.7,
|
|
|
|
| 166 |
)
|
| 167 |
+
|
| 168 |
+
return response
|
| 169 |
+
|
| 170 |
+
async def call_llm_async(messages: list) -> str:
|
| 171 |
+
loop = asyncio.get_event_loop()
|
| 172 |
+
return await loop.run_in_executor(None, _call_llm_sync, messages)
|
| 173 |
|
| 174 |
|
| 175 |
def build_user_message(observation) -> str:
|
| 176 |
+
profile = getattr(observation, "client_profile", {})
|
| 177 |
+
feedback = getattr(observation, "feedback", "")
|
| 178 |
+
breakdown = getattr(observation, "score_breakdown", {})
|
| 179 |
+
task_id = getattr(observation, "task_id", "")
|
|
|
|
| 180 |
|
| 181 |
parts = [
|
| 182 |
f"Task: {task_id}",
|
| 183 |
+
f"Client profile:\n{json.dumps(profile, indent=2)}",
|
| 184 |
]
|
| 185 |
if feedback:
|
| 186 |
parts.append(f"Environment feedback: {feedback}")
|
| 187 |
if breakdown:
|
| 188 |
parts.append(f"Score breakdown: {json.dumps(breakdown, indent=2)}")
|
| 189 |
+
parts.append("Generate or revise the workout plan as a JSON object only.")
|
| 190 |
return "\n\n".join(parts)
|
| 191 |
|
| 192 |
|
| 193 |
+
def strip_fences(text: str) -> str:
|
| 194 |
+
"""Remove ```json ... ``` markdown fences if the LLM added them."""
|
| 195 |
+
text = text.strip()
|
| 196 |
+
if text.startswith("```"):
|
| 197 |
+
lines = [l for l in text.split("\n") if not l.startswith("```")]
|
| 198 |
+
text = "\n".join(lines).strip()
|
| 199 |
+
return text
|
| 200 |
+
|
| 201 |
+
|
| 202 |
# ---------------------------------------------------------------------------
|
| 203 |
+
# Single episode runner
|
| 204 |
# ---------------------------------------------------------------------------
|
| 205 |
|
| 206 |
+
async def run_episode(task_name: str, env) -> None:
|
| 207 |
+
"""
|
| 208 |
+
Run one episode for task_name against env (an async EnvClient).
|
| 209 |
+
Emits [START] / [STEP] / [END] to stdout.
|
| 210 |
+
"""
|
| 211 |
+
from FitScript import FitscriptAction
|
| 212 |
|
| 213 |
+
log_start(task_name, BENCHMARK, MODEL_NAME)
|
| 214 |
|
| 215 |
rewards: list = []
|
| 216 |
+
final_score = 0.0
|
| 217 |
+
success = False
|
| 218 |
+
step = 0
|
| 219 |
+
error_msg = None
|
| 220 |
|
|
|
|
| 221 |
try:
|
| 222 |
+
# reset() is async in EnvClient
|
| 223 |
+
reset_result = await env.reset()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
obs = reset_result.observation
|
| 225 |
|
| 226 |
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
|
| 227 |
|
| 228 |
for step in range(1, MAX_STEPS + 1):
|
|
|
|
| 229 |
user_content = build_user_message(obs)
|
| 230 |
messages.append({"role": "user", "content": user_content})
|
| 231 |
|
| 232 |
+
# LLM call (async-wrapped sync)
|
| 233 |
try:
|
| 234 |
+
assistant_reply = await call_llm_async(messages)
|
| 235 |
except Exception as exc:
|
| 236 |
error_msg = str(exc)
|
| 237 |
log_step(step, "LLM_ERROR", 0.0, True, error_msg)
|
|
|
|
| 239 |
|
| 240 |
messages.append({"role": "assistant", "content": assistant_reply})
|
| 241 |
|
| 242 |
+
plan_str = strip_fences(assistant_reply)
|
| 243 |
+
action_type = "modify_plan" if task_name == "injury_safe_modification" else "generate_plan"
|
| 244 |
+
action = FitscriptAction(action_type=action_type, plan=plan_str)
|
| 245 |
+
|
| 246 |
+
# step() is async in EnvClient
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
try:
|
| 248 |
+
result = await env.step(action)
|
| 249 |
except Exception as exc:
|
| 250 |
error_msg = str(exc)
|
| 251 |
log_step(step, action_type, 0.0, True, error_msg)
|
| 252 |
break
|
| 253 |
|
| 254 |
+
obs = result.observation
|
| 255 |
+
reward = float(result.reward or 0.0)
|
| 256 |
+
done = bool(result.done)
|
| 257 |
rewards.append(reward)
|
| 258 |
+
final_score = max(final_score, reward)
|
| 259 |
|
| 260 |
log_step(step, action_type, reward, done, None)
|
| 261 |
|
| 262 |
if done:
|
|
|
|
| 263 |
break
|
| 264 |
|
| 265 |
+
success = final_score >= 0.75
|
| 266 |
+
|
| 267 |
except Exception as exc:
|
| 268 |
error_msg = str(exc)
|
| 269 |
+
print(f"[ERROR] episode failed: {error_msg}", file=sys.stderr, flush=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
|
| 271 |
log_end(success, step, final_score, rewards)
|
| 272 |
|
| 273 |
|
| 274 |
# ---------------------------------------------------------------------------
|
| 275 |
+
# Main entry point
|
| 276 |
# ---------------------------------------------------------------------------
|
| 277 |
|
| 278 |
+
async def main() -> None:
|
| 279 |
+
from FitScript import FitscriptEnv
|
| 280 |
+
|
| 281 |
+
tasks_to_run = [FITSCRIPT_TASK] if FITSCRIPT_TASK else ALL_TASKS
|
| 282 |
+
|
| 283 |
+
if USE_DOCKER:
|
| 284 |
+
# Docker mode: launch one container per task.
|
| 285 |
+
# FITSCRIPT_TASK env var is passed into the container so the server
|
| 286 |
+
# initialises with the correct task_id.
|
| 287 |
+
for task_name in tasks_to_run:
|
| 288 |
+
print(
|
| 289 |
+
f"[INFO] Starting Docker container ({IMAGE_NAME}) for task={task_name}",
|
| 290 |
+
file=sys.stderr, flush=True,
|
| 291 |
+
)
|
| 292 |
+
# from_docker_image is async and returns a connected EnvClient
|
| 293 |
+
try:
|
| 294 |
+
env = await FitscriptEnv.from_docker_image(
|
| 295 |
+
IMAGE_NAME,
|
| 296 |
+
env={"FITSCRIPT_TASK": task_name},
|
| 297 |
+
)
|
| 298 |
+
except TypeError:
|
| 299 |
+
# Some versions of EnvClient don't support the env= kwarg;
|
| 300 |
+
# fall back to no extra env (server uses its own FITSCRIPT_TASK)
|
| 301 |
+
env = await FitscriptEnv.from_docker_image(IMAGE_NAME)
|
| 302 |
+
try:
|
| 303 |
+
await run_episode(task_name, env)
|
| 304 |
+
finally:
|
| 305 |
+
await env.close()
|
| 306 |
+
|
| 307 |
+
else:
|
| 308 |
+
# Local mode: server must already be running at LOCAL_SERVER_URL.
|
| 309 |
+
# Each task gets a fresh client connection (the server keeps its state
|
| 310 |
+
# per-session via WebSocket, so reconnecting is a clean reset).
|
| 311 |
+
for task_name in tasks_to_run:
|
| 312 |
+
print(
|
| 313 |
+
f"[INFO] Connecting to local server at {LOCAL_SERVER_URL} for task={task_name}",
|
| 314 |
+
file=sys.stderr, flush=True,
|
| 315 |
+
)
|
| 316 |
+
env = FitscriptEnv(base_url=LOCAL_SERVER_URL)
|
| 317 |
+
try:
|
| 318 |
+
await run_episode(task_name, env)
|
| 319 |
+
finally:
|
| 320 |
+
env.close()
|
| 321 |
+
|
| 322 |
+
|
| 323 |
if __name__ == "__main__":
|
| 324 |
+
asyncio.run(main())
|
server/__pycache__/app.cpython-313.pyc
CHANGED
|
Binary files a/server/__pycache__/app.cpython-313.pyc and b/server/__pycache__/app.cpython-313.pyc differ
|
|
|
server/app.py
CHANGED
|
@@ -5,34 +5,34 @@
|
|
| 5 |
# LICENSE file in the root directory of this source tree.
|
| 6 |
|
| 7 |
"""
|
| 8 |
-
FastAPI application for the
|
| 9 |
|
| 10 |
-
|
| 11 |
-
|
| 12 |
|
| 13 |
Endpoints:
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
|
| 20 |
Usage:
|
| 21 |
# Development (with auto-reload):
|
| 22 |
-
uvicorn server.app:app --reload --host 0.0.0.0 --port 8000
|
| 23 |
|
| 24 |
# Production:
|
| 25 |
-
uvicorn server.app:app --host 0.0.0.0 --port 8000
|
| 26 |
-
|
| 27 |
-
# Or run directly:
|
| 28 |
-
python -m server.app
|
| 29 |
"""
|
| 30 |
|
|
|
|
|
|
|
|
|
|
| 31 |
try:
|
| 32 |
from openenv.core.env_server.http_server import create_app
|
| 33 |
except Exception as e: # pragma: no cover
|
| 34 |
raise ImportError(
|
| 35 |
-
"openenv is required
|
| 36 |
) from e
|
| 37 |
|
| 38 |
try:
|
|
@@ -43,42 +43,38 @@ except ModuleNotFoundError:
|
|
| 43 |
from server.FitScript_environment import FitscriptEnvironment
|
| 44 |
|
| 45 |
|
| 46 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
app = create_app(
|
| 48 |
-
|
| 49 |
FitscriptAction,
|
| 50 |
FitscriptObservation,
|
| 51 |
env_name="FitScript",
|
| 52 |
-
max_concurrent_envs=
|
| 53 |
)
|
| 54 |
|
| 55 |
|
| 56 |
def main(host: str = "0.0.0.0", port: int = 8000):
|
| 57 |
-
"""
|
| 58 |
-
Entry point for direct execution via uv run or python -m.
|
| 59 |
-
|
| 60 |
-
This function enables running the server without Docker:
|
| 61 |
-
uv run --project . server
|
| 62 |
-
uv run --project . server --port 8001
|
| 63 |
-
python -m FitScript.server.app
|
| 64 |
-
|
| 65 |
-
Args:
|
| 66 |
-
host: Host address to bind to (default: "0.0.0.0")
|
| 67 |
-
port: Port number to listen on (default: 8000)
|
| 68 |
-
|
| 69 |
-
For production deployments, consider using uvicorn directly with
|
| 70 |
-
multiple workers:
|
| 71 |
-
uvicorn FitScript.server.app:app --workers 4
|
| 72 |
-
"""
|
| 73 |
import uvicorn
|
| 74 |
-
|
| 75 |
uvicorn.run(app, host=host, port=port)
|
| 76 |
|
| 77 |
|
| 78 |
if __name__ == "__main__":
|
| 79 |
import argparse
|
| 80 |
-
|
| 81 |
parser = argparse.ArgumentParser()
|
| 82 |
parser.add_argument("--port", type=int, default=8000)
|
| 83 |
args = parser.parse_args()
|
| 84 |
-
main(port=args.port)
|
|
|
|
| 5 |
# LICENSE file in the root directory of this source tree.
|
| 6 |
|
| 7 |
"""
|
| 8 |
+
FastAPI application for the FitScript Environment.
|
| 9 |
|
| 10 |
+
The task is selected via the FITSCRIPT_TASK environment variable (default: basic_plan).
|
| 11 |
+
Valid values: basic_plan | injury_safe_modification | periodized_program
|
| 12 |
|
| 13 |
Endpoints:
|
| 14 |
+
POST /reset — Reset the environment
|
| 15 |
+
POST /step — Execute an action
|
| 16 |
+
GET /state — Get current environment state
|
| 17 |
+
GET /schema — Get action/observation schemas
|
| 18 |
+
WS /ws — WebSocket endpoint for persistent sessions
|
| 19 |
|
| 20 |
Usage:
|
| 21 |
# Development (with auto-reload):
|
| 22 |
+
FITSCRIPT_TASK=basic_plan uvicorn server.app:app --reload --host 0.0.0.0 --port 8000
|
| 23 |
|
| 24 |
# Production:
|
| 25 |
+
uvicorn server.app:app --host 0.0.0.0 --port 8000
|
|
|
|
|
|
|
|
|
|
| 26 |
"""
|
| 27 |
|
| 28 |
+
import os
|
| 29 |
+
import functools
|
| 30 |
+
|
| 31 |
try:
|
| 32 |
from openenv.core.env_server.http_server import create_app
|
| 33 |
except Exception as e: # pragma: no cover
|
| 34 |
raise ImportError(
|
| 35 |
+
"openenv is required. Install with: pip install openenv-core"
|
| 36 |
) from e
|
| 37 |
|
| 38 |
try:
|
|
|
|
| 43 |
from server.FitScript_environment import FitscriptEnvironment
|
| 44 |
|
| 45 |
|
| 46 |
+
# Read the task from the environment variable; default to basic_plan
|
| 47 |
+
FITSCRIPT_TASK = os.environ.get("FITSCRIPT_TASK", "basic_plan")
|
| 48 |
+
|
| 49 |
+
VALID_TASKS = {"basic_plan", "injury_safe_modification", "periodized_program"}
|
| 50 |
+
if FITSCRIPT_TASK not in VALID_TASKS:
|
| 51 |
+
raise ValueError(
|
| 52 |
+
f"Invalid FITSCRIPT_TASK='{FITSCRIPT_TASK}'. "
|
| 53 |
+
f"Must be one of: {sorted(VALID_TASKS)}"
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
# Use functools.partial so create_app can instantiate the env with the right task_id
|
| 57 |
+
EnvFactory = functools.partial(FitscriptEnvironment, task_id=FITSCRIPT_TASK)
|
| 58 |
+
|
| 59 |
+
# Create the FastAPI app
|
| 60 |
app = create_app(
|
| 61 |
+
EnvFactory,
|
| 62 |
FitscriptAction,
|
| 63 |
FitscriptObservation,
|
| 64 |
env_name="FitScript",
|
| 65 |
+
max_concurrent_envs=4,
|
| 66 |
)
|
| 67 |
|
| 68 |
|
| 69 |
def main(host: str = "0.0.0.0", port: int = 8000):
|
| 70 |
+
"""Entry point for direct execution."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
import uvicorn
|
|
|
|
| 72 |
uvicorn.run(app, host=host, port=port)
|
| 73 |
|
| 74 |
|
| 75 |
if __name__ == "__main__":
|
| 76 |
import argparse
|
|
|
|
| 77 |
parser = argparse.ArgumentParser()
|
| 78 |
parser.add_argument("--port", type=int, default=8000)
|
| 79 |
args = parser.parse_args()
|
| 80 |
+
main(port=args.port)
|
server/requirements.txt
CHANGED
|
@@ -1,6 +1,5 @@
|
|
| 1 |
-
openenv
|
| 2 |
fastapi>=0.115.0
|
| 3 |
uvicorn>=0.24.0
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
|
|
|
| 1 |
+
openenv-core>=0.2.2
|
| 2 |
fastapi>=0.115.0
|
| 3 |
uvicorn>=0.24.0
|
| 4 |
+
python-dotenv>=1.0.0
|
| 5 |
+
openai>=1.0.0
|
|
|