|
|
|
|
|
def extract_test_output_code(model_output: str): |
|
|
|
|
|
outputlines = model_output.split("\n") |
|
|
|
|
|
indexlines = [i for i, line in enumerate(outputlines) if line.startswith("assert")] |
|
|
if indexlines: |
|
|
return outputlines[indexlines[-1]] |
|
|
|
|
|
|
|
|
indexlines = [ |
|
|
i |
|
|
for i, line in enumerate(outputlines) |
|
|
if "```python" in line or "```Python" in line |
|
|
] |
|
|
if indexlines: |
|
|
start_index = indexlines[0] |
|
|
else: |
|
|
start_index = None |
|
|
indexlines = [i for i, line in enumerate(outputlines) if "```" in line] |
|
|
if start_index is not None: |
|
|
indexlines = [i for i in indexlines if i > start_index] |
|
|
indexlines = [start_index] + indexlines |
|
|
|
|
|
if len(indexlines) < 2: |
|
|
return "" |
|
|
return "\n".join(outputlines[indexlines[0] + 1 : indexlines[1]]) |
|
|
|
|
|
|
|
|
def extract_execution_code(model_output: str, cot: bool = False): |
|
|
if cot: |
|
|
if "[ANSWER]" in model_output: |
|
|
model_output = model_output.split("[ANSWER]")[1].strip() |
|
|
if "==" in model_output: |
|
|
model_output = model_output.split("==")[1].strip() |
|
|
if "[/ANSWER]" in model_output: |
|
|
model_output = model_output.split("[/ANSWER]")[0].strip() |
|
|
else: |
|
|
model_output = model_output.split("\n")[0].strip() |
|
|
return model_output.strip() |