CreativeEngineer commited on
Commit
1ee2461
·
1 Parent(s): f2d5eaa

Fix code block extraction and increase completion length

Browse files
Files changed (1) hide show
  1. app.py +16 -2
app.py CHANGED
@@ -29,6 +29,7 @@ check_import("transformers", lambda: __import__("transformers").__version__)
29
  check_import("datasets", lambda: __import__("datasets").__version__)
30
  check_import("peft", lambda: __import__("peft").__version__)
31
  check_import("trl", lambda: __import__("trl").__version__)
 
32
 
33
  try:
34
  from trl import GRPOConfig, GRPOTrainer
@@ -103,6 +104,7 @@ def get_status():
103
 
104
 
105
  def extract_code_block(text: str) -> str:
 
106
  pattern = r"```python\s*(.*?)```"
107
  matches = re.findall(pattern, text, re.DOTALL)
108
  if matches:
@@ -111,6 +113,18 @@ def extract_code_block(text: str) -> str:
111
  matches = re.findall(pattern, text, re.DOTALL)
112
  if matches:
113
  return matches[-1].strip()
 
 
 
 
 
 
 
 
 
 
 
 
114
  return text.strip()
115
 
116
 
@@ -591,7 +605,7 @@ def run_training(model_name, chunk_steps, max_total_steps, max_minutes, auto_con
591
  save_steps=999999,
592
  report_to="none",
593
  remove_unused_columns=False,
594
- max_completion_length=512,
595
  num_generations=4,
596
  )
597
 
@@ -627,7 +641,7 @@ def run_training(model_name, chunk_steps, max_total_steps, max_minutes, auto_con
627
  with torch.no_grad():
628
  outputs = model.generate(
629
  **inputs,
630
- max_new_tokens=256,
631
  do_sample=True,
632
  temperature=0.7,
633
  top_p=0.9,
 
29
  check_import("datasets", lambda: __import__("datasets").__version__)
30
  check_import("peft", lambda: __import__("peft").__version__)
31
  check_import("trl", lambda: __import__("trl").__version__)
32
+ check_import("huggingface_hub", lambda: __import__("huggingface_hub").__version__)
33
 
34
  try:
35
  from trl import GRPOConfig, GRPOTrainer
 
104
 
105
 
106
  def extract_code_block(text: str) -> str:
107
+ # Prefer closed fences
108
  pattern = r"```python\s*(.*?)```"
109
  matches = re.findall(pattern, text, re.DOTALL)
110
  if matches:
 
113
  matches = re.findall(pattern, text, re.DOTALL)
114
  if matches:
115
  return matches[-1].strip()
116
+
117
+ # Handle unclosed fences (common when generation truncates)
118
+ if "```python" in text:
119
+ after = text.split("```python", 1)[1]
120
+ if "```" in after:
121
+ after = after.split("```", 1)[0]
122
+ return after.strip()
123
+ if "```" in text:
124
+ after = text.split("```", 1)[1]
125
+ if "```" in after:
126
+ after = after.split("```", 1)[0]
127
+ return after.strip()
128
  return text.strip()
129
 
130
 
 
605
  save_steps=999999,
606
  report_to="none",
607
  remove_unused_columns=False,
608
+ max_completion_length=2048,
609
  num_generations=4,
610
  )
611
 
 
641
  with torch.no_grad():
642
  outputs = model.generate(
643
  **inputs,
644
+ max_new_tokens=1024,
645
  do_sample=True,
646
  temperature=0.7,
647
  top_p=0.9,