File size: 11,388 Bytes
5dadca5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
"""

test_pushback_ws.py β€” Full multi-actor negotiation + mid-episode complication test.



Shows:

  1. Actor pushback (rejection with specific fixes)

  2. Mid-episode complication injection (knee injury at step 4)

  3. Agent must re-consult actors after injury and revise plan

  4. Final acceptance by all actors



Run with: python test_pushback_ws.py

"""

import asyncio, json, sys, os
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

from models import FitcoachAction, FitcoachObservation
from openenv.core import EnvClient
from openenv.core.client_types import StepResult
from openenv.core.env_server.types import State
from typing import Dict


class FitcoachEnv(EnvClient[FitcoachAction, FitcoachObservation, State]):
    def _step_payload(self, action):
        payload = {
            "action_type":    action.action_type,
            "workout_plan":   action.workout_plan,
            "nutrition_plan": action.nutrition_plan,
        }
        if action.actor_target is not None:
            payload["actor_target"] = action.actor_target
        if action.reasoning is not None:
            payload["reasoning"] = action.reasoning
        return payload

    def _parse_result(self, payload):
        obs = payload.get("observation", {})
        return StepResult(
            observation=FitcoachObservation(
                client_profile  =obs.get("client_profile", {}),
                progress_data   =obs.get("progress_data", {}),
                complications   =obs.get("complications", []),
                actor_response  =obs.get("actor_response", {}),
                actors_consulted=obs.get("actors_consulted", []),
                active_conflicts=obs.get("active_conflicts", []),
                feedback        =obs.get("feedback", ""),
                score_breakdown =obs.get("score_breakdown", {}),
                task_id         =obs.get("task_id", ""),
                phase           =obs.get("phase", ""),
                step_count      =obs.get("step_count", 0),
                best_score      =obs.get("best_score", 0.0),
                done            =payload.get("done", False),
                reward          =payload.get("reward"),
                metadata        =obs.get("metadata", {}),
            ),
            reward=payload.get("reward"),
            done=payload.get("done", False),
        )

    def _parse_state(self, payload):
        return State(
            episode_id=payload.get("episode_id"),
            step_count=payload.get("step_count", 0),
        )


# ── Plans ─────────────────────────────────────────────────────────────────────

BAD_WORKOUT = json.dumps({
    "days": [{"name": "Day 1", "focus": "strength", "exercises": [
        {"name": "Barbell Squat",       "sets": 5, "reps": "5", "rest_seconds": 180, "weight_kg": 100},
        {"name": "Barbell Deadlift",    "sets": 5, "reps": "5", "rest_seconds": 180, "weight_kg": 120},
        {"name": "Barbell Bench Press", "sets": 5, "reps": "5", "rest_seconds": 180, "weight_kg": 80},
    ]}],
    "weekly_volume_sets": 30,
    "notes": "powerlifting β€” wrong for this client"
})

BAD_NUTRITION = json.dumps({
    "daily_targets": {"calories": 1200, "protein_g": 40, "carbs_g": 100, "fats_g": 20},
    "meals": []
})

# After knee injury β€” no lunges, no deep squats, knee-safe exercises only
REVISED_WORKOUT = json.dumps({
    "days": [
        {"name": "Day 1 - Upper", "focus": "chest/back", "exercises": [
            {"name": "Barbell Bench Press", "sets": 3, "reps": "8-12", "rest_seconds": 90, "weight_kg": 50},
            {"name": "Barbell Row",         "sets": 3, "reps": "8-12", "rest_seconds": 90, "weight_kg": 45},
            {"name": "Dumbbell Shoulder Press","sets": 1,"reps": "10-12","rest_seconds": 60,"weight_kg": 14},
        ]},
        {"name": "Day 2 - Lower (knee-safe)", "focus": "glutes/hamstrings", "exercises": [
            {"name": "Hip Thrust",              "sets": 3, "reps": "10-12", "rest_seconds": 90, "weight_kg": 60},
            {"name": "Dumbbell Romanian Deadlift","sets": 3,"reps": "10-12","rest_seconds": 90,"weight_kg": 22},
            {"name": "Leg Press (shallow range)","sets": 1,"reps": "12-15","rest_seconds": 60,"weight_kg": 80},
        ]},
        {"name": "Day 3 - Upper", "focus": "back/biceps", "exercises": [
            {"name": "Cable Row",   "sets": 3, "reps": "10-12", "rest_seconds": 90, "weight_kg": 35},
            {"name": "Lat Pulldown","sets": 1, "reps": "10-12", "rest_seconds": 60, "weight_kg": 30},
        ]},
    ],
    "weekly_volume_sets": 18,
    "notes": "Knee-safe plan β€” 18 sets within intermediate range. No knee-contraindicated exercises."
})

# Priya: weight_loss, tdee=2100, target ~1700 kcal (tdee - 400)
REVISED_NUTRITION = json.dumps({
    "daily_targets": {"calories": 1700, "protein_g": 143, "carbs_g": 155, "fats_g": 47},
    "meals": [
        {"meal_name": "Breakfast",
         "foods": ["3 boiled eggs", "1 slice whole grain bread", "100g curd"],
         "calories": 380, "protein_g": 28},
        {"meal_name": "Lunch",
         "foods": ["120g chicken breast", "100g brown rice", "100g broccoli"],
         "calories": 450, "protein_g": 42},
        {"meal_name": "Snack",
         "foods": ["30g almonds", "1 banana"],
         "calories": 270, "protein_g": 8},
        {"meal_name": "Dinner",
         "foods": ["150g fish", "100g spinach", "1 roti"],
         "calories": 380, "protein_g": 38},
        {"meal_name": "Post-workout",
         "foods": ["200ml milk", "30g whey protein"],
         "calories": 220, "protein_g": 27},
    ]
})


def sep(title=""):
    print(f"\n{'='*60}")
    if title:
        print(f"  {title}")
        print(f"{'='*60}")


async def main():
    env = FitcoachEnv(base_url="http://localhost:8000")

    sep("FITCOACH-RL PUSHBACK + MID-EPISODE COMPLICATION TEST")
    print("Task: plateau_adaptation (Priya Menon, intermediate, weight_loss)")
    print("Complication schedule: knee injury injected at step 4")

    # Reset
    result = await env.reset()
    obs = result.observation
    client = obs.client_profile
    print(f"\nClient:     {client.get('name')}")
    print(f"Level:      {client.get('fitness_level')} | Goal: {client.get('goal')}")
    print(f"Equipment:  {client.get('available_equipment')}")
    print(f"Injuries:   {client.get('injuries') or 'none (so far)'}")
    print(f"Complications: {obs.complications}")

    # Step 1: Consult fitness
    sep("Step 1: consult fitness_advisor")
    result = await env.step(FitcoachAction(
        action_type="consult_actor", actor_target="fitness_advisor",
        workout_plan="{}", nutrition_plan="{}"
    ))
    fa = result.observation.actor_response
    print(f"Volume range: {fa.get('constraints', {}).get('weekly_sets_min')}–{fa.get('constraints', {}).get('weekly_sets_max')} sets")
    print(f"Banned:       {fa.get('constraints', {}).get('must_avoid_exercises', [])}")
    print(f"Consulted:    {result.observation.actors_consulted}")

    # Step 2: Consult nutrition
    sep("Step 2: consult nutrition_advisor")
    result = await env.step(FitcoachAction(
        action_type="consult_actor", actor_target="nutrition_advisor",
        workout_plan="{}", nutrition_plan="{}"
    ))
    na = result.observation.actor_response
    print(f"Calorie target: {na.get('constraints', {}).get('calories_target')} kcal")
    print(f"Protein min:    {na.get('constraints', {}).get('protein_minimum_g')}g")
    print(f"Consulted:      {result.observation.actors_consulted}")

    # Step 3: Consult progress analyst
    sep("Step 3: consult progress_analyst")
    result = await env.step(FitcoachAction(
        action_type="consult_actor", actor_target="progress_analyst",
        workout_plan="{}", nutrition_plan="{}"
    ))
    pa = result.observation.actor_response
    print(f"Plateau status: {pa.get('recommendations', {}).get('plateau_status')}")
    print(f"Must adapt:     {pa.get('constraints', {}).get('must_adapt_if_plateau')}")
    print(f"Conflicts:      {len(result.observation.active_conflicts)} detected")
    print(f"Consulted:      {result.observation.actors_consulted}")

    # Step 4: Submit BAD plan β€” AND knee injury injected this step
    sep("Step 4: submit BAD plan + 🚨 KNEE INJURY INJECTION")
    print("(Submitting wrong plan: barbell-heavy, 30 sets, 1200 kcal)")
    print("(Also: complication_schedule fires β†’ new_injury:knee injected)")
    result = await env.step(FitcoachAction(
        action_type="submit_plan",
        workout_plan=BAD_WORKOUT,
        nutrition_plan=BAD_NUTRITION,
        reasoning="initial attempt"
    ))
    print(f"\nReward: {result.reward:.2f} | Done: {result.done}")
    print(f"\nFeedback:\n{result.observation.feedback}")

    if result.done:
        print("\n[Episode ended early]")
        await env.close()
        return

    # Step 5: Re-consult fitness_advisor β€” NOW sees knee injury in client
    sep("Step 5: re-consult fitness_advisor (post-injury)")
    print("Agent re-consults after injury injection to get updated constraints")
    result = await env.step(FitcoachAction(
        action_type="consult_actor", actor_target="fitness_advisor",
        workout_plan="{}", nutrition_plan="{}"
    ))
    fa2 = result.observation.actor_response
    print(f"Updated banned exercises: {fa2.get('constraints', {}).get('must_avoid_exercises', [])}")
    print(f"Client injuries now:      {result.observation.client_profile.get('injuries')}")
    print(f"Feedback:\n{result.observation.feedback[:300]}")

    # Step 6: Submit REVISED plan β€” knee-safe, correct macros, right volume
    sep("Step 6: submit REVISED plan (knee-safe + correct macros)")
    print("Revised: no lunges/squats, hip thrusts instead, 1700 kcal for weight loss")
    result = await env.step(FitcoachAction(
        action_type="submit_plan",
        workout_plan=REVISED_WORKOUT,
        nutrition_plan=REVISED_NUTRITION,
        reasoning=(
            "Revised after actor rejections: "
            "replaced barbell-only exercises with barbell+cable alternatives, "
            "reduced volume to 21 sets (within 12-18 intermediate range), "
            "adjusted calories to 1700 kcal for weight_loss goal, "
            "removed all knee-contraindicated exercises (lunges, deep squats) "
            "after new_injury:knee was injected at step 4."
        )
    ))
    print(f"\nReward: {result.reward:.2f} | Done: {result.done}")
    print(f"\nFeedback:\n{result.observation.feedback}")
    print(f"\nScore breakdown:")
    for k, v in result.observation.score_breakdown.items():
        icon = "βœ“" if v >= 0.8 else ("~" if v >= 0.5 else "βœ—")
        print(f"  {icon} {k}: {v:.2f}")

    await env.close()
    sep("TEST COMPLETE")
    print(f"Final reward: {result.reward:.2f}")
    if result.reward >= 0.85:
        print("βœ“ Agent successfully negotiated with all actors and adapted to mid-episode injury!")
    else:
        print("~ Plan partially accepted β€” more revision needed")


asyncio.run(main())