jayantaggarwal-sketch commited on
Commit
0194e2e
·
1 Parent(s): 6762657

Fix TRL compatibility in GRPO training and Space API example

Browse files
Files changed (2) hide show
  1. HF_README.md +1 -1
  2. training/train_grpo.py +39 -5
HF_README.md CHANGED
@@ -32,7 +32,7 @@ curl -X POST "https://jayant2304-commitment-os.hf.space/reset?task_id=easy_001"
32
  # Make a tool call
33
  curl -X POST "https://jayant2304-commitment-os.hf.space/step" \
34
  -H "Content-Type: application/json" \
35
- -d '{"action_type": "view_calendar", "date": "2026-04-25"}'
36
 
37
  # Get state
38
  curl "https://jayant2304-commitment-os.hf.space/state"
 
32
  # Make a tool call
33
  curl -X POST "https://jayant2304-commitment-os.hf.space/step" \
34
  -H "Content-Type: application/json" \
35
+ -d '{"action": {"action_type": "view_calendar", "date": "2026-04-25"}}'
36
 
37
  # Get state
38
  curl "https://jayant2304-commitment-os.hf.space/state"
training/train_grpo.py CHANGED
@@ -66,12 +66,46 @@ def build_dataset(num_scenarios: int = 15) -> List[Dict[str, Any]]:
66
  return dataset
67
 
68
 
69
- def reward_function(completions: List[str], **kwargs: Any) -> List[float]:
70
  """Reward function for GRPO — evaluates completions against CommitmentOS."""
71
  from training.env_factory import CommitmentOSEnvFactory
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  factory = CommitmentOSEnvFactory(max_turns=8)
74
- return factory(completions)
 
75
 
76
 
77
  def main() -> None:
@@ -95,7 +129,7 @@ def main() -> None:
95
 
96
  model = AutoModelForCausalLM.from_pretrained(
97
  args.model,
98
- torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
99
  device_map="auto" if torch.cuda.is_available() else None,
100
  trust_remote_code=True,
101
  )
@@ -122,7 +156,7 @@ def main() -> None:
122
  save_steps=50,
123
  bf16=torch.cuda.is_available(),
124
  gradient_accumulation_steps=2,
125
- warmup_ratio=0.1,
126
  max_completion_length=512,
127
  num_generations=args.group_size,
128
  report_to="none",
@@ -131,7 +165,7 @@ def main() -> None:
131
  print("Initialising GRPOTrainer...")
132
  trainer = GRPOTrainer(
133
  model=model,
134
- config=training_config,
135
  train_dataset=dataset,
136
  processing_class=tokenizer,
137
  reward_funcs=reward_function,
 
66
  return dataset
67
 
68
 
69
+ def reward_function(completions: List[Any], **kwargs: Any) -> List[float]:
70
  """Reward function for GRPO — evaluates completions against CommitmentOS."""
71
  from training.env_factory import CommitmentOSEnvFactory
72
 
73
+ def _completion_to_text(completion: Any) -> str:
74
+ """Normalize TRL completion payloads across versions.
75
+
76
+ Depending on TRL/Transformers version, completions can arrive as
77
+ strings, dicts, or nested lists of chat/message objects.
78
+ """
79
+ if isinstance(completion, str):
80
+ return completion
81
+ if isinstance(completion, dict):
82
+ content = completion.get("content", completion.get("text", ""))
83
+ if isinstance(content, str):
84
+ return content
85
+ if isinstance(content, list):
86
+ return "\n".join(str(item) for item in content)
87
+ return str(content)
88
+ if isinstance(completion, list):
89
+ parts: List[str] = []
90
+ for item in completion:
91
+ if isinstance(item, str):
92
+ parts.append(item)
93
+ elif isinstance(item, dict):
94
+ content = item.get("content", item.get("text", ""))
95
+ if isinstance(content, list):
96
+ content = " ".join(
97
+ block.get("text", str(block)) if isinstance(block, dict) else str(block)
98
+ for block in content
99
+ )
100
+ parts.append(str(content))
101
+ else:
102
+ parts.append(str(item))
103
+ return "\n".join(part for part in parts if part)
104
+ return str(completion)
105
+
106
  factory = CommitmentOSEnvFactory(max_turns=8)
107
+ normalized = [_completion_to_text(completion) for completion in completions]
108
+ return factory(normalized)
109
 
110
 
111
  def main() -> None:
 
129
 
130
  model = AutoModelForCausalLM.from_pretrained(
131
  args.model,
132
+ dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
133
  device_map="auto" if torch.cuda.is_available() else None,
134
  trust_remote_code=True,
135
  )
 
156
  save_steps=50,
157
  bf16=torch.cuda.is_available(),
158
  gradient_accumulation_steps=2,
159
+ warmup_steps=5,
160
  max_completion_length=512,
161
  num_generations=args.group_size,
162
  report_to="none",
 
165
  print("Initialising GRPOTrainer...")
166
  trainer = GRPOTrainer(
167
  model=model,
168
+ args=training_config,
169
  train_dataset=dataset,
170
  processing_class=tokenizer,
171
  reward_funcs=reward_function,