Spaces:
Running
Running
| import json | |
| import re | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional, Tuple | |
| ROLE_HEADERS = {"# system", "# user", "# assistant"} | |
| JSON_SCHEMA_HEADER = "# JSON schema" | |
| def fill_template_file(template_path: str, data: Dict[str, Any]) -> Tuple[List[Dict[str, str]], Optional[Dict[str, Any]]]: | |
| """ | |
| Reads a DAS-style prompt template and returns: | |
| 1. chat messages: [{"role": "...", "content": "..."}] | |
| 2. response_format: dict | None | |
| Supported template features: | |
| - role headers: | |
| # system | |
| # user | |
| # assistant | |
| - placeholders: | |
| [key] | |
| {{key}} | |
| - loop sections: | |
| # start field_name | |
| ... | |
| # end field_name | |
| where data[field_name] is a list[dict] | |
| - nested loops inside loops | |
| - optional JSON schema block: | |
| # JSON schema | |
| { ... valid JSON ... } | |
| """ | |
| raw_text = Path(template_path).read_text(encoding="utf-8") | |
| prompt_text, schema_text = _split_prompt_and_schema(raw_text) | |
| expanded_prompt = _expand_template(prompt_text, data) | |
| messages = _parse_role_markdown(expanded_prompt) | |
| response_format = None | |
| if schema_text: | |
| schema_payload = json.loads(schema_text) | |
| response_format = _schema_to_response_format(schema_payload) | |
| return messages, response_format | |
| def _split_prompt_and_schema(text: str) -> Tuple[str, Optional[str]]: | |
| if JSON_SCHEMA_HEADER not in text: | |
| return text, None | |
| prompt_part, schema_part = text.split(JSON_SCHEMA_HEADER, 1) | |
| schema_text = schema_part.strip() | |
| if not schema_text: | |
| return prompt_part, None | |
| return prompt_part.rstrip(), schema_text | |
| def _schema_to_response_format(schema_payload: Dict[str, Any]) -> Dict[str, Any]: | |
| """ | |
| Expects DAS-style schema payload, e.g. | |
| { | |
| "name": "decoded_das", | |
| "schema": {...}, | |
| "strict": true | |
| } | |
| """ | |
| if "name" not in schema_payload or "schema" not in schema_payload: | |
| raise ValueError("JSON schema block must contain at least 'name' and 'schema' keys.") | |
| return { | |
| "type": "json_schema", | |
| "json_schema": { | |
| "name": schema_payload["name"], | |
| "schema": schema_payload["schema"], | |
| "strict": schema_payload.get("strict", True), | |
| }, | |
| } | |
| def _expand_template(text: str, data: Dict[str, Any]) -> str: | |
| lines = text.splitlines() | |
| expanded_lines, _ = _process_block(lines, 0, data) | |
| expanded_text = "\n".join(expanded_lines) | |
| expanded_text = _replace_placeholders(expanded_text, data) | |
| return expanded_text.strip() | |
| def _process_block(lines: List[str], start_idx: int, context: Dict[str, Any]) -> Tuple[List[str], int]: | |
| """ | |
| Recursively processes lines until the end of the block or a matching # end ... | |
| """ | |
| output: List[str] = [] | |
| i = start_idx | |
| while i < len(lines): | |
| stripped = lines[i].strip() | |
| if stripped.startswith("# end "): | |
| return output, i | |
| if stripped.startswith("# start "): | |
| field_name = stripped.replace("# start ", "", 1).strip() | |
| block_lines, end_idx = _collect_loop_block(lines, i + 1, field_name) | |
| loop_value = _resolve_key(field_name, context) | |
| if loop_value is None: | |
| loop_value = [] | |
| if not isinstance(loop_value, list): | |
| raise ValueError(f"Loop field '{field_name}' must be a list, got {type(loop_value).__name__}.") | |
| for idx, item in enumerate(loop_value, start=1): | |
| child_context = dict(context) | |
| if isinstance(item, dict): | |
| child_context.update(item) | |
| else: | |
| child_context[field_name] = item | |
| child_context["$index"] = idx | |
| child_context[field_name] = item | |
| expanded_child, _ = _process_block(block_lines, 0, child_context) | |
| output.extend(expanded_child) | |
| i = end_idx + 1 | |
| continue | |
| output.append(_replace_placeholders(lines[i], context)) | |
| i += 1 | |
| return output, i | |
| def _collect_loop_block(lines: List[str], start_idx: int, field_name: str) -> Tuple[List[str], int]: | |
| """ | |
| Collects lines until the matching # end field_name, respecting nested loops. | |
| Returns (block_lines, end_index). | |
| """ | |
| block: List[str] = [] | |
| depth = 1 | |
| i = start_idx | |
| while i < len(lines): | |
| stripped = lines[i].strip() | |
| if stripped.startswith("# start "): | |
| nested_name = stripped.replace("# start ", "", 1).strip() | |
| if nested_name == field_name: | |
| depth += 1 | |
| block.append(lines[i]) | |
| i += 1 | |
| continue | |
| if stripped.startswith("# end "): | |
| end_name = stripped.replace("# end ", "", 1).strip() | |
| if end_name == field_name: | |
| depth -= 1 | |
| if depth == 0: | |
| return block, i | |
| block.append(lines[i]) | |
| i += 1 | |
| continue | |
| block.append(lines[i]) | |
| i += 1 | |
| raise ValueError(f"Missing matching '# end {field_name}' in template.") | |
| def _parse_role_markdown(text: str) -> List[Dict[str, str]]: | |
| messages: List[Dict[str, str]] = [] | |
| current_role: Optional[str] = None | |
| buffer: List[str] = [] | |
| for line in text.splitlines(): | |
| stripped = line.strip() | |
| if stripped in ROLE_HEADERS: | |
| if current_role is not None: | |
| content = _clean_content("\n".join(buffer)) | |
| messages.append({"role": current_role, "content": content}) | |
| current_role = stripped.replace("# ", "") | |
| buffer = [] | |
| continue | |
| buffer.append(line) | |
| if current_role is not None: | |
| content = _clean_content("\n".join(buffer)) | |
| messages.append({"role": current_role, "content": content}) | |
| if not messages: | |
| raise ValueError( | |
| "Template must contain at least one role header: '# system', '# user', or '# assistant'." | |
| ) | |
| return messages | |
| def _clean_content(text: str) -> str: | |
| lines = text.splitlines() | |
| while lines and not lines[0].strip(): | |
| lines.pop(0) | |
| while lines and not lines[-1].strip(): | |
| lines.pop() | |
| if not lines: | |
| return "" | |
| min_indent = None | |
| for line in lines: | |
| if not line.strip(): | |
| continue | |
| indent = len(line) - len(line.lstrip(" ")) | |
| if min_indent is None or indent < min_indent: | |
| min_indent = indent | |
| min_indent = min_indent or 0 | |
| cleaned = "\n".join(line[min_indent:] if len(line) >= min_indent else line for line in lines) | |
| return cleaned.strip() | |
| def _replace_placeholders(text: str, context: Dict[str, Any]) -> str: | |
| """ | |
| Supports both: | |
| [key] | |
| {{key}} | |
| including dotted keys: | |
| [speaker.name] | |
| {{speaker.name}} | |
| and special loop index: | |
| [$index] | |
| {{$index}} | |
| """ | |
| def square_repl(match: re.Match) -> str: | |
| key = match.group(1).strip() | |
| value = _resolve_key(key, context) | |
| return _stringify(value) | |
| def brace_repl(match: re.Match) -> str: | |
| key = match.group(1).strip() | |
| value = _resolve_key(key, context) | |
| return _stringify(value) | |
| text = re.sub(r"\[([^\[\]]+)\]", square_repl, text) | |
| text = re.sub(r"\{\{([^{}]+)\}\}", brace_repl, text) | |
| return text | |
| def _resolve_key(key: str, context: Dict[str, Any]) -> Any: | |
| if key in context: | |
| return context[key] | |
| if "." not in key: | |
| return "" | |
| current: Any = context | |
| for part in key.split("."): | |
| part = part.strip() | |
| if isinstance(current, dict) and part in current: | |
| current = current[part] | |
| else: | |
| return "" | |
| return current | |
| def _stringify(value: Any) -> str: | |
| if value is None: | |
| return "" | |
| if isinstance(value, str): | |
| return value | |
| if isinstance(value, (int, float, bool)): | |
| return str(value) | |
| if isinstance(value, list): | |
| return ", ".join(_stringify(v) for v in value) | |
| if isinstance(value, dict): | |
| return json.dumps(value, ensure_ascii=False) | |
| return str(value) |