|
|
""" |
|
|
metrics.py - Evaluation metrics for code completion validation. |
|
|
|
|
|
Implement functions to calculate bracket accuracy, indentation, |
|
|
and other code-specific metrics. |
|
|
""" |
|
|
|
|
|
import re |
|
|
from typing import List, Tuple, Dict |
|
|
from dataclasses import dataclass |
|
|
from collections import Counter |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TestResult: |
|
|
"""Individual test result.""" |
|
|
test_name: str |
|
|
category: str |
|
|
passed: bool |
|
|
prompt: str |
|
|
generated: str |
|
|
expected_patterns: List[str] |
|
|
matched_patterns: List[str] |
|
|
failed_patterns: List[str] |
|
|
forbidden_matches: List[str] |
|
|
score: float |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class CategoryResult: |
|
|
"""Aggregated result for a category.""" |
|
|
category: str |
|
|
total_tests: int |
|
|
passed_tests: int |
|
|
accuracy: float |
|
|
test_results: List[TestResult] |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ValidationReport: |
|
|
"""Complete validation report.""" |
|
|
model_name: str |
|
|
total_tests: int |
|
|
total_passed: int |
|
|
overall_accuracy: float |
|
|
category_results: Dict[str, CategoryResult] |
|
|
bracket_accuracy: float |
|
|
indentation_accuracy: float |
|
|
structure_accuracy: float |
|
|
|
|
|
|
|
|
def check_brackets_balanced(text: str) -> Tuple[bool, str]: |
|
|
""" |
|
|
Checks if brackets are balanced. |
|
|
|
|
|
Returns: |
|
|
(is_balanced, error_message) |
|
|
""" |
|
|
stack = [] |
|
|
pairs = {'(': ')', '[': ']', '{': '}'} |
|
|
|
|
|
for i, char in enumerate(text): |
|
|
if char in pairs: |
|
|
stack.append((char, i)) |
|
|
elif char in pairs.values(): |
|
|
if not stack: |
|
|
return False, f"Extra bracket '{char}' at position {i}" |
|
|
opening, pos = stack.pop() |
|
|
if pairs[opening] != char: |
|
|
return False, f"Mismatch: '{opening}' at position {pos} closed with '{char}' at position {i}" |
|
|
|
|
|
if stack: |
|
|
unclosed = [(char, pos) for char, pos in stack] |
|
|
return False, f"Unclosed brackets: {unclosed}" |
|
|
|
|
|
return True, "OK" |
|
|
|
|
|
|
|
|
def count_bracket_errors(prompt: str, generated: str) -> Dict[str, int]: |
|
|
""" |
|
|
Counts bracket errors in generated code. |
|
|
|
|
|
Returns: |
|
|
Dictionary with error counts by type |
|
|
""" |
|
|
full_code = prompt + generated |
|
|
|
|
|
errors = { |
|
|
'unclosed_parens': 0, |
|
|
'unclosed_brackets': 0, |
|
|
'unclosed_braces': 0, |
|
|
'extra_closing': 0 |
|
|
} |
|
|
|
|
|
|
|
|
parens = full_code.count('(') - full_code.count(')') |
|
|
brackets = full_code.count('[') - full_code.count(']') |
|
|
braces = full_code.count('{') - full_code.count('}') |
|
|
|
|
|
if parens > 0: |
|
|
errors['unclosed_parens'] = parens |
|
|
elif parens < 0: |
|
|
errors['extra_closing'] += abs(parens) |
|
|
|
|
|
if brackets > 0: |
|
|
errors['unclosed_brackets'] = brackets |
|
|
elif brackets < 0: |
|
|
errors['extra_closing'] += abs(brackets) |
|
|
|
|
|
if braces > 0: |
|
|
errors['unclosed_braces'] = braces |
|
|
elif braces < 0: |
|
|
errors['extra_closing'] += abs(braces) |
|
|
|
|
|
return errors |
|
|
|
|
|
|
|
|
def check_indentation(text: str) -> Dict[str, any]: |
|
|
""" |
|
|
Analyzes indentation quality in code. |
|
|
|
|
|
Returns: |
|
|
Dictionary with indentation metrics |
|
|
""" |
|
|
lines = text.split('\n') |
|
|
|
|
|
stats = { |
|
|
'total_lines': len(lines), |
|
|
'indented_lines': 0, |
|
|
'consistent_indent': True, |
|
|
'indent_style': None, |
|
|
'indent_size': None, |
|
|
'indent_errors': [] |
|
|
} |
|
|
|
|
|
indent_sizes = [] |
|
|
|
|
|
for i, line in enumerate(lines): |
|
|
if not line.strip(): |
|
|
continue |
|
|
|
|
|
|
|
|
stripped = line.lstrip() |
|
|
indent = len(line) - len(stripped) |
|
|
|
|
|
if indent > 0: |
|
|
stats['indented_lines'] += 1 |
|
|
|
|
|
|
|
|
if line.startswith('\t'): |
|
|
if stats['indent_style'] is None: |
|
|
stats['indent_style'] = 'tabs' |
|
|
elif stats['indent_style'] == 'spaces': |
|
|
stats['consistent_indent'] = False |
|
|
else: |
|
|
if stats['indent_style'] is None: |
|
|
stats['indent_style'] = 'spaces' |
|
|
elif stats['indent_style'] == 'tabs': |
|
|
stats['consistent_indent'] = False |
|
|
|
|
|
if stats['indent_style'] == 'spaces': |
|
|
indent_sizes.append(indent) |
|
|
|
|
|
|
|
|
if indent_sizes: |
|
|
|
|
|
common_indents = Counter(indent_sizes) |
|
|
stats['indent_size'] = min(common_indents.keys()) if common_indents else 4 |
|
|
|
|
|
return stats |
|
|
|
|
|
|
|
|
def evaluate_test_case( |
|
|
prompt: str, |
|
|
generated: str, |
|
|
expected_patterns: List[str], |
|
|
forbidden_patterns: List[str] = None |
|
|
) -> Tuple[bool, float, List[str], List[str], List[str]]: |
|
|
""" |
|
|
Evaluates a test case. |
|
|
|
|
|
Returns: |
|
|
(passed, score, matched_patterns, failed_patterns, forbidden_matches) |
|
|
""" |
|
|
if forbidden_patterns is None: |
|
|
forbidden_patterns = [] |
|
|
|
|
|
matched = [] |
|
|
failed = [] |
|
|
forbidden_found = [] |
|
|
|
|
|
|
|
|
for pattern in expected_patterns: |
|
|
try: |
|
|
if re.search(pattern, generated, re.MULTILINE): |
|
|
matched.append(pattern) |
|
|
else: |
|
|
failed.append(pattern) |
|
|
except re.error: |
|
|
|
|
|
if pattern in generated: |
|
|
matched.append(pattern) |
|
|
else: |
|
|
failed.append(pattern) |
|
|
|
|
|
|
|
|
for pattern in forbidden_patterns: |
|
|
try: |
|
|
if re.search(pattern, generated, re.MULTILINE): |
|
|
forbidden_found.append(pattern) |
|
|
except re.error: |
|
|
if pattern in generated: |
|
|
forbidden_found.append(pattern) |
|
|
|
|
|
|
|
|
if expected_patterns: |
|
|
score = len(matched) / len(expected_patterns) |
|
|
else: |
|
|
score = 1.0 |
|
|
|
|
|
|
|
|
if forbidden_found: |
|
|
score *= 0.5 |
|
|
|
|
|
passed = len(matched) > 0 and len(forbidden_found) == 0 |
|
|
|
|
|
return passed, score, matched, failed, forbidden_found |
|
|
|
|
|
|
|
|
def calculate_bracket_accuracy(results: List[TestResult]) -> float: |
|
|
"""Calculates accuracy specific to brackets.""" |
|
|
bracket_tests = [r for r in results if r.category == 'brackets'] |
|
|
if not bracket_tests: |
|
|
return 0.0 |
|
|
return sum(1 for t in bracket_tests if t.passed) / len(bracket_tests) |
|
|
|
|
|
|
|
|
def calculate_indentation_accuracy(results: List[TestResult]) -> float: |
|
|
"""Calculates accuracy specific to indentation.""" |
|
|
indent_tests = [r for r in results if r.category == 'indentation'] |
|
|
if not indent_tests: |
|
|
return 0.0 |
|
|
return sum(1 for t in indent_tests if t.passed) / len(indent_tests) |
|
|
|
|
|
|
|
|
def generate_report( |
|
|
model_name: str, |
|
|
results: List[TestResult] |
|
|
) -> ValidationReport: |
|
|
""" |
|
|
Generates complete validation report. |
|
|
""" |
|
|
|
|
|
categories = {} |
|
|
for result in results: |
|
|
if result.category not in categories: |
|
|
categories[result.category] = [] |
|
|
categories[result.category].append(result) |
|
|
|
|
|
|
|
|
category_results = {} |
|
|
for cat, cat_results in categories.items(): |
|
|
passed = sum(1 for r in cat_results if r.passed) |
|
|
category_results[cat] = CategoryResult( |
|
|
category=cat, |
|
|
total_tests=len(cat_results), |
|
|
passed_tests=passed, |
|
|
accuracy=passed / len(cat_results) if cat_results else 0, |
|
|
test_results=cat_results |
|
|
) |
|
|
|
|
|
|
|
|
total = len(results) |
|
|
passed = sum(1 for r in results if r.passed) |
|
|
|
|
|
return ValidationReport( |
|
|
model_name=model_name, |
|
|
total_tests=total, |
|
|
total_passed=passed, |
|
|
overall_accuracy=passed / total if total > 0 else 0, |
|
|
category_results=category_results, |
|
|
bracket_accuracy=calculate_bracket_accuracy(results), |
|
|
indentation_accuracy=calculate_indentation_accuracy(results), |
|
|
structure_accuracy=sum(1 for r in results if r.category == 'structure' and r.passed) / |
|
|
max(1, len([r for r in results if r.category == 'structure'])) |
|
|
) |
|
|
|
|
|
|
|
|
def format_report(report: ValidationReport) -> str: |
|
|
"""Formats report for printing.""" |
|
|
lines = [ |
|
|
"=" * 60, |
|
|
f"π VALIDATION REPORT: {report.model_name}", |
|
|
"=" * 60, |
|
|
"", |
|
|
f"π OVERALL RESULTS", |
|
|
f" Total tests: {report.total_tests}", |
|
|
f" Passed tests: {report.total_passed}", |
|
|
f" Overall Accuracy: {report.overall_accuracy:.1%}", |
|
|
"", |
|
|
"π SPECIFIC METRICS", |
|
|
f" Bracket Accuracy: {report.bracket_accuracy:.1%}", |
|
|
f" Indentation Accuracy: {report.indentation_accuracy:.1%}", |
|
|
f" Structure Accuracy: {report.structure_accuracy:.1%}", |
|
|
"", |
|
|
"π RESULTS BY CATEGORY", |
|
|
] |
|
|
|
|
|
for cat_name, cat_result in report.category_results.items(): |
|
|
status = "β
" if cat_result.accuracy >= 0.7 else "β οΈ" if cat_result.accuracy >= 0.5 else "β" |
|
|
lines.append(f" {status} {cat_name}: {cat_result.passed_tests}/{cat_result.total_tests} ({cat_result.accuracy:.1%})") |
|
|
|
|
|
lines.extend([ |
|
|
"", |
|
|
"=" * 60 |
|
|
]) |
|
|
|
|
|
return "\n".join(lines) |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
print("π§ͺ Testing metrics...") |
|
|
|
|
|
|
|
|
is_bal, msg = check_brackets_balanced("def foo(a, b):") |
|
|
print(f"Balanced '(a, b)': {is_bal} - {msg}") |
|
|
|
|
|
is_bal, msg = check_brackets_balanced("def foo(a, b:") |
|
|
print(f"Balanced '(a, b:': {is_bal} - {msg}") |
|
|
|
|
|
|
|
|
passed, score, matched, failed, forbidden = evaluate_test_case( |
|
|
prompt="def hello(", |
|
|
generated="name):\n print(name)", |
|
|
expected_patterns=[r"\)", r":"] |
|
|
) |
|
|
print(f"Test result: passed={passed}, score={score}, matched={matched}") |
|
|
|