Commit
·
1ee2461
1
Parent(s):
f2d5eaa
Fix code block extraction and increase completion length
Browse files
app.py
CHANGED
|
@@ -29,6 +29,7 @@ check_import("transformers", lambda: __import__("transformers").__version__)
|
|
| 29 |
check_import("datasets", lambda: __import__("datasets").__version__)
|
| 30 |
check_import("peft", lambda: __import__("peft").__version__)
|
| 31 |
check_import("trl", lambda: __import__("trl").__version__)
|
|
|
|
| 32 |
|
| 33 |
try:
|
| 34 |
from trl import GRPOConfig, GRPOTrainer
|
|
@@ -103,6 +104,7 @@ def get_status():
|
|
| 103 |
|
| 104 |
|
| 105 |
def extract_code_block(text: str) -> str:
|
|
|
|
| 106 |
pattern = r"```python\s*(.*?)```"
|
| 107 |
matches = re.findall(pattern, text, re.DOTALL)
|
| 108 |
if matches:
|
|
@@ -111,6 +113,18 @@ def extract_code_block(text: str) -> str:
|
|
| 111 |
matches = re.findall(pattern, text, re.DOTALL)
|
| 112 |
if matches:
|
| 113 |
return matches[-1].strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
return text.strip()
|
| 115 |
|
| 116 |
|
|
@@ -591,7 +605,7 @@ def run_training(model_name, chunk_steps, max_total_steps, max_minutes, auto_con
|
|
| 591 |
save_steps=999999,
|
| 592 |
report_to="none",
|
| 593 |
remove_unused_columns=False,
|
| 594 |
-
max_completion_length=
|
| 595 |
num_generations=4,
|
| 596 |
)
|
| 597 |
|
|
@@ -627,7 +641,7 @@ def run_training(model_name, chunk_steps, max_total_steps, max_minutes, auto_con
|
|
| 627 |
with torch.no_grad():
|
| 628 |
outputs = model.generate(
|
| 629 |
**inputs,
|
| 630 |
-
max_new_tokens=
|
| 631 |
do_sample=True,
|
| 632 |
temperature=0.7,
|
| 633 |
top_p=0.9,
|
|
|
|
| 29 |
check_import("datasets", lambda: __import__("datasets").__version__)
|
| 30 |
check_import("peft", lambda: __import__("peft").__version__)
|
| 31 |
check_import("trl", lambda: __import__("trl").__version__)
|
| 32 |
+
check_import("huggingface_hub", lambda: __import__("huggingface_hub").__version__)
|
| 33 |
|
| 34 |
try:
|
| 35 |
from trl import GRPOConfig, GRPOTrainer
|
|
|
|
| 104 |
|
| 105 |
|
| 106 |
def extract_code_block(text: str) -> str:
|
| 107 |
+
# Prefer closed fences
|
| 108 |
pattern = r"```python\s*(.*?)```"
|
| 109 |
matches = re.findall(pattern, text, re.DOTALL)
|
| 110 |
if matches:
|
|
|
|
| 113 |
matches = re.findall(pattern, text, re.DOTALL)
|
| 114 |
if matches:
|
| 115 |
return matches[-1].strip()
|
| 116 |
+
|
| 117 |
+
# Handle unclosed fences (common when generation truncates)
|
| 118 |
+
if "```python" in text:
|
| 119 |
+
after = text.split("```python", 1)[1]
|
| 120 |
+
if "```" in after:
|
| 121 |
+
after = after.split("```", 1)[0]
|
| 122 |
+
return after.strip()
|
| 123 |
+
if "```" in text:
|
| 124 |
+
after = text.split("```", 1)[1]
|
| 125 |
+
if "```" in after:
|
| 126 |
+
after = after.split("```", 1)[0]
|
| 127 |
+
return after.strip()
|
| 128 |
return text.strip()
|
| 129 |
|
| 130 |
|
|
|
|
| 605 |
save_steps=999999,
|
| 606 |
report_to="none",
|
| 607 |
remove_unused_columns=False,
|
| 608 |
+
max_completion_length=2048,
|
| 609 |
num_generations=4,
|
| 610 |
)
|
| 611 |
|
|
|
|
| 641 |
with torch.no_grad():
|
| 642 |
outputs = model.generate(
|
| 643 |
**inputs,
|
| 644 |
+
max_new_tokens=1024,
|
| 645 |
do_sample=True,
|
| 646 |
temperature=0.7,
|
| 647 |
top_p=0.9,
|