afridialeval / src /generation_utils.py
millicentochieng's picture
Upload folder using huggingface_hub
e2b8b61 verified
import json
import re
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
ROLE_HEADERS = {"# system", "# user", "# assistant"}
JSON_SCHEMA_HEADER = "# JSON schema"
def fill_template_file(template_path: str, data: Dict[str, Any]) -> Tuple[List[Dict[str, str]], Optional[Dict[str, Any]]]:
"""
Reads a DAS-style prompt template and returns:
1. chat messages: [{"role": "...", "content": "..."}]
2. response_format: dict | None
Supported template features:
- role headers:
# system
# user
# assistant
- placeholders:
[key]
{{key}}
- loop sections:
# start field_name
...
# end field_name
where data[field_name] is a list[dict]
- nested loops inside loops
- optional JSON schema block:
# JSON schema
{ ... valid JSON ... }
"""
raw_text = Path(template_path).read_text(encoding="utf-8")
prompt_text, schema_text = _split_prompt_and_schema(raw_text)
expanded_prompt = _expand_template(prompt_text, data)
messages = _parse_role_markdown(expanded_prompt)
response_format = None
if schema_text:
schema_payload = json.loads(schema_text)
response_format = _schema_to_response_format(schema_payload)
return messages, response_format
def _split_prompt_and_schema(text: str) -> Tuple[str, Optional[str]]:
if JSON_SCHEMA_HEADER not in text:
return text, None
prompt_part, schema_part = text.split(JSON_SCHEMA_HEADER, 1)
schema_text = schema_part.strip()
if not schema_text:
return prompt_part, None
return prompt_part.rstrip(), schema_text
def _schema_to_response_format(schema_payload: Dict[str, Any]) -> Dict[str, Any]:
"""
Expects DAS-style schema payload, e.g.
{
"name": "decoded_das",
"schema": {...},
"strict": true
}
"""
if "name" not in schema_payload or "schema" not in schema_payload:
raise ValueError("JSON schema block must contain at least 'name' and 'schema' keys.")
return {
"type": "json_schema",
"json_schema": {
"name": schema_payload["name"],
"schema": schema_payload["schema"],
"strict": schema_payload.get("strict", True),
},
}
def _expand_template(text: str, data: Dict[str, Any]) -> str:
lines = text.splitlines()
expanded_lines, _ = _process_block(lines, 0, data)
expanded_text = "\n".join(expanded_lines)
expanded_text = _replace_placeholders(expanded_text, data)
return expanded_text.strip()
def _process_block(lines: List[str], start_idx: int, context: Dict[str, Any]) -> Tuple[List[str], int]:
"""
Recursively processes lines until the end of the block or a matching # end ...
"""
output: List[str] = []
i = start_idx
while i < len(lines):
stripped = lines[i].strip()
if stripped.startswith("# end "):
return output, i
if stripped.startswith("# start "):
field_name = stripped.replace("# start ", "", 1).strip()
block_lines, end_idx = _collect_loop_block(lines, i + 1, field_name)
loop_value = _resolve_key(field_name, context)
if loop_value is None:
loop_value = []
if not isinstance(loop_value, list):
raise ValueError(f"Loop field '{field_name}' must be a list, got {type(loop_value).__name__}.")
for idx, item in enumerate(loop_value, start=1):
child_context = dict(context)
if isinstance(item, dict):
child_context.update(item)
else:
child_context[field_name] = item
child_context["$index"] = idx
child_context[field_name] = item
expanded_child, _ = _process_block(block_lines, 0, child_context)
output.extend(expanded_child)
i = end_idx + 1
continue
output.append(_replace_placeholders(lines[i], context))
i += 1
return output, i
def _collect_loop_block(lines: List[str], start_idx: int, field_name: str) -> Tuple[List[str], int]:
"""
Collects lines until the matching # end field_name, respecting nested loops.
Returns (block_lines, end_index).
"""
block: List[str] = []
depth = 1
i = start_idx
while i < len(lines):
stripped = lines[i].strip()
if stripped.startswith("# start "):
nested_name = stripped.replace("# start ", "", 1).strip()
if nested_name == field_name:
depth += 1
block.append(lines[i])
i += 1
continue
if stripped.startswith("# end "):
end_name = stripped.replace("# end ", "", 1).strip()
if end_name == field_name:
depth -= 1
if depth == 0:
return block, i
block.append(lines[i])
i += 1
continue
block.append(lines[i])
i += 1
raise ValueError(f"Missing matching '# end {field_name}' in template.")
def _parse_role_markdown(text: str) -> List[Dict[str, str]]:
messages: List[Dict[str, str]] = []
current_role: Optional[str] = None
buffer: List[str] = []
for line in text.splitlines():
stripped = line.strip()
if stripped in ROLE_HEADERS:
if current_role is not None:
content = _clean_content("\n".join(buffer))
messages.append({"role": current_role, "content": content})
current_role = stripped.replace("# ", "")
buffer = []
continue
buffer.append(line)
if current_role is not None:
content = _clean_content("\n".join(buffer))
messages.append({"role": current_role, "content": content})
if not messages:
raise ValueError(
"Template must contain at least one role header: '# system', '# user', or '# assistant'."
)
return messages
def _clean_content(text: str) -> str:
lines = text.splitlines()
while lines and not lines[0].strip():
lines.pop(0)
while lines and not lines[-1].strip():
lines.pop()
if not lines:
return ""
min_indent = None
for line in lines:
if not line.strip():
continue
indent = len(line) - len(line.lstrip(" "))
if min_indent is None or indent < min_indent:
min_indent = indent
min_indent = min_indent or 0
cleaned = "\n".join(line[min_indent:] if len(line) >= min_indent else line for line in lines)
return cleaned.strip()
def _replace_placeholders(text: str, context: Dict[str, Any]) -> str:
"""
Supports both:
[key]
{{key}}
including dotted keys:
[speaker.name]
{{speaker.name}}
and special loop index:
[$index]
{{$index}}
"""
def square_repl(match: re.Match) -> str:
key = match.group(1).strip()
value = _resolve_key(key, context)
return _stringify(value)
def brace_repl(match: re.Match) -> str:
key = match.group(1).strip()
value = _resolve_key(key, context)
return _stringify(value)
text = re.sub(r"\[([^\[\]]+)\]", square_repl, text)
text = re.sub(r"\{\{([^{}]+)\}\}", brace_repl, text)
return text
def _resolve_key(key: str, context: Dict[str, Any]) -> Any:
if key in context:
return context[key]
if "." not in key:
return ""
current: Any = context
for part in key.split("."):
part = part.strip()
if isinstance(current, dict) and part in current:
current = current[part]
else:
return ""
return current
def _stringify(value: Any) -> str:
if value is None:
return ""
if isinstance(value, str):
return value
if isinstance(value, (int, float, bool)):
return str(value)
if isinstance(value, list):
return ", ".join(_stringify(v) for v in value)
if isinstance(value, dict):
return json.dumps(value, ensure_ascii=False)
return str(value)