Rishi-Jain-27's picture
Created data generator and data and finetune.py
da653f3
Raw
History Blame Contribute Delete
12.5 kB
"""Core synthetic-data engine for the code -> Mermaid flowchart dataset.
The whole design goal is *correctness by construction*. Each example's source
code and its flowchart are emitted from the same builder state, so:
* Every node's `<linemap>` line number is the live 1-based line number that
`CodeBuilder.add()` returned when the corresponding statement was written.
Injecting comments / docstrings / blank lines shifts the line numbers for
free and the map stays correct, because nothing is hard-coded.
* Decision labels are *paraphrased* plain-English questions (per the strict
system-prompt constraint), never raw code/operators/quotes/brackets.
`Mermaid` renders the graph + the `<linemap>` block; `validate_example()` is the
hard gate that every produced example must pass.
"""
from __future__ import annotations
import re
import string
from dataclasses import dataclass, field
from typing import Callable, Optional
# --------------------------------------------------------------------------- #
# Code builder
# --------------------------------------------------------------------------- #
class CodeBuilder:
"""Accumulates source lines and hands back live 1-based line numbers.
Templates call ``add()`` for real statements (and keep the returned line
number to wire into the flowchart) and ``blank()`` / ``comment()`` /
``docstring()`` for filler that deliberately shifts line numbers but maps to
no node.
"""
def __init__(self, indent_unit: str = " ", comment_prefix: str = "# "):
self.lines: list[str] = []
self.iu = indent_unit
self.cpre = comment_prefix
def add(self, text: str, indent: int = 0) -> int:
self.lines.append(self.iu * indent + text)
return len(self.lines)
def blank(self) -> int:
self.lines.append("")
return len(self.lines)
def comment(self, text: str, indent: int = 0) -> int:
return self.add(self.cpre + text, indent)
def maybe_blank(self, rng, p: float = 0.18) -> None:
if rng.random() < p:
self.blank()
def maybe_comment(self, rng, comments, indent: int = 0, p: float = 0.3) -> None:
if rng.random() < p:
self.comment(rng.choice(comments), indent)
def maybe_docstring(self, rng, texts, indent: int = 1, p: float = 0.3) -> None:
"""Python-style one-line docstring (filler that shifts line numbers)."""
if rng.random() < p:
self.add('"""' + rng.choice(texts) + '"""', indent)
def render_numbered(self) -> str:
return "\n".join(f"{i}| {ln}" for i, ln in enumerate(self.lines, 1))
def source(self) -> str:
return "\n".join(self.lines)
@property
def n_lines(self) -> int:
return len(self.lines)
# --------------------------------------------------------------------------- #
# Mermaid builder
# --------------------------------------------------------------------------- #
_LETTERS = string.ascii_uppercase
def _node_id(n: int) -> str:
"""1 -> A, 26 -> Z, 27 -> AA, ..."""
if n <= 26:
return _LETTERS[n - 1]
first = (n - 1) // 26 - 1
second = (n - 1) % 26
return _LETTERS[first] + _LETTERS[second]
# Allowed characters inside a rendered label. Deliberately excludes every
# operator, quote, paren and bracket the system prompt bans.
_LABEL_RE = re.compile(r"^[A-Za-z0-9 ?:_,./\-]+$")
class Mermaid:
SHAPES = {
"rect": ("[", "]"),
"decision": ("{", "}"),
"stadium": ("([", "])"),
"round": ("(", ")"),
}
def __init__(self, header: str = "graph TD"):
self.header = header
self._n = 0
# node id -> (shape, label)
self.nodes: list[tuple[str, str, str]] = []
self.edges: list[tuple[str, Optional[str], str]] = []
self.lines: dict[str, str] = {}
self.loop_count = 0
def add(self, label: str, shape: str = "rect", line=None) -> str:
self._n += 1
nid = _node_id(self._n)
self.nodes.append((nid, shape, label))
if line is not None:
self.lines[nid] = str(line)
return nid
def edge(self, a: str, b: str, label: Optional[str] = None, loop: bool = False) -> None:
self.edges.append((a, label, b))
if loop:
self.loop_count += 1
# rendering -------------------------------------------------------------- #
def _render_node(self, nid: str, shape: str, label: str) -> str:
open_b, close_b = self.SHAPES[shape]
return f" {nid}{open_b}{label}{close_b}"
def render_graph(self) -> str:
out = [self.header]
for nid, shape, label in self.nodes:
out.append(self._render_node(nid, shape, label))
for a, label, b in self.edges:
if label is None:
out.append(f" {a} --> {b}")
else:
out.append(f" {a} -- {label} --> {b}")
return "\n".join(out)
def render_linemap(self) -> Optional[str]:
rows = [f"{nid}: {self.lines[nid]}" for nid, _, _ in self.nodes if nid in self.lines]
if not rows:
return None
return "<linemap>\n" + "\n".join(rows) + "\n</linemap>"
# --------------------------------------------------------------------------- #
# Example container
# --------------------------------------------------------------------------- #
_TERMINAL_PREFIXES = ("Return", "Raise", "Throw", "Print", "Yield", "Error", "End", "Exit", "Break")
def _is_terminal(label: str) -> bool:
return label.startswith(_TERMINAL_PREFIXES)
@dataclass
class Example:
language: str
template: str
code: str # line-numbered source (the user turn)
output: str # thinking + mermaid + linemap (the assistant turn)
source: str # raw source without line prefixes (for syntax checking)
n_nodes: int
def _pluralize(n: int, singular: str) -> str:
if n == 1:
return f"{n} {singular}"
suffix = "es" if singular.endswith(("ch", "sh", "s", "x", "z")) else "s"
return f"{n} {singular}{suffix}"
def build_thinking(rng, m: Mermaid) -> str:
decisions = [n for n in m.nodes if n[1] == "decision"]
terminals = [n for n in m.nodes if _is_terminal(n[2])]
pieces = []
if decisions:
pieces.append(_pluralize(len(decisions), "decision point"))
if m.loop_count:
pieces.append(_pluralize(m.loop_count, "loop"))
pieces.append(_pluralize(len(terminals) or 1, "terminal branch"))
lead1 = rng.choice(["Control structures:", "Structural parse:", "Control flow detected:"])
p1 = f"1. {lead1} " + ", ".join(pieces) + "."
nodelist = ", ".join(f"{nid} {label}" for nid, _, label in m.nodes)
lead2 = rng.choice(["Nodes mapped chronologically:", "Execution nodes in order:", "Node sequence:"])
p2 = f"2. {lead2} {nodelist}."
mapped = [(nid, m.lines[nid]) for nid, _, _ in m.nodes if nid in m.lines]
shown = mapped[:6]
verb = rng.choice(["maps to line", "is line", "at line"])
linestr = ", ".join(f"{nid} {verb} {ln}" for nid, ln in shown)
if len(mapped) > len(shown):
linestr += ", and so on"
p3 = f"3. Source lines: {linestr}."
return "<thinking>\n" + "\n".join([p1, p2, p3]) + "\n</thinking>"
def assemble_output(rng, m: Mermaid) -> str:
thinking = build_thinking(rng, m)
graph = m.render_graph()
linemap = m.render_linemap()
parts = [thinking, graph]
if linemap is not None:
parts.append(linemap)
return "\n".join(parts)
# --------------------------------------------------------------------------- #
# Validation (the hard gate)
# --------------------------------------------------------------------------- #
class ValidationError(Exception):
pass
_NODE_DEF_RE = re.compile(
r"^([A-Za-z][A-Za-z0-9]*)"
r"(?:\(\[(?P<stadium>[^\]]*)\]\)"
r"|\{(?P<decision>[^{}]*)\}"
r"|\[(?P<rect>[^\]]*)\]"
r"|\((?P<round>[^()]*)\))$"
)
_EDGE_RE = re.compile(
r"^([A-Za-z][A-Za-z0-9]*)\s*(?:--\s*(?P<label>[^>]*?)\s*-->|-->)\s*([A-Za-z][A-Za-z0-9]*)$"
)
def _check_label(label: str) -> None:
if not label.strip():
raise ValidationError("empty label")
if not _LABEL_RE.match(label):
bad = sorted(set(c for c in label if not _LABEL_RE.match(c)))
raise ValidationError(f"label has banned chars {bad!r}: {label!r}")
def validate_mermaid_block(graph_text: str) -> set[str]:
"""Parse a rendered graph; return the set of defined node ids. Raises on any
malformed node/edge, banned label char, or edge referencing an unknown node.
"""
lines = [ln.rstrip() for ln in graph_text.splitlines() if ln.strip()]
if not lines:
raise ValidationError("empty graph")
if not re.match(r"^(graph|flowchart)\s+(TD|LR|TB|RL|BT)$", lines[0]):
raise ValidationError(f"bad header: {lines[0]!r}")
defined: set[str] = set()
referenced: set[str] = set()
for raw in lines[1:]:
body = raw.strip()
m_edge = _EDGE_RE.match(body)
if m_edge:
referenced.add(m_edge.group(1))
referenced.add(m_edge.group(3))
lbl = m_edge.groupdict().get("label")
if lbl:
_check_label(lbl)
continue
m_node = _NODE_DEF_RE.match(body)
if m_node:
nid = m_node.group(1)
label = next(v for k, v in m_node.groupdict().items() if k and v is not None)
_check_label(label)
if nid in defined:
raise ValidationError(f"duplicate node id {nid}")
defined.add(nid)
continue
raise ValidationError(f"unparseable mermaid line: {body!r}")
missing = referenced - defined
if missing:
raise ValidationError(f"edges reference undefined nodes: {sorted(missing)}")
if len(defined) < 1:
raise ValidationError("no nodes defined")
return defined
def validate_example(ex: Example) -> None:
"""Full structural validation of one assembled example."""
out = ex.output
if out.count("<thinking>") != 1 or out.count("</thinking>") != 1:
raise ValidationError("thinking block malformed")
if "```" in out:
raise ValidationError("markdown fence leaked")
for banned in ("Here is", "As requested", "Explanation:", "Note:"):
# only check the post-thinking region to avoid false positives
post = out.split("</thinking>", 1)[1]
if banned in post:
raise ValidationError(f"banned phrase leaked: {banned!r}")
after = out.split("</thinking>", 1)[1].lstrip("\n")
if not after.startswith(("graph ", "flowchart ")):
raise ValidationError("diagram does not start with graph/flowchart after thinking")
# split graph vs linemap
if "<linemap>" in after:
graph_text, _, rest = after.partition("<linemap>")
if not rest.rstrip().endswith("</linemap>"):
raise ValidationError("linemap not closed")
linemap_body = rest.split("</linemap>", 1)[0].strip("\n")
else:
graph_text, linemap_body = after, ""
defined = validate_mermaid_block(graph_text)
# n_lines of the user code
code_lines = ex.code.splitlines()
n_src = len(code_lines)
if linemap_body:
seen_ids = set()
for row in linemap_body.splitlines():
row = row.strip()
if not row:
continue
mrow = re.match(r"^([A-Za-z][A-Za-z0-9]*):\s*(\d+)(?:-(\d+))?$", row)
if not mrow:
raise ValidationError(f"bad linemap row: {row!r}")
nid, lo, hi = mrow.group(1), int(mrow.group(2)), mrow.group(3)
if nid not in defined:
raise ValidationError(f"linemap references unknown node {nid}")
if nid in seen_ids:
raise ValidationError(f"duplicate linemap entry for {nid}")
seen_ids.add(nid)
if lo < 1 or lo > n_src:
raise ValidationError(f"linemap line {lo} out of range 1..{n_src}")
if hi is not None and (int(hi) < lo or int(hi) > n_src):
raise ValidationError(f"linemap range {lo}-{hi} invalid (max {n_src})")
# the user code lines must all carry the "N| " prefix
for i, ln in enumerate(code_lines, 1):
if not ln.startswith(f"{i}| ") and ln != f"{i}|":
raise ValidationError(f"line {i} not properly prefixed: {ln!r}")