File size: 25,863 Bytes
8555ea6
 
 
 
 
 
 
0446283
8555ea6
0446283
 
 
8555ea6
 
0446283
8555ea6
0446283
8555ea6
 
 
 
 
 
 
 
 
 
0446283
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8555ea6
 
0446283
 
 
 
 
 
 
 
 
8555ea6
 
 
 
0446283
 
 
 
 
 
 
 
 
 
 
 
8555ea6
0446283
8555ea6
 
 
0446283
8555ea6
 
0446283
8555ea6
 
0446283
 
 
8555ea6
 
0446283
 
 
 
 
8555ea6
 
 
 
 
 
0446283
8555ea6
 
0446283
8555ea6
 
0446283
8555ea6
 
0446283
8555ea6
0446283
 
 
 
 
 
 
 
 
 
 
8555ea6
0446283
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8555ea6
 
0446283
 
 
 
 
 
 
8555ea6
 
 
 
0446283
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
FitScript Environment Implementation.

Simulates a real-world fitness prescription task: generating, evaluating,
and refining personalized workout plans. Supports three tasks of increasing
difficulty with deterministic graders.
"""

import json
from uuid import uuid4
from typing import Dict, Any, Tuple

from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State

try:
    from ..models import FitscriptAction, FitscriptObservation
except ImportError:
    from models import FitscriptAction, FitscriptObservation


# ---------------------------------------------------------------------------
# Grader base class
# ---------------------------------------------------------------------------

class BaseTask:
    """Base class for all FitScript tasks."""

    client_profile: dict = {}
    max_steps: int = 5

    def grade(
        self, action: FitscriptAction, step: int
    ) -> Tuple[float, Dict[str, float], str]:
        """
        Returns (score: float in [0,1], breakdown: dict, feedback: str).
        Must be implemented by every concrete task.
        """
        raise NotImplementedError


# ---------------------------------------------------------------------------
# Task 1 - EASY: Basic Plan Generation
# ---------------------------------------------------------------------------

class BasicPlanTask(BaseTask):
    """
    Scenario: 35-year-old beginner, no injuries, 3 days/week, home, no equipment.
    Grader: 4 criteria worth 0.25 each.
    Episode ends when plan submitted OR after 3 steps.
    """

    client_profile = {
        "age": 35,
        "fitness_level": "beginner",
        "goal": "general fitness",
        "equipment": [],
        "injuries": [],
        "days_per_week": 3,
    }
    max_steps = 3

    # Exercises that require equipment --- flag any appearance
    EQUIPMENT_EXERCISES = {
        "barbell", "dumbbell", "kettlebell", "cable", "machine",
        "bench press", "squat rack", "pull-up bar", "resistance band",
        "treadmill", "stationary bike",
    }

    # Advanced movements not appropriate for beginners
    ADVANCED_MOVEMENTS = {
        "muscle-up", "muscle up", "handstand push-up", "handstand pushup",
        "pistol squat", "one-arm push-up", "planche", "front lever",
        "back lever", "dragon flag",
    }

    def grade(self, action: FitscriptAction, step: int) -> Tuple[float, Dict[str, float], str]:
        scores: Dict[str, float] = {}
        feedback_parts = []

        try:
            plan = json.loads(action.plan) if action.plan else {}
        except json.JSONDecodeError:
            plan = {}

        # Criterion 1: Plan contains exactly 3 workout days
        days = plan.get("days", plan.get("workout_days", []))
        if isinstance(days, list) and len(days) == 3:
            scores["three_days"] = 0.25
            feedback_parts.append("βœ“ Plan has exactly 3 workout days.")
        else:
            scores["three_days"] = 0.0
            found = len(days) if isinstance(days, list) else "unknown"
            feedback_parts.append(f"βœ— Expected 3 workout days, found {found}.")

        # Criterion 2: All exercises are bodyweight-only
        all_exercises = _extract_exercises(plan)
        plan_text_lower = action.plan.lower()
        equipment_found = [e for e in self.EQUIPMENT_EXERCISES if e in plan_text_lower]
        if not equipment_found:
            scores["bodyweight_only"] = 0.25
            feedback_parts.append("βœ“ No equipment required --- all bodyweight exercises.")
        else:
            scores["bodyweight_only"] = 0.0
            feedback_parts.append(f"βœ— Equipment-dependent exercises found: {equipment_found[:3]}.")

        # Criterion 3: Each day has 4-8 exercises with sets and reps defined
        if isinstance(days, list) and len(days) > 0:
            days_ok = 0
            for day in days:
                exs = day.get("exercises", [])
                if 4 <= len(exs) <= 8 and all(
                    e.get("sets") and e.get("reps") for e in exs
                ):
                    days_ok += 1
            if days_ok == len(days) and len(days) > 0:
                scores["exercise_structure"] = 0.25
                feedback_parts.append("βœ“ Each day has 4-8 exercises with sets and reps defined.")
            else:
                scores["exercise_structure"] = 0.0
                feedback_parts.append(
                    f"βœ— {days_ok}/{len(days)} days have 4-8 exercises with sets+reps. "
                    "Ensure every exercise has 'sets' and 'reps' fields."
                )
        else:
            scores["exercise_structure"] = 0.0
            feedback_parts.append("βœ— Cannot evaluate exercise structure: no days found.")

        # Criterion 4: Beginner-appropriate (reps <= 15, no advanced movements)
        advanced_found = [m for m in self.ADVANCED_MOVEMENTS if m in plan_text_lower]
        reps_too_high = _check_reps_exceed(plan, max_reps=15)
        if not advanced_found and not reps_too_high:
            scores["beginner_appropriate"] = 0.25
            feedback_parts.append("βœ“ Plan is beginner-appropriate (no advanced movements, reps ≀ 15).")
        else:
            scores["beginner_appropriate"] = 0.0
            if advanced_found:
                feedback_parts.append(f"βœ— Advanced movements not suitable for beginners: {advanced_found}.")
            if reps_too_high:
                feedback_parts.append("βœ— Some exercises have reps > 15 --- too high for a beginner.")

        score = sum(scores.values())
        feedback = " ".join(feedback_parts)
        return score, scores, feedback


# ---------------------------------------------------------------------------
# Task 2 - MEDIUM: Injury-Safe Plan Modification
# ---------------------------------------------------------------------------

class InjurySafeTask(BaseTask):
    """
    Scenario: Intermediate client with lower-back injury. Pre-generated plan
    contains back squats, deadlifts, and bent-over rows. Agent must modify safely.
    Episode ends when modification submitted OR after 5 steps.
    """

    client_profile = {
        "age": 30,
        "fitness_level": "intermediate",
        "goal": "strength maintenance",
        "equipment": ["barbell", "dumbbells", "cables", "machines"],
        "injuries": ["lower back"],
        "days_per_week": 4,
        "initial_plan": {
            "days": [
                {
                    "name": "Day 1 - Lower Body",
                    "exercises": [
                        {"name": "Back Squat", "sets": 4, "reps": 8},
                        {"name": "Deadlift", "sets": 3, "reps": 5},
                        {"name": "Leg Press", "sets": 3, "reps": 10},
                        {"name": "Calf Raises", "sets": 4, "reps": 15},
                    ],
                },
                {
                    "name": "Day 2 - Upper Body",
                    "exercises": [
                        {"name": "Bench Press", "sets": 4, "reps": 8},
                        {"name": "Bent-Over Row", "sets": 4, "reps": 8},
                        {"name": "Overhead Press", "sets": 3, "reps": 10},
                        {"name": "Pull-Up", "sets": 3, "reps": "max"},
                    ],
                },
            ]
        },
    }
    max_steps = 5

    DEADLIFT_REPLACEMENTS = {
        "romanian deadlift", "rdl", "leg press", "leg curl",
        "hip thrust", "glute bridge", "trap bar deadlift",
    }
    SQUAT_REPLACEMENTS = {
        "goblet squat", "wall sit", "wall squat", "leg press",
        "box squat", "safety bar squat", "hack squat",
    }
    ROW_REPLACEMENTS = {
        "seated cable row", "seated row", "machine row",
        "chest-supported row", "chest supported row",
        "t-bar row", "seal row",
    }
    ORIGINAL_MUSCLE_GROUPS = {"quads", "hamstrings", "glutes", "back", "chest", "shoulders"}

    def grade(self, action: FitscriptAction, step: int) -> Tuple[float, Dict[str, float], str]:
        scores: Dict[str, float] = {}
        feedback_parts = []
        plan_text_lower = action.plan.lower()

        # Criterion 1: Deadlifts removed or replaced with safe alternatives
        has_deadlift = "deadlift" in plan_text_lower and not any(
            r in plan_text_lower for r in self.DEADLIFT_REPLACEMENTS
        )
        raw_deadlift = "deadlift" in plan_text_lower and "romanian" not in plan_text_lower and "rdl" not in plan_text_lower
        if not raw_deadlift:
            scores["deadlift_removed"] = 0.25
            feedback_parts.append("βœ“ Conventional deadlift removed or replaced safely.")
        else:
            scores["deadlift_removed"] = 0.0
            feedback_parts.append(
                "βœ— Conventional deadlift still present. Replace with Romanian deadlift, leg press, or hip thrust."
            )

        # Criterion 2: Back squats replaced with safe alternatives
        has_back_squat = "back squat" in plan_text_lower
        if not has_back_squat:
            scores["squat_replaced"] = 0.25
            feedback_parts.append("βœ“ Back squat removed or replaced safely.")
        else:
            scores["squat_replaced"] = 0.0
            feedback_parts.append(
                "βœ— Back squat still present. Replace with goblet squat, wall sit, or leg press."
            )

        # Criterion 3: Bent-over rows replaced with seated/machine variants
        has_bent_over_row = "bent-over row" in plan_text_lower or "bent over row" in plan_text_lower
        if not has_bent_over_row:
            scores["rows_replaced"] = 0.25
            feedback_parts.append("βœ“ Bent-over rows removed or replaced with spine-neutral variant.")
        else:
            scores["rows_replaced"] = 0.0
            feedback_parts.append(
                "βœ— Bent-over rows still present. Replace with seated cable rows or machine rows."
            )

        # Criterion 4: Plan retains same muscle group targets
        # Proxy: check that back/leg work still appears in the plan
        back_work = any(
            t in plan_text_lower
            for t in ["row", "pull", "lat", "back", "rhomboid"]
        )
        leg_work = any(
            t in plan_text_lower
            for t in ["squat", "press", "lunge", "hip", "glute", "quad", "hamstring", "leg"]
        )
        if back_work and leg_work:
            scores["muscle_targets_retained"] = 0.25
            feedback_parts.append("βœ“ Original muscle groups (back, legs) still targeted despite modifications.")
        else:
            scores["muscle_targets_retained"] = 0.0
            missing = []
            if not back_work:
                missing.append("back")
            if not leg_work:
                missing.append("legs")
            feedback_parts.append(
                f"βœ— Missing muscle group coverage: {missing}. Ensure modifications keep the same target areas."
            )

        score = sum(scores.values())
        feedback = " ".join(feedback_parts)
        return score, scores, feedback


# ---------------------------------------------------------------------------
# Task 3 - HARD: Periodized 4-Week Program
# ---------------------------------------------------------------------------

class PeriodizedProgramTask(BaseTask):
    """
    Scenario: Advanced powerlifter, 5 days/week, full gym, competition in 5 weeks.
    Needs 4-week block with deload in week 4.
    Episode ends when full program submitted OR after 8 steps.
    """

    client_profile = {
        "age": 27,
        "fitness_level": "advanced",
        "goal": "powerlifting competition prep",
        "equipment": ["full gym", "barbell", "squat rack", "bench", "deadlift platform"],
        "injuries": [],
        "days_per_week": 5,
        "competition_weeks_out": 5,
        "weak_points": ["upper back", "lockout strength"],
        "current_maxes": {"squat": 180, "bench": 120, "deadlift": 220},
    }
    max_steps = 8

    COMPETITION_LIFTS = {"squat", "bench", "bench press", "deadlift"}

    def grade(self, action: FitscriptAction, step: int) -> Tuple[float, Dict[str, float], str]:
        scores: Dict[str, float] = {}
        feedback_parts = []

        try:
            plan = json.loads(action.plan) if action.plan else {}
        except json.JSONDecodeError:
            plan = {}

        weeks = plan.get("weeks", [])

        # Criterion 1: 4 distinct weeks, each with 5 training days
        if isinstance(weeks, list) and len(weeks) == 4:
            all_five_days = all(
                len(w.get("days", w.get("training_days", []))) == 5
                for w in weeks
            )
            if all_five_days:
                scores["week_structure"] = 0.2
                feedback_parts.append("βœ“ 4 weeks present, each with 5 training days.")
            else:
                scores["week_structure"] = 0.1
                feedback_parts.append(
                    "~ 4 weeks present but not all weeks have exactly 5 training days."
                )
        else:
            scores["week_structure"] = 0.0
            found_weeks = len(weeks) if isinstance(weeks, list) else "unknown"
            feedback_parts.append(
                f"βœ— Expected 4 weeks with 5 days each. Found {found_weeks} weeks."
            )

        # Criterion 2: Weeks 1-3 show progressive overload
        if isinstance(weeks, list) and len(weeks) >= 3:
            intensities = []
            for w in weeks[:3]:
                # Accept intensity as explicit field or infer from RPE/percentage keywords
                intensity = w.get("intensity") or w.get("avg_rpe") or w.get("percentage")
                if intensity is None:
                    # Try to infer from week label/description
                    desc = str(w).lower()
                    if "heavy" in desc or "high" in desc:
                        intensity = 85
                    elif "moderate" in desc or "medium" in desc:
                        intensity = 75
                    else:
                        intensity = None
                intensities.append(intensity)

            if all(i is not None for i in intensities) and intensities[0] < intensities[1] < intensities[2]:
                scores["progressive_overload"] = 0.2
                feedback_parts.append("βœ“ Weeks 1-3 show clear progressive overload (increasing intensity).")
            elif all(i is not None for i in intensities):
                scores["progressive_overload"] = 0.1
                feedback_parts.append(
                    "~ Intensity values present but progressive overload pattern not clearly ascending across weeks 1-3."
                )
            else:
                scores["progressive_overload"] = 0.0
                feedback_parts.append(
                    "βœ— Cannot verify progressive overload. Add 'intensity', 'avg_rpe', or 'percentage' fields to each week."
                )
        else:
            scores["progressive_overload"] = 0.0
            feedback_parts.append("βœ— Fewer than 3 weeks present; cannot verify progressive overload.")

        # Criterion 3: Week 4 is a deload (volume reduced >= 40% vs week 3)
        if isinstance(weeks, list) and len(weeks) == 4:
            w3 = weeks[2]
            w4 = weeks[3]
            w3_vol = _estimate_volume(w3)
            w4_vol = _estimate_volume(w4)
            is_deload_label = "deload" in str(w4).lower()
            if w3_vol > 0 and w4_vol > 0:
                reduction = (w3_vol - w4_vol) / w3_vol
                if reduction >= 0.40:
                    scores["deload_week"] = 0.2
                    feedback_parts.append(
                        f"βœ“ Week 4 deload: volume reduced by {reduction*100:.0f}% vs week 3."
                    )
                elif is_deload_label:
                    scores["deload_week"] = 0.1
                    feedback_parts.append(
                        "~ Week 4 labeled as deload but volume reduction < 40%. Reduce total sets/volume further."
                    )
                else:
                    scores["deload_week"] = 0.0
                    feedback_parts.append(
                        f"βœ— Week 4 volume only reduced by {reduction*100:.0f}%. Deload requires >= 40% reduction."
                    )
            elif is_deload_label:
                scores["deload_week"] = 0.1
                feedback_parts.append(
                    "~ Week 4 labeled as deload but no volume data to verify the 40% reduction threshold."
                )
            else:
                scores["deload_week"] = 0.0
                feedback_parts.append(
                    "βœ— Week 4 not identified as a deload and volume data insufficient to verify."
                )
        else:
            scores["deload_week"] = 0.0
            feedback_parts.append("βœ— Fewer than 4 weeks present; cannot evaluate deload week.")

        # Criterion 4: Competition lifts appear as primary movements on separate days
        plan_text_lower = action.plan.lower()
        squat_present = "squat" in plan_text_lower
        bench_present = "bench" in plan_text_lower
        deadlift_present = "deadlift" in plan_text_lower
        if squat_present and bench_present and deadlift_present:
            scores["competition_lifts"] = 0.2
            feedback_parts.append("βœ“ All three competition lifts (squat, bench, deadlift) present as primary movements.")
        else:
            missing = []
            if not squat_present:
                missing.append("squat")
            if not bench_present:
                missing.append("bench press")
            if not deadlift_present:
                missing.append("deadlift")
            scores["competition_lifts"] = 0.0
            feedback_parts.append(f"βœ— Missing competition lifts: {missing}.")

        # Criterion 5 (bonus): Accessory work targets weak points (upper back, lockout)
        weak_point_keywords = ["face pull", "upper back", "row", "rdl", "pause", "lockout", "band pull apart", "rear delt"]
        accessory_bonus = sum(1 for kw in weak_point_keywords if kw in plan_text_lower)
        if accessory_bonus >= 3:
            scores["accessory_weak_points"] = 0.2
            feedback_parts.append("βœ“ Accessory work targets weak points (upper back, lockout strength).")
        elif accessory_bonus >= 1:
            scores["accessory_weak_points"] = 0.1
            feedback_parts.append("~ Some accessory work present but weak points (upper back, lockout) not fully addressed.")
        else:
            scores["accessory_weak_points"] = 0.0
            feedback_parts.append("βœ— No accessory work targeting weak points (upper back, lockout strength).")

        score = min(1.0, sum(scores.values()))
        feedback = " ".join(feedback_parts)
        return score, scores, feedback


# ---------------------------------------------------------------------------
# Helper utilities
# ---------------------------------------------------------------------------

def _extract_exercises(plan: dict) -> list:
    """Flatten all exercises from all days in a plan."""
    exercises = []
    for day in plan.get("days", plan.get("workout_days", [])):
        if isinstance(day, dict):
            exercises.extend(day.get("exercises", []))
    return exercises


def _check_reps_exceed(plan: dict, max_reps: int) -> bool:
    """Return True if any exercise in the plan has reps > max_reps."""
    for ex in _extract_exercises(plan):
        reps = ex.get("reps")
        if isinstance(reps, (int, float)) and reps > max_reps:
            return True
    return False


def _estimate_volume(week: dict) -> float:
    """Estimate total volume (sets Γ— reps) across all days in a week."""
    total = 0
    for day in week.get("days", week.get("training_days", [])):
        if isinstance(day, dict):
            for ex in day.get("exercises", []):
                sets = ex.get("sets", 0)
                reps = ex.get("reps", 0)
                if isinstance(sets, (int, float)) and isinstance(reps, (int, float)):
                    total += sets * reps
    # Also accept a flat 'total_sets' key on the week
    if total == 0:
        total = week.get("total_sets", 0) * 8  # assume ~8 reps avg if only sets given
    return float(total)


# ---------------------------------------------------------------------------
# Task registry
# ---------------------------------------------------------------------------

TASKS: Dict[str, BaseTask] = {
    "basic_plan": BasicPlanTask(),
    "injury_safe_modification": InjurySafeTask(),
    "periodized_program": PeriodizedProgramTask(),
}


# ---------------------------------------------------------------------------
# Main environment class
# ---------------------------------------------------------------------------

class FitscriptEnvironment(Environment):
    """
    FitScript fitness prescription environment.

    Three tasks of increasing difficulty:
      - basic_plan (easy): generate a 3-day bodyweight beginner plan
      - injury_safe_modification (medium): modify a plan for a lower-back-injured client
      - periodized_program (hard): design a 4-week periodized powerlifting block

    Rewards are always in [0.0, 1.0]. Episodes terminate on task completion
    (score >= 0.99) or when max_steps is reached.
    """

    SUPPORTS_CONCURRENT_SESSIONS: bool = True

    def __init__(self, task_id: str = "basic_plan"):
        """
        Initialize the FitScript environment.

        Args:
            task_id: One of 'basic_plan', 'injury_safe_modification', 'periodized_program'.
        """
        if task_id not in TASKS:
            raise ValueError(
                f"Unknown task_id '{task_id}'. Valid options: {list(TASKS.keys())}"
            )
        self._task_id = task_id
        self._state = State(episode_id=str(uuid4()), step_count=0)
        self._last_plan: str = ""

    def reset(self) -> FitscriptObservation:
        """
        Reset the environment for the current task.

        Returns:
            FitscriptObservation with the client profile and welcome message.
        """
        self._state = State(episode_id=str(uuid4()), step_count=0)
        self._last_plan = ""

        task = TASKS[self._task_id]

        return FitscriptObservation(
            client_profile=task.client_profile,
            feedback="Welcome! Review the client profile and generate a plan.",
            score_breakdown={},
            task_id=self._task_id,
            step_count=0,
            done=False,
            reward=0.0,
        )

    def step(self, action: FitscriptAction) -> FitscriptObservation:  # type: ignore[override]
        """
        Execute a step: grade the submitted plan and return feedback.

        Args:
            action: FitscriptAction with action_type, plan JSON string, and optional reasoning.

        Returns:
            FitscriptObservation with score breakdown and feedback.
        """
        self._state.step_count += 1
        task = TASKS[self._task_id]

        # Penalty: empty or null plan
        if not action.plan or action.plan.strip() in ("", "null", "{}"):
            return FitscriptObservation(
                client_profile=task.client_profile,
                feedback="βœ— Empty or null plan submitted. Please provide a structured workout plan.",
                score_breakdown={},
                task_id=self._task_id,
                step_count=self._state.step_count,
                done=self._state.step_count >= task.max_steps,
                reward=0.0,
            )

        # Penalty: identical plan submitted twice in a row
        if action.plan == self._last_plan:
            return FitscriptObservation(
                client_profile=task.client_profile,
                feedback="βœ— Identical plan submitted twice. Please revise based on the previous feedback.",
                score_breakdown={},
                task_id=self._task_id,
                step_count=self._state.step_count,
                done=self._state.step_count >= task.max_steps,
                reward=0.0,
            )

        self._last_plan = action.plan

        # Grade the plan
        score, breakdown, feedback = task.grade(action, self._state.step_count)

        # Safety penalty: contraindicated exercises for injured clients
        injuries = task.client_profile.get("injuries", [])
        if injuries:
            plan_lower = action.plan.lower()
            CONTRAINDICATED = {
                "lower back": ["deadlift", "back squat", "good morning", "bent-over row"],
                "knee": ["lunge", "leg press", "deep squat", "box jump"],
                "shoulder": ["overhead press", "upright row", "behind neck"],
            }
            for injury in injuries:
                banned = CONTRAINDICATED.get(injury, [])
                if any(b in plan_lower for b in banned):
                    score = max(0.0, score - 0.3)
                    feedback += " ⚠️ Safety penalty applied: plan contains exercises contraindicated for the client's injury."
                    break

        # Clamp to [0.0, 1.0]
        score = max(0.0, min(1.0, score))

        done = score >= 0.99 or self._state.step_count >= task.max_steps

        return FitscriptObservation(
            client_profile=task.client_profile,
            feedback=feedback,
            score_breakdown=breakdown,
            task_id=self._task_id,
            step_count=self._state.step_count,
            done=done,
            reward=score,
        )

    @property
    def state(self) -> State:
        """Get the current environment state."""
        return self._state