scicoqa / parsing.py
timbmg's picture
inital commit
4caa453 unverified
"""Discrepancy parsing logic for extracting discrepancies from LLM output."""
import logging
import re
logger = logging.getLogger(__name__)
def parse_discrepancies(text: str) -> list[str] | None:
"""
Extract list items (discrepancies) from model output.
Replicates the _extract_list_items logic from scicoqa/inference/discrepancy_eval.py
Args:
text: Raw text output from LLM
Returns:
List of discrepancy strings, or None if no discrepancies found
"""
if not text:
return None
# Remove redacted reasoning if present
if "</think>" in text:
text = text.split("</think>")[1]
# Detect YAML or dashed list format
if "```yaml\ndiscrepancies:" in text:
text = text.split("```yaml\ndiscrepancies:")[-1]
yaml_or_dashed = True
elif "```yaml" in text:
text = text.split("```yaml")[-1]
yaml_or_dashed = True
elif "discrepancies:" in text:
text = text.split("discrepancies:")[1]
yaml_or_dashed = True
elif re.search(r"# Discrepancies[\s\r\n]*-", text, re.IGNORECASE):
text = re.split(
r"# Discrepancies[\s\r\n]*-", text, maxsplit=1, flags=re.IGNORECASE
)[1]
text = "- " + text
yaml_or_dashed = True
else:
yaml_or_dashed = False
if yaml_or_dashed:
# Clean up the text
text = text.strip("\n").strip().strip("```yaml").strip("```").strip("\n")
text = (
text.strip("discrepancies:").strip("discrepancies").strip("\n").strip()
)
# Split by list item pattern
pattern = r"\n\s{0,2}-\s+"
parts = re.split(pattern, text)
items = []
for part in parts:
cleaned = " ".join(part.split())
if cleaned and not cleaned.startswith("discrepancies:"):
# Multiple cleaning passes
cleaned = cleaned.strip().strip("-").strip()
cleaned = cleaned.strip().strip("-").strip()
cleaned = cleaned.strip().strip("|").strip()
cleaned = cleaned.strip().strip(">-").strip()
cleaned = cleaned.strip().strip(">").strip()
cleaned = cleaned.strip().strip('"').strip()
cleaned = cleaned.strip().strip("'").strip()
cleaned = cleaned.strip("summary: |\n")
cleaned = cleaned.strip("summary: ")
cleaned = cleaned.strip("|")
cleaned = cleaned.strip("\n").strip()
# Remove numbered prefixes
cleaned = re.sub(r"^[0-9]+[\.\)]\s*", "", cleaned)
if cleaned: # Only add non-empty items
items.append(cleaned)
else:
items = None
# Handle empty list case
if items and len(items) == 1 and items[0].strip() == "[]":
items = None
return items if items else None