coffeine16 commited on
Commit
b7bec01
Β·
verified Β·
1 Parent(s): 1a74e1c

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. openenv.yaml +20 -9
  2. server/FitCoach_environment.py +78 -31
  3. server/app.py +10 -30
  4. utils/curriculum.py +411 -0
openenv.yaml CHANGED
@@ -14,23 +14,34 @@ tasks:
14
  vegetarian client with dumbbells and pull-up bar only. Graded on
15
  equipment compliance, macro accuracy against IFCT 2017 database,
16
  volume appropriateness, and nutrition-training coherence.
 
17
 
18
  - id: plateau_adaptation
19
  name: Plateau Detection & Adaptation
20
  difficulty: medium
21
  description: >
22
  An intermediate client has been stuck at the same weight for 14 days.
23
- Plateau detected via linear regression on 7-day rolling mean (Hall et al. 2011).
24
- Agent must detect the plateau and adapt: increase training volume by β‰₯10%
25
- OR reduce calories by β‰₯150 kcal. Also applies progressive overload rules
26
- to previously tracked exercises.
27
 
28
  - id: conflict_resolution
29
  name: Multi-Conflict Resolution
30
  difficulty: hard
31
  description: >
32
- Three simultaneous challenges: (1) 3-week weight plateau requiring
33
- adaptation, (2) new lower-back injury making deadlifts and bent-over
34
- rows contraindicated, (3) goal change from weight_loss to maintenance
35
- requiring caloric adjustment upward. Hard safety penalty of βˆ’0.3 for
36
- any constraint or injury violation.
 
 
 
 
 
 
 
 
 
 
 
14
  vegetarian client with dumbbells and pull-up bar only. Graded on
15
  equipment compliance, macro accuracy against IFCT 2017 database,
16
  volume appropriateness, and nutrition-training coherence.
17
+ Agent must consult all 3 specialist actors before submitting.
18
 
19
  - id: plateau_adaptation
20
  name: Plateau Detection & Adaptation
21
  difficulty: medium
22
  description: >
23
  An intermediate client has been stuck at the same weight for 14 days.
24
+ Plateau detected via linear regression on 7-day rolling mean.
25
+ FitnessAdvisor and NutritionAdvisor will CONFLICT on how to adapt.
26
+ Agent must consult all actors, resolve the conflict, and submit
27
+ an adapted plan.
28
 
29
  - id: conflict_resolution
30
  name: Multi-Conflict Resolution
31
  difficulty: hard
32
  description: >
33
+ Three simultaneous challenges requiring multi-actor coordination:
34
+ (1) 3-week weight plateau, (2) new lower-back injury creating an
35
+ injury-overload conflict between actors, (3) goal change requiring
36
+ caloric adjustment. All three actors will have conflicting
37
+ recommendations that the orchestrator must resolve.
38
+
39
+ - id: curriculum
40
+ name: Adaptive Curriculum (Self-Improvement)
41
+ difficulty: adaptive
42
+ description: >
43
+ Adaptive curriculum mode β€” generates RANDOM clients each episode
44
+ so the agent cannot memorize answers. Difficulty starts at easy
45
+ and escalates to medium then hard as the agent scores consistently
46
+ above 0.8. Difficulty drops back if agent struggles. This implements
47
+ Theme 4 (Self-Improvement) with procedural client generation.
server/FitCoach_environment.py CHANGED
@@ -57,6 +57,7 @@ from utils.nutrition import calculate_macro_targets, verify_meal_macros
57
  from utils.actors import (
58
  fitness_actor, nutrition_actor, progress_actor, detect_actor_conflicts
59
  )
 
60
 
61
 
62
  # ── Domain constraint tables ──────────────────────────────────────────────────
@@ -248,6 +249,17 @@ TASK_CONFIGS: dict[str, dict] = {
248
  "goal_change:weight_lossβ†’maintenance",
249
  ],
250
  },
 
 
 
 
 
 
 
 
 
 
 
251
  }
252
 
253
  ALL_ACTORS = {"fitness_advisor", "nutrition_advisor", "progress_analyst"}
@@ -270,6 +282,7 @@ def grade_plan(
270
  action: FitcoachAction,
271
  config: dict,
272
  actors_consulted: list[str],
 
273
  active_conflicts: list[dict],
274
  safety_already_violated: bool,
275
  ) -> tuple[float, dict[str, float], str, bool]:
@@ -500,8 +513,7 @@ def grade_plan(
500
  scores["coherence"] = 1.0
501
  fb.append(f"βœ“ Coherence: {agent_cal:.0f} kcal supports {agent_sets} sets/wk.")
502
 
503
- # ── 8. Actor coordination (NEW) ───────────────────────────────────────────
504
- # Check which actors were relevant for this task
505
  needs_progress = bool(progress.get("weight_series") or "plateau" in comps)
506
  required_actors = {"fitness_advisor", "nutrition_advisor"}
507
  if needs_progress:
@@ -510,38 +522,59 @@ def grade_plan(
510
  consulted_set = set(actors_consulted)
511
  missing_actors = required_actors - consulted_set
512
 
513
- # Check conflicts were resolved
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
514
  unresolved = []
515
  for conflict in active_conflicts:
516
- conflict_type = conflict.get("type", "")
517
- if conflict_type == "plateau_volume_conflict":
518
- if scores.get("plateau_response", 1.0) < 0.5:
519
- unresolved.append(conflict_type)
520
- elif conflict_type == "volume_calorie_mismatch":
521
- if scores.get("coherence", 1.0) < 0.5:
522
- unresolved.append(conflict_type)
523
- elif conflict_type == "injury_overload_conflict":
524
- if scores.get("constraint_respect", 1.0) < 0.5:
525
- unresolved.append(conflict_type)
526
-
527
- if not missing_actors and not unresolved:
528
- scores["actor_coordination"] = 1.0
529
- fb.append(
530
- f"βœ“ Actor coordination: consulted all required actors "
531
- f"({sorted(consulted_set)}) and resolved all conflicts."
532
- )
533
- elif missing_actors:
534
- penalty = len(missing_actors) / len(required_actors)
535
- scores["actor_coordination"] = max(0.0, 1.0 - penalty)
536
- fb.append(
537
- f"βœ— Actor coordination: did not consult {sorted(missing_actors)}. "
538
- f"Always consult all relevant actors before submitting."
539
- )
540
  else:
541
- scores["actor_coordination"] = max(0.0, 0.5 - 0.2 * len(unresolved))
542
- fb.append(
543
- f"βœ— Actor coordination: {len(unresolved)} conflict(s) unresolved: {unresolved}."
544
- )
 
 
545
 
546
  # ── Aggregate ─────────────────────────────────────────────────────────────
547
  active = list(scores.values())
@@ -572,6 +605,8 @@ class FitcoachEnvironment(Environment):
572
  f"Unknown task_id '{task_id}'. Valid: {list(TASK_CONFIGS.keys())}"
573
  )
574
  self._task_id = task_id
 
 
575
  self._config = TASK_CONFIGS[task_id]
576
  self._state = State(episode_id=str(uuid4()), step_count=0)
577
  self._phase_idx = 0
@@ -583,6 +618,17 @@ class FitcoachEnvironment(Environment):
583
  self._active_conflicts: list[dict] = []
584
 
585
  def reset(self) -> FitcoachObservation:
 
 
 
 
 
 
 
 
 
 
 
586
  self._state = State(episode_id=str(uuid4()), step_count=0)
587
  self._phase_idx = 0
588
  self._best_score = 0.0
@@ -768,6 +814,7 @@ class FitcoachEnvironment(Environment):
768
  reward, breakdown, feedback, safety_now = grade_plan(
769
  action, cfg,
770
  self._actors_consulted,
 
771
  self._active_conflicts,
772
  self._safety_hit,
773
  )
 
57
  from utils.actors import (
58
  fitness_actor, nutrition_actor, progress_actor, detect_actor_conflicts
59
  )
60
+ from utils.curriculum import CurriculumManager, generate_client
61
 
62
 
63
  # ── Domain constraint tables ──────────────────────────────────────────────────
 
249
  "goal_change:weight_lossβ†’maintenance",
250
  ],
251
  },
252
+
253
+
254
+ # ── Theme 4: Adaptive curriculum ───────────────────────────────────────
255
+ "curriculum": {
256
+ "max_steps": 7,
257
+ "phases": ["initial"],
258
+ "description": "Adaptive curriculum β€” random clients, difficulty escalates with performance.",
259
+ "client": {},
260
+ "progress_data": {},
261
+ "complications": [],
262
+ },
263
  }
264
 
265
  ALL_ACTORS = {"fitness_advisor", "nutrition_advisor", "progress_analyst"}
 
282
  action: FitcoachAction,
283
  config: dict,
284
  actors_consulted: list[str],
285
+ actor_responses: dict[str, dict],
286
  active_conflicts: list[dict],
287
  safety_already_violated: bool,
288
  ) -> tuple[float, dict[str, float], str, bool]:
 
513
  scores["coherence"] = 1.0
514
  fb.append(f"βœ“ Coherence: {agent_cal:.0f} kcal supports {agent_sets} sets/wk.")
515
 
516
+ # ── 8. Actor coordination (TIGHTENED β€” verifies plan USES actor data) ─────
 
517
  needs_progress = bool(progress.get("weight_series") or "plateau" in comps)
518
  required_actors = {"fitness_advisor", "nutrition_advisor"}
519
  if needs_progress:
 
522
  consulted_set = set(actors_consulted)
523
  missing_actors = required_actors - consulted_set
524
 
525
+ # Check plan actually USES actor data (not just consulted)
526
+ usage_score = 0.0
527
+ usage_checks = 0
528
+
529
+ if "fitness_advisor" in actor_responses:
530
+ fa_c = actor_responses["fitness_advisor"].get("constraints", {})
531
+ fa_min = fa_c.get("weekly_sets_min", 0)
532
+ fa_max = fa_c.get("weekly_sets_max", 999)
533
+ if fa_min <= agent_sets <= fa_max:
534
+ usage_score += 1.0
535
+ usage_checks += 1
536
+
537
+ if "nutrition_advisor" in actor_responses:
538
+ na_cal = actor_responses["nutrition_advisor"].get("constraints", {}).get("calories_target", 0)
539
+ if na_cal > 0 and agent_cal > 0:
540
+ if abs(agent_cal - na_cal) / na_cal <= 0.15:
541
+ usage_score += 1.0
542
+ elif abs(agent_cal - na_cal) / na_cal <= 0.25:
543
+ usage_score += 0.5
544
+ usage_checks += 1
545
+
546
+ if "progress_analyst" in actor_responses:
547
+ must_adapt = actor_responses["progress_analyst"].get("constraints", {}).get("must_adapt_if_plateau", False)
548
+ if must_adapt:
549
+ usage_score += 1.0 if scores.get("plateau_response", 0) >= 0.5 else 0.0
550
+ else:
551
+ usage_score += 1.0
552
+ usage_checks += 1
553
+
554
  unresolved = []
555
  for conflict in active_conflicts:
556
+ ct = conflict.get("type", "")
557
+ if ct == "plateau_volume_conflict" and scores.get("plateau_response", 1.0) < 0.5:
558
+ unresolved.append(ct)
559
+ elif ct == "volume_calorie_mismatch" and scores.get("coherence", 1.0) < 0.5:
560
+ unresolved.append(ct)
561
+ elif ct == "injury_overload_conflict" and scores.get("constraint_respect", 1.0) < 0.5:
562
+ unresolved.append(ct)
563
+
564
+ if missing_actors:
565
+ consult_score = max(0.0, 1.0 - len(missing_actors) / len(required_actors))
566
+ scores["actor_coordination"] = consult_score * 0.5
567
+ fb.append(f"βœ— Coordination: missing {sorted(missing_actors)}.")
568
+ elif unresolved:
569
+ scores["actor_coordination"] = max(0.0, 0.4 - 0.15 * len(unresolved))
570
+ fb.append(f"βœ— Coordination: {len(unresolved)} conflict(s) unresolved.")
 
 
 
 
 
 
 
 
 
571
  else:
572
+ usage_pct = (usage_score / usage_checks) if usage_checks > 0 else 0.5
573
+ scores["actor_coordination"] = round(usage_pct, 2)
574
+ if usage_pct >= 0.8:
575
+ fb.append(f"βœ“ Coordination: plan follows all actor constraints.")
576
+ else:
577
+ fb.append(f"~ Coordination: plan partially ignores actor data ({usage_pct:.0%}).")
578
 
579
  # ── Aggregate ─────────────────────────────────────────────────────────────
580
  active = list(scores.values())
 
605
  f"Unknown task_id '{task_id}'. Valid: {list(TASK_CONFIGS.keys())}"
606
  )
607
  self._task_id = task_id
608
+ self._is_curriculum = (task_id == "curriculum")
609
+ self._curriculum = CurriculumManager() if self._is_curriculum else None
610
  self._config = TASK_CONFIGS[task_id]
611
  self._state = State(episode_id=str(uuid4()), step_count=0)
612
  self._phase_idx = 0
 
618
  self._active_conflicts: list[dict] = []
619
 
620
  def reset(self) -> FitcoachObservation:
621
+ # Record previous episode for curriculum
622
+ if self._is_curriculum and self._curriculum and self._best_score > 0:
623
+ self._curriculum.record_score(self._best_score)
624
+
625
+ # Build config (curriculum generates random clients)
626
+ if self._is_curriculum and self._curriculum:
627
+ ep = self._curriculum.get_next_episode()
628
+ self._config = ep
629
+ else:
630
+ self._config = TASK_CONFIGS[self._task_id]
631
+
632
  self._state = State(episode_id=str(uuid4()), step_count=0)
633
  self._phase_idx = 0
634
  self._best_score = 0.0
 
814
  reward, breakdown, feedback, safety_now = grade_plan(
815
  action, cfg,
816
  self._actors_consulted,
817
+ self._actor_responses,
818
  self._active_conflicts,
819
  self._safety_hit,
820
  )
server/app.py CHANGED
@@ -1,19 +1,15 @@
1
  """
2
  FastAPI application for the FitCoach RL Environment.
3
 
4
- Task is selected via FITCOACH_TASK env var (default: week1_plan).
5
- Valid: week1_plan | plateau_adaptation | conflict_resolution
6
 
7
  Usage:
8
- $env:FITCOACH_TASK="week1_plan"
9
  uvicorn server.app:app --host 0.0.0.0 --port 8000
10
  """
11
 
12
- import os
13
- import sys
14
- import functools
15
 
16
- # Ensure the FitCoach root is on sys.path so absolute imports work
17
  _HERE = os.path.dirname(os.path.abspath(__file__))
18
  _ROOT = os.path.dirname(_HERE)
19
  if _ROOT not in sys.path:
@@ -22,41 +18,25 @@ if _ROOT not in sys.path:
22
  try:
23
  from openenv.core.env_server.http_server import create_app
24
  except Exception as e:
25
- raise ImportError(
26
- "openenv is required. Install with: pip install openenv-core"
27
- ) from e
28
 
29
  from models import FitcoachAction, FitcoachObservation
30
  from server.FitCoach_environment import FitcoachEnvironment
31
 
32
  FITCOACH_TASK = os.environ.get("FITCOACH_TASK", "week1_plan")
33
- VALID_TASKS = {"week1_plan", "plateau_adaptation", "conflict_resolution"}
34
 
35
  if FITCOACH_TASK not in VALID_TASKS:
36
- raise ValueError(
37
- f"Invalid FITCOACH_TASK='{FITCOACH_TASK}'. "
38
- f"Must be one of: {sorted(VALID_TASKS)}"
39
- )
40
 
41
  EnvFactory = functools.partial(FitcoachEnvironment, task_id=FITCOACH_TASK)
42
 
43
- app = create_app(
44
- EnvFactory,
45
- FitcoachAction,
46
- FitcoachObservation,
47
- env_name="FitCoach",
48
- max_concurrent_envs=4,
49
- )
50
 
51
-
52
- def main(host: str = "0.0.0.0", port: int = 8000):
53
  import uvicorn
54
  uvicorn.run(app, host=host, port=port)
55
 
56
-
57
  if __name__ == "__main__":
58
- import argparse
59
- parser = argparse.ArgumentParser()
60
- parser.add_argument("--port", type=int, default=8000)
61
- args = parser.parse_args()
62
- main(port=args.port)
 
1
  """
2
  FastAPI application for the FitCoach RL Environment.
3
 
4
+ Valid tasks: week1_plan | plateau_adaptation | conflict_resolution | curriculum
 
5
 
6
  Usage:
7
+ $env:FITCOACH_TASK="curriculum"
8
  uvicorn server.app:app --host 0.0.0.0 --port 8000
9
  """
10
 
11
+ import os, sys, functools
 
 
12
 
 
13
  _HERE = os.path.dirname(os.path.abspath(__file__))
14
  _ROOT = os.path.dirname(_HERE)
15
  if _ROOT not in sys.path:
 
18
  try:
19
  from openenv.core.env_server.http_server import create_app
20
  except Exception as e:
21
+ raise ImportError("openenv required: pip install openenv-core") from e
 
 
22
 
23
  from models import FitcoachAction, FitcoachObservation
24
  from server.FitCoach_environment import FitcoachEnvironment
25
 
26
  FITCOACH_TASK = os.environ.get("FITCOACH_TASK", "week1_plan")
27
+ VALID_TASKS = {"week1_plan", "plateau_adaptation", "conflict_resolution", "curriculum"}
28
 
29
  if FITCOACH_TASK not in VALID_TASKS:
30
+ raise ValueError(f"Invalid FITCOACH_TASK='{FITCOACH_TASK}'. Valid: {sorted(VALID_TASKS)}")
 
 
 
31
 
32
  EnvFactory = functools.partial(FitcoachEnvironment, task_id=FITCOACH_TASK)
33
 
34
+ app = create_app(EnvFactory, FitcoachAction, FitcoachObservation,
35
+ env_name="FitCoach", max_concurrent_envs=4)
 
 
 
 
 
36
 
37
+ def main(host="0.0.0.0", port=8000):
 
38
  import uvicorn
39
  uvicorn.run(app, host=host, port=port)
40
 
 
41
  if __name__ == "__main__":
42
+ main()
 
 
 
 
utils/curriculum.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Adaptive Curriculum Manager β€” Theme 4 (Self-Improvement).
3
+
4
+ Implements two key capabilities:
5
+ 1. Procedural client generation β€” randomized profiles so the agent
6
+ cannot memorize answers; must genuinely generalize.
7
+ 2. Adaptive difficulty escalation β€” agent earns harder tasks through
8
+ consistent performance. Difficulty drops back if agent struggles.
9
+
10
+ This is what makes Theme 4 legitimate: the environment itself adapts
11
+ to the agent's skill level, creating an automatic curriculum.
12
+
13
+ Snorkel AI sub-theme fit: "Simulated Experts-in-the-Loop with changing
14
+ requirements/preferences" β€” each generated client has different
15
+ preferences, restrictions, and complications.
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import random
21
+ import copy
22
+ from typing import Optional
23
+
24
+
25
+ # ── Client generation pools ───────────────────────────────────────────────────
26
+
27
+ NAMES = [
28
+ ("Arjun Sharma", "male"), ("Priya Menon", "female"),
29
+ ("Rahul Verma", "male"), ("Sneha Reddy", "female"),
30
+ ("Vikram Patel", "male"), ("Ananya Iyer", "female"),
31
+ ("Karan Singh", "male"), ("Meera Nair", "female"),
32
+ ("Rohan Gupta", "male"), ("Diya Kapoor", "female"),
33
+ ("Aditya Joshi", "male"), ("Kavya Pillai", "female"),
34
+ ]
35
+
36
+ GOALS = ["muscle_gain", "weight_loss", "endurance", "maintenance"]
37
+
38
+ FITNESS_LEVELS = ["beginner", "intermediate", "advanced"]
39
+
40
+ EQUIPMENT_SETS = [
41
+ ["dumbbells", "pull_up_bar"],
42
+ ["dumbbells", "pull_up_bar", "resistance_bands"],
43
+ ["barbell", "dumbbells"],
44
+ ["barbell", "dumbbells", "cables", "machines"],
45
+ ["barbell", "dumbbells", "cables", "machines", "pull_up_bar"],
46
+ ["dumbbells", "resistance_bands", "kettlebell"],
47
+ ]
48
+
49
+ DIETARY_RESTRICTIONS = [
50
+ [],
51
+ ["vegetarian"],
52
+ ["vegan"],
53
+ ["vegetarian", "gluten_free"],
54
+ ]
55
+
56
+ INJURY_OPTIONS = [
57
+ [],
58
+ ["lower back"],
59
+ ["knee"],
60
+ ["shoulder"],
61
+ ["lower back", "knee"],
62
+ ]
63
+
64
+ COMPLICATION_TEMPLATES = {
65
+ "none": [],
66
+ "plateau": ["plateau"],
67
+ "injury": [], # filled from client injuries
68
+ "goal_change": [], # filled dynamically
69
+ "multi": [], # filled with all applicable
70
+ }
71
+
72
+
73
+ def generate_weight_series(
74
+ base_weight: float,
75
+ goal: str,
76
+ n_days: int = 14,
77
+ plateau: bool = False,
78
+ seed: Optional[int] = None,
79
+ ) -> list[dict]:
80
+ """Generate synthetic weight series data."""
81
+ rng = random.Random(seed)
82
+ series = []
83
+ for i in range(n_days):
84
+ day = i + 1
85
+ if plateau:
86
+ # Flat trend with noise
87
+ weight = base_weight + rng.uniform(-0.3, 0.3)
88
+ elif goal == "weight_loss":
89
+ weight = base_weight - (i * 0.05) + rng.uniform(-0.3, 0.3)
90
+ elif goal == "muscle_gain":
91
+ weight = base_weight + (i * 0.03) + rng.uniform(-0.2, 0.2)
92
+ else:
93
+ weight = base_weight + rng.uniform(-0.2, 0.2)
94
+
95
+ series.append({
96
+ "date": f"2026-04-{day:02d}",
97
+ "weight_kg": round(weight, 1),
98
+ })
99
+ return series
100
+
101
+
102
+ def generate_exercise_history(
103
+ equipment: list[str],
104
+ fitness_level: str,
105
+ injuries: list[str],
106
+ seed: Optional[int] = None,
107
+ ) -> dict:
108
+ """Generate plausible exercise history for overload testing."""
109
+ rng = random.Random(seed)
110
+
111
+ # Pool of exercises by equipment
112
+ exercise_pool = {
113
+ "dumbbells": [
114
+ ("Dumbbell Bench Press", 20, "8-12"),
115
+ ("Dumbbell Row", 18, "8-12"),
116
+ ("Dumbbell Shoulder Press", 14, "8-12"),
117
+ ("Dumbbell Romanian Deadlift", 20, "10-12"),
118
+ ("Dumbbell Squat", 22, "8-12"),
119
+ ("Dumbbell Curl", 12, "10-12"),
120
+ ],
121
+ "barbell": [
122
+ ("Barbell Squat", 60, "6-10"),
123
+ ("Barbell Deadlift", 80, "4-6"),
124
+ ("Barbell Bench Press", 50, "6-10"),
125
+ ("Barbell Row", 45, "8-12"),
126
+ ],
127
+ "pull_up_bar": [
128
+ ("Pull-up", 0, "6-10"),
129
+ ("Chin-up", 0, "6-10"),
130
+ ],
131
+ "cables": [
132
+ ("Cable Row", 30, "10-12"),
133
+ ("Lat Pulldown", 35, "8-12"),
134
+ ],
135
+ }
136
+
137
+ # Filter by available equipment and injuries
138
+ banned = set()
139
+ injury_bans = {
140
+ "lower back": {"deadlift", "bent-over row", "good morning"},
141
+ "knee": {"lunge", "deep squat", "leg extension"},
142
+ "shoulder": {"overhead press", "upright row", "military press"},
143
+ }
144
+ for injury in injuries:
145
+ for term in injury_bans.get(injury, set()):
146
+ banned.add(term)
147
+
148
+ available_exercises = []
149
+ for eq in equipment:
150
+ for ex_name, weight, reps in exercise_pool.get(eq, []):
151
+ # Check not banned
152
+ if any(b in ex_name.lower() for b in banned):
153
+ continue
154
+ available_exercises.append((ex_name, weight, reps))
155
+
156
+ # Pick 2-3 exercises for history
157
+ if not available_exercises:
158
+ return {}
159
+
160
+ n = min(rng.randint(2, 3), len(available_exercises))
161
+ chosen = rng.sample(available_exercises, n)
162
+
163
+ history = {}
164
+ for ex_name, base_weight, target_reps in chosen:
165
+ # Randomize performance β€” sometimes hit top, sometimes not
166
+ lo, hi = [int(x) for x in target_reps.split("-")]
167
+ if rng.random() < 0.4:
168
+ # Hit top of range β†’ should add weight
169
+ reps_str = f"{hi},{hi},{hi}"
170
+ elif rng.random() < 0.3:
171
+ # Missed some β†’ should repeat
172
+ mid = (lo + hi) // 2
173
+ reps_str = f"{hi},{mid},{lo}"
174
+ else:
175
+ # In range but not at top β†’ repeat
176
+ mid = (lo + hi) // 2
177
+ reps_str = f"{mid},{mid},{mid}"
178
+
179
+ # Scale weight by fitness level
180
+ level_scale = {"beginner": 0.6, "intermediate": 1.0, "advanced": 1.4}
181
+ scaled_weight = round(
182
+ base_weight * level_scale.get(fitness_level, 1.0) / 2.5
183
+ ) * 2.5
184
+
185
+ history[ex_name] = {
186
+ "last_weight_kg": scaled_weight,
187
+ "last_reps_str": reps_str,
188
+ "target_reps": target_reps,
189
+ "target_sets": 3,
190
+ }
191
+
192
+ return history
193
+
194
+
195
+ def generate_client(
196
+ difficulty: str = "easy",
197
+ seed: Optional[int] = None,
198
+ ) -> dict:
199
+ """
200
+ Generate a random client profile appropriate for the difficulty level.
201
+
202
+ Difficulty controls:
203
+ - easy: no injuries, no complications, simple equipment
204
+ - medium: may have plateau, some exercise history
205
+ - hard: injuries + plateau + goal change + conflicts guaranteed
206
+ """
207
+ rng = random.Random(seed)
208
+
209
+ name, sex = rng.choice(NAMES)
210
+ age = rng.randint(20, 50)
211
+
212
+ if sex == "male":
213
+ weight = round(rng.uniform(60, 95), 1)
214
+ height = round(rng.uniform(165, 190), 1)
215
+ else:
216
+ weight = round(rng.uniform(48, 80), 1)
217
+ height = round(rng.uniform(150, 175), 1)
218
+
219
+ if difficulty == "easy":
220
+ goal = rng.choice(["muscle_gain", "weight_loss"])
221
+ fitness_level = "beginner"
222
+ equipment = rng.choice(EQUIPMENT_SETS[:3]) # simpler setups
223
+ dietary = rng.choice(DIETARY_RESTRICTIONS[:2]) # none or vegetarian
224
+ injuries = []
225
+ complications = []
226
+ sessions = rng.choice([3, 4])
227
+ elif difficulty == "medium":
228
+ goal = rng.choice(GOALS)
229
+ fitness_level = rng.choice(["beginner", "intermediate"])
230
+ equipment = rng.choice(EQUIPMENT_SETS)
231
+ dietary = rng.choice(DIETARY_RESTRICTIONS)
232
+ injuries = []
233
+ complications = ["plateau"] if rng.random() < 0.7 else []
234
+ sessions = rng.choice([3, 4, 5])
235
+ else: # hard
236
+ goal = rng.choice(GOALS)
237
+ fitness_level = rng.choice(["intermediate", "advanced"])
238
+ equipment = rng.choice(EQUIPMENT_SETS[2:]) # needs more equipment
239
+ dietary = rng.choice(DIETARY_RESTRICTIONS)
240
+ injuries = rng.choice(INJURY_OPTIONS[1:]) # guaranteed injury
241
+ complications = ["plateau"]
242
+ # Add goal change
243
+ old_goal = rng.choice([g for g in GOALS if g != goal])
244
+ complications.append(f"goal_change:{old_goal}β†’{goal}")
245
+ # Add injury complication
246
+ for injury in injuries:
247
+ complications.append(f"new_injury:{injury}")
248
+ sessions = rng.choice([4, 5])
249
+
250
+ # TDEE estimate based on weight, sex, activity
251
+ bmr = (10 * weight) + (6.25 * height) - (5 * age) + (5 if sex == "male" else -161)
252
+ tdee = round(bmr * rng.uniform(1.4, 1.7))
253
+
254
+ client = {
255
+ "name": name,
256
+ "age": age,
257
+ "sex": sex,
258
+ "weight_kg": weight,
259
+ "height_cm": height,
260
+ "goal": goal,
261
+ "fitness_level": fitness_level,
262
+ "dietary_restrictions": dietary,
263
+ "available_equipment": equipment,
264
+ "sessions_per_week": sessions,
265
+ "tdee_estimate": float(tdee),
266
+ "injuries": injuries,
267
+ }
268
+
269
+ # Build progress data
270
+ progress_data = {}
271
+ if "plateau" in complications:
272
+ progress_data["weight_series"] = generate_weight_series(
273
+ weight, goal, n_days=14, plateau=True, seed=seed
274
+ )
275
+ progress_data["adherence_pct"] = rng.randint(55, 90)
276
+ progress_data["avg_workout_rating"] = round(rng.uniform(1.5, 3.5), 1)
277
+
278
+ if difficulty in ("medium", "hard"):
279
+ progress_data["exercise_history"] = generate_exercise_history(
280
+ equipment, fitness_level, injuries, seed=seed
281
+ )
282
+ if any("goal_change" in c for c in complications):
283
+ progress_data["previous_goal"] = old_goal
284
+
285
+ return {
286
+ "client": client,
287
+ "progress_data": progress_data,
288
+ "complications": complications,
289
+ }
290
+
291
+
292
+ # ── Adaptive Curriculum Manager ───────────────────────────────────────────────
293
+
294
+ class CurriculumManager:
295
+ """
296
+ Tracks agent performance and escalates/de-escalates difficulty.
297
+
298
+ Rules:
299
+ - Start at easy
300
+ - Score β‰₯ 0.8 for 3 consecutive episodes β†’ escalate
301
+ - Score < 0.5 for 2 consecutive episodes β†’ de-escalate
302
+ - Generate new random client each episode (no memorization)
303
+ """
304
+
305
+ DIFFICULTIES = ["easy", "medium", "hard"]
306
+
307
+ def __init__(self, start_difficulty: str = "easy"):
308
+ self.current_difficulty = start_difficulty
309
+ self.episode_scores: list[float] = []
310
+ self.difficulty_history: list[str] = []
311
+ self.escalation_events: list[dict] = []
312
+ self._episode_count = 0
313
+ self._seed_counter = 42
314
+
315
+ def get_next_episode(self) -> dict:
316
+ """
317
+ Generate the next episode config with a random client
318
+ at the current difficulty level.
319
+
320
+ Returns dict with: client, progress_data, complications,
321
+ difficulty, max_steps, phases, description.
322
+ """
323
+ self._episode_count += 1
324
+ self._seed_counter += 1
325
+
326
+ generated = generate_client(
327
+ difficulty=self.current_difficulty,
328
+ seed=self._seed_counter,
329
+ )
330
+
331
+ # Max steps and phases by difficulty
332
+ if self.current_difficulty == "easy":
333
+ max_steps = 5
334
+ phases = ["initial"]
335
+ elif self.current_difficulty == "medium":
336
+ max_steps = 7
337
+ phases = ["initial", "adaptation"]
338
+ else:
339
+ max_steps = 9
340
+ phases = ["initial", "adaptation", "conflict"]
341
+
342
+ client = generated["client"]
343
+ desc = (
344
+ f"[Curriculum: {self.current_difficulty.upper()} | "
345
+ f"Episode {self._episode_count}] "
346
+ f"Client: {client['name']}, {client['age']}y, "
347
+ f"{client['fitness_level']} {client['goal']}. "
348
+ f"Equipment: {client['available_equipment']}. "
349
+ f"Injuries: {client['injuries'] or 'none'}. "
350
+ f"Complications: {generated['complications'] or 'none'}."
351
+ )
352
+
353
+ return {
354
+ "client": generated["client"],
355
+ "progress_data": generated["progress_data"],
356
+ "complications": generated["complications"],
357
+ "difficulty": self.current_difficulty,
358
+ "max_steps": max_steps,
359
+ "phases": phases,
360
+ "description": desc,
361
+ }
362
+
363
+ def record_score(self, score: float):
364
+ """Record episode score and check for escalation/de-escalation."""
365
+ self.episode_scores.append(score)
366
+ self.difficulty_history.append(self.current_difficulty)
367
+
368
+ current_idx = self.DIFFICULTIES.index(self.current_difficulty)
369
+
370
+ # Check escalation: 3 consecutive scores β‰₯ 0.8
371
+ if len(self.episode_scores) >= 3:
372
+ last_3 = self.episode_scores[-3:]
373
+ if all(s >= 0.8 for s in last_3) and current_idx < len(self.DIFFICULTIES) - 1:
374
+ old = self.current_difficulty
375
+ self.current_difficulty = self.DIFFICULTIES[current_idx + 1]
376
+ self.escalation_events.append({
377
+ "episode": self._episode_count,
378
+ "direction": "escalate",
379
+ "from": old,
380
+ "to": self.current_difficulty,
381
+ "trigger": f"3 consecutive scores β‰₯ 0.8: {last_3}",
382
+ })
383
+ return
384
+
385
+ # Check de-escalation: 2 consecutive scores < 0.5
386
+ if len(self.episode_scores) >= 2:
387
+ last_2 = self.episode_scores[-2:]
388
+ if all(s < 0.5 for s in last_2) and current_idx > 0:
389
+ old = self.current_difficulty
390
+ self.current_difficulty = self.DIFFICULTIES[current_idx - 1]
391
+ self.escalation_events.append({
392
+ "episode": self._episode_count,
393
+ "direction": "de-escalate",
394
+ "from": old,
395
+ "to": self.current_difficulty,
396
+ "trigger": f"2 consecutive scores < 0.5: {last_2}",
397
+ })
398
+
399
+ def get_summary(self) -> dict:
400
+ """Return training summary for plotting."""
401
+ return {
402
+ "total_episodes": self._episode_count,
403
+ "current_difficulty": self.current_difficulty,
404
+ "episode_scores": self.episode_scores,
405
+ "difficulty_history": self.difficulty_history,
406
+ "escalation_events": self.escalation_events,
407
+ "avg_score": (
408
+ sum(self.episode_scores) / len(self.episode_scores)
409
+ if self.episode_scores else 0.0
410
+ ),
411
+ }