|
|
""" |
|
|
Format Detection Utilities for GEPA Optimizer. |
|
|
|
|
|
This module provides utilities to automatically detect output format patterns |
|
|
from expected outputs and generate format constraints for reflection prompts. |
|
|
|
|
|
Key Features: |
|
|
1. Auto-detect JSON, key-value, tabular, or free-text formats |
|
|
2. Generate format specifications from examples |
|
|
3. Create format constraint strings for prompt injection |
|
|
""" |
|
|
|
|
|
import re |
|
|
import json |
|
|
from typing import List, Dict, Any, Optional, Tuple |
|
|
|
|
|
|
|
|
def detect_output_format(expected_outputs: List[str]) -> Dict[str, Any]: |
|
|
""" |
|
|
Analyze expected outputs to detect the common format pattern. |
|
|
|
|
|
Args: |
|
|
expected_outputs: List of expected output strings from the dataset |
|
|
|
|
|
Returns: |
|
|
Dictionary containing: |
|
|
- format_type: 'json', 'key_value', 'tabular', 'structured_text', 'free_text' |
|
|
- format_spec: Human-readable format specification |
|
|
- format_example: Example showing the format |
|
|
- format_constraint: Constraint text to add to prompts |
|
|
- detected_keys: List of keys/fields detected (for structured formats) |
|
|
- avg_length: Average length of outputs (to enforce conciseness) |
|
|
""" |
|
|
if not expected_outputs: |
|
|
return { |
|
|
'format_type': 'unknown', |
|
|
'format_spec': 'Unknown format', |
|
|
'format_example': '', |
|
|
'format_constraint': '', |
|
|
'detected_keys': [], |
|
|
'avg_length': 0 |
|
|
} |
|
|
|
|
|
|
|
|
valid_outputs = [o for o in expected_outputs if o and o.strip()] |
|
|
if not valid_outputs: |
|
|
return _create_format_result('unknown', 'Unknown format', '', [], 0) |
|
|
|
|
|
|
|
|
avg_length = sum(len(o) for o in valid_outputs) // len(valid_outputs) |
|
|
max_length = max(len(o) for o in valid_outputs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
json_result = _detect_json_format(valid_outputs, avg_length, max_length) |
|
|
if json_result: |
|
|
return json_result |
|
|
|
|
|
|
|
|
kv_result = _detect_key_value_format(valid_outputs, avg_length, max_length) |
|
|
if kv_result: |
|
|
return kv_result |
|
|
|
|
|
|
|
|
list_result = _detect_list_format(valid_outputs, avg_length, max_length) |
|
|
if list_result: |
|
|
return list_result |
|
|
|
|
|
|
|
|
structured_result = _detect_structured_text(valid_outputs, avg_length, max_length) |
|
|
if structured_result: |
|
|
return structured_result |
|
|
|
|
|
|
|
|
return _create_format_result( |
|
|
'free_text', |
|
|
f'Free-form text response (typically {avg_length} characters)', |
|
|
valid_outputs[0][:100] if valid_outputs else '', |
|
|
[], |
|
|
avg_length, |
|
|
max_length |
|
|
) |
|
|
|
|
|
|
|
|
def _detect_json_format(outputs: List[str], avg_length: int, max_length: int) -> Optional[Dict[str, Any]]: |
|
|
"""Detect if outputs are JSON format.""" |
|
|
json_count = 0 |
|
|
all_keys = [] |
|
|
|
|
|
for output in outputs: |
|
|
stripped = output.strip() |
|
|
if stripped.startswith('{') and stripped.endswith('}'): |
|
|
try: |
|
|
parsed = json.loads(stripped) |
|
|
if isinstance(parsed, dict): |
|
|
json_count += 1 |
|
|
all_keys.extend(parsed.keys()) |
|
|
except json.JSONDecodeError: |
|
|
pass |
|
|
|
|
|
|
|
|
if json_count >= len(outputs) * 0.7: |
|
|
|
|
|
key_counts = {} |
|
|
for key in all_keys: |
|
|
key_counts[key] = key_counts.get(key, 0) + 1 |
|
|
|
|
|
common_keys = [k for k, v in key_counts.items() if v >= json_count * 0.5] |
|
|
|
|
|
|
|
|
format_spec = f"JSON object with keys: {', '.join(common_keys)}" |
|
|
format_example = outputs[0][:200] if outputs else '{}' |
|
|
|
|
|
return _create_format_result( |
|
|
'json', |
|
|
format_spec, |
|
|
format_example, |
|
|
common_keys, |
|
|
avg_length, |
|
|
max_length |
|
|
) |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
def _detect_key_value_format(outputs: List[str], avg_length: int, max_length: int) -> Optional[Dict[str, Any]]: |
|
|
"""Detect key-value formats like 'Department: X | Sentiment: Y'.""" |
|
|
|
|
|
separators = ['|', '\n', ';', ','] |
|
|
key_patterns = [ |
|
|
r'([A-Za-z_][A-Za-z0-9_\s]*)\s*[:=]\s*([^|;\n,]+)', |
|
|
] |
|
|
|
|
|
all_keys = [] |
|
|
kv_count = 0 |
|
|
detected_separator = None |
|
|
|
|
|
for output in outputs: |
|
|
|
|
|
for pattern in key_patterns: |
|
|
matches = re.findall(pattern, output) |
|
|
if len(matches) >= 2: |
|
|
kv_count += 1 |
|
|
for key, _ in matches: |
|
|
all_keys.append(key.strip()) |
|
|
|
|
|
|
|
|
for sep in separators: |
|
|
if sep in output: |
|
|
detected_separator = sep |
|
|
break |
|
|
break |
|
|
|
|
|
|
|
|
if kv_count >= len(outputs) * 0.6: |
|
|
|
|
|
key_counts = {} |
|
|
for key in all_keys: |
|
|
normalized = key.strip().lower() |
|
|
key_counts[normalized] = key_counts.get(normalized, 0) + 1 |
|
|
|
|
|
common_keys = [k for k, v in sorted(key_counts.items(), key=lambda x: -x[1]) |
|
|
if v >= kv_count * 0.4][:5] |
|
|
|
|
|
|
|
|
sep_display = detected_separator if detected_separator else ' | ' |
|
|
format_spec = f"Key-value pairs: {sep_display.join([f'{k}: [value]' for k in common_keys])}" |
|
|
format_example = outputs[0] if outputs else '' |
|
|
|
|
|
return _create_format_result( |
|
|
'key_value', |
|
|
format_spec, |
|
|
format_example, |
|
|
common_keys, |
|
|
avg_length, |
|
|
max_length |
|
|
) |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
def _detect_list_format(outputs: List[str], avg_length: int, max_length: int) -> Optional[Dict[str, Any]]: |
|
|
"""Detect bullet/numbered list formats.""" |
|
|
list_patterns = [ |
|
|
r'^[-*•]\s+', |
|
|
r'^\d+[.)]\s+', |
|
|
] |
|
|
|
|
|
list_count = 0 |
|
|
|
|
|
for output in outputs: |
|
|
lines = output.strip().split('\n') |
|
|
list_lines = 0 |
|
|
for line in lines: |
|
|
for pattern in list_patterns: |
|
|
if re.match(pattern, line.strip()): |
|
|
list_lines += 1 |
|
|
break |
|
|
|
|
|
if list_lines >= len(lines) * 0.5: |
|
|
list_count += 1 |
|
|
|
|
|
if list_count >= len(outputs) * 0.6: |
|
|
return _create_format_result( |
|
|
'list', |
|
|
'Bullet or numbered list format', |
|
|
outputs[0][:200] if outputs else '', |
|
|
[], |
|
|
avg_length, |
|
|
max_length |
|
|
) |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
def _detect_structured_text(outputs: List[str], avg_length: int, max_length: int) -> Optional[Dict[str, Any]]: |
|
|
"""Detect structured text with consistent patterns.""" |
|
|
|
|
|
line_counts = [len(o.strip().split('\n')) for o in outputs] |
|
|
avg_lines = sum(line_counts) // len(line_counts) if line_counts else 1 |
|
|
|
|
|
if avg_lines >= 2: |
|
|
return _create_format_result( |
|
|
'structured_text', |
|
|
f'Structured text with ~{avg_lines} lines', |
|
|
outputs[0][:200] if outputs else '', |
|
|
[], |
|
|
avg_length, |
|
|
max_length |
|
|
) |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
def _create_format_result( |
|
|
format_type: str, |
|
|
format_spec: str, |
|
|
format_example: str, |
|
|
detected_keys: List[str], |
|
|
avg_length: int, |
|
|
max_length: int = 0 |
|
|
) -> Dict[str, Any]: |
|
|
"""Create a standardized format detection result.""" |
|
|
|
|
|
if format_type == 'json': |
|
|
constraint = f"""OUTPUT FORMAT REQUIREMENT: |
|
|
- Return ONLY a valid JSON object |
|
|
- Required keys: {', '.join(detected_keys) if detected_keys else 'as shown in examples'} |
|
|
- NO explanations, NO prose, NO markdown code blocks |
|
|
- Maximum length: ~{max_length} characters |
|
|
- Example format: {format_example[:150]}""" |
|
|
|
|
|
elif format_type == 'key_value': |
|
|
constraint = f"""OUTPUT FORMAT REQUIREMENT: |
|
|
- Return ONLY in key-value format: {format_spec} |
|
|
- NO explanations, NO reasoning, NO additional text |
|
|
- Be CONCISE - output should be ~{avg_length} characters max |
|
|
- Example: {format_example}""" |
|
|
|
|
|
elif format_type == 'list': |
|
|
constraint = f"""OUTPUT FORMAT REQUIREMENT: |
|
|
- Return as a bullet or numbered list |
|
|
- NO explanations before or after the list |
|
|
- Keep it concise (~{avg_length} characters)""" |
|
|
|
|
|
elif format_type == 'structured_text': |
|
|
constraint = f"""OUTPUT FORMAT REQUIREMENT: |
|
|
- Follow the structured format shown in examples |
|
|
- NO additional explanations or commentary |
|
|
- Keep output concise (~{avg_length} characters)""" |
|
|
|
|
|
else: |
|
|
constraint = f"""OUTPUT FORMAT REQUIREMENT: |
|
|
- Keep response CONCISE and DIRECT |
|
|
- NO lengthy explanations or reasoning |
|
|
- Target length: ~{avg_length} characters (max {max_length}) |
|
|
- Match the format/style of the expected examples""" |
|
|
|
|
|
return { |
|
|
'format_type': format_type, |
|
|
'format_spec': format_spec, |
|
|
'format_example': format_example[:200] if format_example else '', |
|
|
'format_constraint': constraint, |
|
|
'detected_keys': detected_keys, |
|
|
'avg_length': avg_length, |
|
|
'max_length': max_length |
|
|
} |
|
|
|
|
|
|
|
|
def build_format_aware_reflection_prompt( |
|
|
base_prompt: str, |
|
|
format_info: Dict[str, Any], |
|
|
include_example: bool = True |
|
|
) -> str: |
|
|
""" |
|
|
Enhance a reflection prompt with format awareness. |
|
|
|
|
|
Args: |
|
|
base_prompt: The original reflection prompt |
|
|
format_info: Format detection result from detect_output_format() |
|
|
include_example: Whether to include format example |
|
|
|
|
|
Returns: |
|
|
Enhanced prompt with format constraints |
|
|
""" |
|
|
if not format_info or format_info.get('format_type') == 'unknown': |
|
|
return base_prompt |
|
|
|
|
|
format_section = f""" |
|
|
|
|
|
🎯 CRITICAL FORMAT REQUIREMENT: |
|
|
The optimized prompt MUST produce outputs that match this EXACT format: |
|
|
|
|
|
{format_info['format_constraint']} |
|
|
|
|
|
⚠️ COMMON FAILURE MODES TO AVOID: |
|
|
1. Generating explanations when only the answer is needed |
|
|
2. Adding "Here's the analysis..." or similar preambles |
|
|
3. Producing verbose output when concise is required |
|
|
4. Wrong structure (e.g., prose instead of key-value pairs) |
|
|
""" |
|
|
|
|
|
if include_example and format_info.get('format_example'): |
|
|
format_section += f""" |
|
|
📋 EXAMPLE OF CORRECT OUTPUT FORMAT: |
|
|
{format_info['format_example']} |
|
|
""" |
|
|
|
|
|
|
|
|
return base_prompt + format_section |
|
|
|
|
|
|
|
|
def generate_format_feedback( |
|
|
predicted_output: str, |
|
|
expected_output: str, |
|
|
format_info: Dict[str, Any] |
|
|
) -> str: |
|
|
""" |
|
|
Generate specific feedback about format compliance. |
|
|
|
|
|
Args: |
|
|
predicted_output: What the model actually produced |
|
|
expected_output: The ground truth output |
|
|
format_info: Format detection result |
|
|
|
|
|
Returns: |
|
|
Specific format-related feedback |
|
|
""" |
|
|
predicted_len = len(predicted_output) if predicted_output else 0 |
|
|
expected_len = len(expected_output) if expected_output else 0 |
|
|
|
|
|
issues = [] |
|
|
|
|
|
|
|
|
if format_info.get('avg_length', 0) > 0: |
|
|
if predicted_len > format_info['avg_length'] * 3: |
|
|
issues.append(f"OUTPUT TOO VERBOSE: Generated {predicted_len} chars, expected ~{format_info['avg_length']} chars") |
|
|
elif predicted_len > format_info.get('max_length', predicted_len) * 2: |
|
|
issues.append(f"OUTPUT TOO LONG: {predicted_len} chars vs max expected {format_info.get('max_length', 'unknown')}") |
|
|
|
|
|
|
|
|
format_type = format_info.get('format_type', 'unknown') |
|
|
|
|
|
if format_type == 'json': |
|
|
try: |
|
|
json.loads(predicted_output.strip() if predicted_output else '{}') |
|
|
except json.JSONDecodeError: |
|
|
issues.append("FORMAT ERROR: Expected JSON but got non-JSON output") |
|
|
|
|
|
elif format_type == 'key_value': |
|
|
|
|
|
if predicted_output and ':' not in predicted_output: |
|
|
issues.append("FORMAT ERROR: Expected key-value pairs (Key: Value) but output lacks this structure") |
|
|
|
|
|
|
|
|
verbose_indicators = [ |
|
|
'let me', 'i will', 'here is', "here's", 'analysis:', 'step-by-step', |
|
|
'first,', 'to begin', 'in order to', 'the following', 'please note' |
|
|
] |
|
|
|
|
|
if predicted_output: |
|
|
lower_output = predicted_output.lower() |
|
|
found_verbose = [v for v in verbose_indicators if v in lower_output] |
|
|
if found_verbose: |
|
|
issues.append(f"VERBOSITY WARNING: Output contains explanatory phrases: {', '.join(found_verbose[:3])}") |
|
|
|
|
|
if not issues: |
|
|
return "" |
|
|
|
|
|
return "\n🚨 FORMAT ISSUES DETECTED:\n" + "\n".join(f" • {issue}" for issue in issues) |
|
|
|